Browse Source

Rm const generics (#2)

* arith: get rid of constant generics. Reason:

using constant generics was great for allocating the arrays in the
stack, which is faster, but when started to use bigger parameter values,
in some cases it was overflowing the stack. This commit removes all the
constant generics in all of the `arith` crate, which in some cases slows
a bit the performance, but allows for bigger parameter values (on the
ones that affect lengths, like N and K).

* bfv: get rid of constant generics (reason in previous commit)

* ckks: get rid of constant generics (reason in two commits ago)

* group ring params under a single struct

* gfhe: get rid of constant generics

* tfhe: get rid of constant generics

* polish & clean a bit

* add methods for encoding constants for ct-pt-multiplication
main
arnaucube 2 months ago
committed by GitHub
parent
commit
fb1fb6b4e9
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
23 changed files with 2733 additions and 1883 deletions
  1. +22
    -17
      README.md
  2. +1
    -0
      arith/Cargo.toml
  3. +2
    -3
      arith/src/lib.rs
  4. +114
    -66
      arith/src/ntt.rs
  5. +187
    -0
      arith/src/ntt_fixedsize.rs
  6. +13
    -9
      arith/src/ring.rs
  7. +189
    -162
      arith/src/ring_n.rs
  8. +301
    -226
      arith/src/ring_nq.rs
  9. +178
    -99
      arith/src/ring_torus.rs
  10. +23
    -15
      arith/src/torus.rs
  11. +81
    -48
      arith/src/tuple_ring.rs
  12. +169
    -124
      arith/src/zq.rs
  13. +343
    -302
      bfv/src/lib.rs
  14. +17
    -14
      ckks/src/encoder.rs
  15. +98
    -79
      ckks/src/lib.rs
  16. +51
    -34
      gfhe/src/glev.rs
  17. +318
    -199
      gfhe/src/glwe.rs
  18. +2
    -0
      tfhe/src/lib.rs
  19. +75
    -58
      tfhe/src/tggsw.rs
  20. +183
    -141
      tfhe/src/tglwe.rs
  21. +67
    -61
      tfhe/src/tgsw.rs
  22. +71
    -45
      tfhe/src/tlev.rs
  23. +228
    -181
      tfhe/src/tlwe.rs

+ 22
- 17
README.md

@ -16,30 +16,34 @@ Implementations from scratch done while studying some FHE papers; do not use in
This example shows usage of TFHE, but the idea is that the same interface would This example shows usage of TFHE, but the idea is that the same interface would
work for using CKKS & BFV, the only thing to be changed would be the parameters work for using CKKS & BFV, the only thing to be changed would be the parameters
and the line `type S = TWLE<K>` to use `CKKS<Q, N>` or `BFV<Q, N, T>`.
and the usage of `TLWE` by `CKKS` or `BFV`.
```rust ```rust
const T: u64 = 128; // msg space (msg modulus)
type M = Rq<T, 1>; // msg space
type S = TLWE<256>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 1 },
k: 256,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
// get two random msgs in Z_t
let m1 = M::rand_u64(&mut rng, msg_dist)?;
let m2 = M::rand_u64(&mut rng, msg_dist)?;
let m3 = M::rand_u64(&mut rng, msg_dist)?;
// get three random msgs in Rt
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m3 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
// encode the msgs into the plaintext space // encode the msgs into the plaintext space
let p1 = S::encode::<T>(&m1); // plaintext
let p2 = S::encode::<T>(&m2); // plaintext
let c3_const: Tn<1> = Tn(array::from_fn(|i| T64(m3.coeffs()[i].0))); // encode it as constant
let p1 = TLWE::encode(&param, &m1); // plaintext
let p2 = TLWE::encode(&param, &m2); // plaintext
let c3_const = TLWE::new_const(&param, &m3); // as constant/public value
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c2 = S::encrypt(&mut rng, &pk, &p2)?;
// encrypt p1 and m2
let c1 = TLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c2 = TLWE::encrypt(&mut rng, &param, &pk, &p2)?;
// now we can do encrypted operations (notice that we do them using simple // now we can do encrypted operations (notice that we do them using simple
// operation notation by rust's operator overloading): // operation notation by rust's operator overloading):
@ -48,9 +52,10 @@ let c4 = c_12 * c3_const;
// decrypt & decode // decrypt & decode
let p4_recovered = c4.decrypt(&sk); let p4_recovered = c4.decrypt(&sk);
let m4 = S::decode::<T>(&p4_recovered);
let m4 = TLWE::decode(&param, &p4_recovered);
// m4 is equal to (m1+m2)*m3 // m4 is equal to (m1+m2)*m3
assert_eq!(((m1 + m2).to_r() * m3.to_r()).to_rq(param.t), m4);
``` ```
@ -62,7 +67,7 @@ let m4 = S::decode::(&p4_recovered);
- external products of ciphertexts - external products of ciphertexts
- TGSW x TLWE - TGSW x TLWE
- TGGSW x TGLWE - TGGSW x TGLWE
- TGSW & TGGSW CMux gate
- {TGSW, TGGSW} CMux gate
- blind rotation, key switching, mod switching - blind rotation, key switching, mod switching
- bootstrapping - bootstrapping
- CKKS - CKKS

+ 1
- 0
arith/Cargo.toml

@ -8,6 +8,7 @@ anyhow = { workspace = true }
rand = { workspace = true } rand = { workspace = true }
rand_distr = { workspace = true } rand_distr = { workspace = true }
itertools = { workspace = true } itertools = { workspace = true }
lazy_static = "1.5.0"
# TMP: the next 4 imports are TMP, to solve systems of linear equations. Used # TMP: the next 4 imports are TMP, to solve systems of linear equations. Used
# for the CKKS encoding step, probably remvoed once in ckks the encoding is done # for the CKKS encoding step, probably remvoed once in ckks the encoding is done

+ 2
- 3
arith/src/lib.rs

@ -15,17 +15,16 @@ pub mod ring_nq;
pub mod ring_torus; pub mod ring_torus;
pub mod tuple_ring; pub mod tuple_ring;
mod naive_ntt; // note: for dev only
// mod naive_ntt; // note: for dev only
pub mod ntt; pub mod ntt;
// expose objects // expose objects
pub use complex::C; pub use complex::C;
pub use matrix::Matrix; pub use matrix::Matrix;
pub use torus::T64; pub use torus::T64;
pub use zq::Zq; pub use zq::Zq;
pub use ring::Ring;
pub use ring::{Ring, RingParam};
pub use ring_n::R; pub use ring_n::R;
pub use ring_nq::Rq; pub use ring_nq::Rq;
pub use ring_torus::Tn; pub use ring_torus::Tn;

+ 114
- 66
arith/src/ntt.rs

@ -1,34 +1,60 @@
//! Implementation of the NTT & iNTT, following the CT & GS algorighms, more details in //! Implementation of the NTT & iNTT, following the CT & GS algorighms, more details in
//! https://eprint.iacr.org/2017/727.pdf, some notes at //! https://eprint.iacr.org/2017/727.pdf, some notes at
//! https://github.com/arnaucube/math/blob/master/notes_ntt.pdf . //! https://github.com/arnaucube/math/blob/master/notes_ntt.pdf .
use crate::zq::Zq;
//!
//! NOTE: initially I implemented it with fixed Q & N, given as constant
//! generics; but once using real-world parameters, the stack could not handle
//! it, so moved to use Vec instead of fixed-sized arrays, and adapted the NTT
//! implementation to that too.
use crate::{ring::RingParam, ring_nq::Rq, zq::Zq};
use std::collections::HashMap;
#[derive(Debug)] #[derive(Debug)]
pub struct NTT<const Q: u64, const N: usize> {}
impl<const Q: u64, const N: usize> NTT<Q, N> {
const N_INV: Zq<Q> = Zq(const_inv_mod::<Q>(N as u64));
// since we work over Zq[X]/(X^N+1) (negacyclic), get the 2*N-th root of unity
pub(crate) const ROOT_OF_UNITY: u64 = primitive_root_of_unity::<Q>(2 * N);
pub(crate) const ROOTS_OF_UNITY: [Zq<Q>; N] = roots_of_unity(Self::ROOT_OF_UNITY);
const ROOTS_OF_UNITY_INV: [Zq<Q>; N] = roots_of_unity_inv(Self::ROOTS_OF_UNITY);
pub struct NTT {}
use std::sync::{Mutex, OnceLock};
static CACHE: OnceLock<Mutex<HashMap<(u64, usize), (Vec<Zq>, Vec<Zq>, Zq)>>> = OnceLock::new();
fn roots(q: u64, n: usize) -> (Vec<Zq>, Vec<Zq>, Zq) {
let cache_lock = CACHE.get_or_init(|| Mutex::new(HashMap::new()));
let mut cache = cache_lock.lock().unwrap();
if let Some(value) = cache.get(&(q, n)) {
return value.clone();
}
let n_inv: Zq = Zq {
q,
v: const_inv_mod(q, n as u64),
};
let root_of_unity: u64 = primitive_root_of_unity(q, 2 * n);
let roots_of_unity: Vec<Zq> = roots_of_unity(q, n, root_of_unity);
let roots_of_unity_inv: Vec<Zq> = roots_of_unity_inv(q, n, roots_of_unity.clone());
let value = (roots_of_unity, roots_of_unity_inv, n_inv);
cache.insert((q, n), value.clone());
value
} }
impl<const Q: u64, const N: usize> NTT<Q, N> {
impl NTT {
/// implements the Cooley-Tukey (CT) algorithm. Details at /// implements the Cooley-Tukey (CT) algorithm. Details at
/// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.1 of /// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.1 of
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf /// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
pub fn ntt(a: [Zq<Q>; N]) -> [Zq<Q>; N] {
let mut t = N / 2;
pub fn ntt(a: &Rq) -> Rq {
let (q, n) = (a.param.q, a.param.n);
let (roots_of_unity, _, _) = roots(q, n);
let mut t = n / 2;
let mut m = 1; let mut m = 1;
let mut r: [Zq<Q>; N] = a.clone();
while m < N {
let mut r: Vec<Zq> = a.coeffs.clone();
while m < n {
let mut k = 0; let mut k = 0;
for i in 0..m { for i in 0..m {
let S: Zq<Q> = Self::ROOTS_OF_UNITY[m + i];
let S: Zq = roots_of_unity[m + i];
for j in k..k + t { for j in k..k + t {
let U: Zq<Q> = r[j];
let V: Zq<Q> = r[j + t] * S;
let U: Zq = r[j];
let V: Zq = r[j + t] * S;
r[j] = U + V; r[j] = U + V;
r[j + t] = U - V; r[j + t] = U - V;
} }
@ -37,23 +63,32 @@ impl NTT {
t /= 2; t /= 2;
m *= 2; m *= 2;
} }
r
// TODO think if maybe not return a Rq type, or if returned Rq, maybe
// fill the `evals` field, which is what we're actually returning here
Rq {
param: RingParam { q, n },
coeffs: r,
evals: None,
}
} }
/// implements the Cooley-Tukey (CT) algorithm. Details at /// implements the Cooley-Tukey (CT) algorithm. Details at
/// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.2 of /// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.2 of
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf /// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
pub fn intt(a: [Zq<Q>; N]) -> [Zq<Q>; N] {
pub fn intt(a: &Rq) -> Rq {
let (q, n) = (a.param.q, a.param.n);
let (_, roots_of_unity_inv, n_inv) = roots(q, n);
let mut t = 1; let mut t = 1;
let mut m = N / 2;
let mut r: [Zq<Q>; N] = a.clone();
let mut m = n / 2;
let mut r: Vec<Zq> = a.coeffs.clone();
while m > 0 { while m > 0 {
let mut k = 0; let mut k = 0;
for i in 0..m { for i in 0..m {
let S: Zq<Q> = Self::ROOTS_OF_UNITY_INV[m + i];
let S: Zq = roots_of_unity_inv[m + i];
for j in k..k + t { for j in k..k + t {
let U: Zq<Q> = r[j];
let V: Zq<Q> = r[j + t];
let U: Zq = r[j];
let V: Zq = r[j + t];
r[j] = U + V; r[j] = U + V;
r[j + t] = (U - V) * S; r[j + t] = (U - V) * S;
} }
@ -62,26 +97,32 @@ impl NTT {
t *= 2; t *= 2;
m /= 2; m /= 2;
} }
for i in 0..N {
r[i] = r[i] * Self::N_INV;
for i in 0..n {
r[i] = r[i] * n_inv;
}
Rq {
param: RingParam { q, n },
coeffs: r,
// TODO maybe at `evals` place the inputed `a` which is the evals
// format
evals: None,
} }
r
} }
} }
/// computes a primitive N-th root of unity using the method described by Thomas /// computes a primitive N-th root of unity using the method described by Thomas
/// Pornin in https://crypto.stackexchange.com/a/63616 /// Pornin in https://crypto.stackexchange.com/a/63616
const fn primitive_root_of_unity<const Q: u64>(N: usize) -> u64 {
assert!(N.is_power_of_two());
assert!((Q - 1) % N as u64 == 0);
const fn primitive_root_of_unity(q: u64, n: usize) -> u64 {
assert!(n.is_power_of_two());
assert!((q - 1) % n as u64 == 0);
let n_u64 = n as u64;
let n: u64 = N as u64;
let mut k = 1; let mut k = 1;
while k < Q {
while k < q {
// alternatively could get a random k at each iteration, if so, add the following if: // alternatively could get a random k at each iteration, if so, add the following if:
// `if k == 0 { continue; }` // `if k == 0 { continue; }`
let w = const_exp_mod::<Q>(k, (Q - 1) / n);
if const_exp_mod::<Q>(w, n / 2) != 1 {
let w = const_exp_mod(q, k, (q - 1) / n_u64);
if const_exp_mod(q, w, n_u64 / 2) != 1 {
return w; // w is a primitive N-th root of unity return w; // w is a primitive N-th root of unity
} }
k += 1; k += 1;
@ -89,78 +130,85 @@ const fn primitive_root_of_unity(N: usize) -> u64 {
panic!("No primitive root of unity"); panic!("No primitive root of unity");
} }
const fn roots_of_unity<const Q: u64, const N: usize>(w: u64) -> [Zq<Q>; N] {
let mut r: [Zq<Q>; N] = [Zq(0u64); N];
fn roots_of_unity(q: u64, n: usize, w: u64) -> Vec<Zq> {
let mut r: Vec<Zq> = vec![Zq { q, v: 0 }; n];
let mut i = 0; let mut i = 0;
let log_n = N.ilog2();
while i < N {
let log_n = n.ilog2();
while i < n {
// (return the roots in bit-reverset order) // (return the roots in bit-reverset order)
let j = ((i as u64).reverse_bits() >> (64 - log_n)) as usize; let j = ((i as u64).reverse_bits() >> (64 - log_n)) as usize;
r[i] = Zq(const_exp_mod::<Q>(w, j as u64));
r[i] = Zq {
q,
v: const_exp_mod(q, w, j as u64),
};
i += 1; i += 1;
} }
r r
} }
const fn roots_of_unity_inv<const Q: u64, const N: usize>(v: [Zq<Q>; N]) -> [Zq<Q>; N] {
fn roots_of_unity_inv(q: u64, n: usize, v: Vec<Zq>) -> Vec<Zq> {
// assumes that the inputted roots are already in bit-reverset order // assumes that the inputted roots are already in bit-reverset order
let mut r: [Zq<Q>; N] = [Zq(0u64); N];
let mut r: Vec<Zq> = vec![Zq { q, v: 0 }; n];
let mut i = 0; let mut i = 0;
while i < N {
r[i] = Zq(const_inv_mod::<Q>(v[i].0));
while i < n {
r[i] = Zq {
q,
v: const_inv_mod(q, v[i].v),
};
i += 1; i += 1;
} }
r r
} }
/// returns x^k mod Q /// returns x^k mod Q
const fn const_exp_mod<const Q: u64>(x: u64, k: u64) -> u64 {
const fn const_exp_mod(q: u64, x: u64, k: u64) -> u64 {
// work on u128 to avoid overflow // work on u128 to avoid overflow
let mut r = 1u128; let mut r = 1u128;
let mut x = x as u128; let mut x = x as u128;
let mut k = k as u128; let mut k = k as u128;
x = x % Q as u128;
x = x % q as u128;
// exponentiation by square strategy // exponentiation by square strategy
while k > 0 { while k > 0 {
if k % 2 == 1 { if k % 2 == 1 {
r = (r * x) % Q as u128;
r = (r * x) % q as u128;
} }
x = (x * x) % Q as u128;
x = (x * x) % q as u128;
k /= 2; k /= 2;
} }
r as u64 r as u64
} }
/// returns x^-1 mod Q /// returns x^-1 mod Q
const fn const_inv_mod<const Q: u64>(x: u64) -> u64 {
const fn const_inv_mod(q: u64, x: u64) -> u64 {
// by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q // by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q
const_exp_mod::<Q>(x, Q - 2)
const_exp_mod(q, x, q - 2)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::Ring;
use anyhow::Result; use anyhow::Result;
use std::array;
#[test] #[test]
fn test_ntt() -> Result<()> { fn test_ntt() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 4;
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 4;
let param = RingParam { q, n };
let a: [u64; N] = [1u64, 2, 3, 4];
let a: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(a[i]));
let a: Vec<u64> = vec![1u64, 2, 3, 4];
let a: Rq = Rq::from_vec_u64(&param, a);
let a_ntt = NTT::<Q, N>::ntt(a);
let a_ntt = NTT::ntt(&a);
let a_intt = NTT::<Q, N>::intt(a_ntt);
let a_intt = NTT::intt(&a_ntt);
dbg!(&a); dbg!(&a);
dbg!(&a_ntt); dbg!(&a_ntt);
dbg!(&a_intt); dbg!(&a_intt);
dbg!(NTT::<Q, N>::ROOT_OF_UNITY);
dbg!(NTT::<Q, N>::ROOTS_OF_UNITY);
// dbg!(NTT::ROOT_OF_UNITY);
// dbg!(NTT::ROOTS_OF_UNITY);
assert_eq!(a, a_intt); assert_eq!(a, a_intt);
Ok(()) Ok(())
@ -168,18 +216,18 @@ mod tests {
#[test] #[test]
fn test_ntt_loop() -> Result<()> { fn test_ntt_loop() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 512;
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 512;
let param = RingParam { q, n };
use rand::distributions::Distribution;
use rand::distributions::Uniform; use rand::distributions::Uniform;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let dist = Uniform::new(0_f64, Q as f64);
let dist = Uniform::new(0_f64, q as f64);
for _ in 0..100 {
let a: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng)));
let a_ntt = NTT::<Q, N>::ntt(a);
let a_intt = NTT::<Q, N>::intt(a_ntt);
for _ in 0..1000 {
let a: Rq = Rq::rand(&mut rng, dist, &param);
let a_ntt = NTT::ntt(&a);
let a_intt = NTT::intt(&a_ntt);
assert_eq!(a, a_intt); assert_eq!(a, a_intt);
} }
Ok(()) Ok(())

+ 187
- 0
arith/src/ntt_fixedsize.rs

@ -0,0 +1,187 @@
//! Implementation of the NTT & iNTT, following the CT & GS algorighms, more details in
//! https://eprint.iacr.org/2017/727.pdf, some notes at
//! https://github.com/arnaucube/math/blob/master/notes_ntt.pdf .
use crate::zq::Zq;
#[derive(Debug)]
pub struct NTT<const Q: u64, const N: usize> {}
impl<const Q: u64, const N: usize> NTT<Q, N> {
const N_INV: Zq<Q> = Zq(const_inv_mod::<Q>(N as u64));
// since we work over Zq[X]/(X^N+1) (negacyclic), get the 2*N-th root of unity
pub(crate) const ROOT_OF_UNITY: u64 = primitive_root_of_unity::<Q>(2 * N);
pub(crate) const ROOTS_OF_UNITY: [Zq<Q>; N] = roots_of_unity(Self::ROOT_OF_UNITY);
const ROOTS_OF_UNITY_INV: [Zq<Q>; N] = roots_of_unity_inv(Self::ROOTS_OF_UNITY);
}
impl<const Q: u64, const N: usize> NTT<Q, N> {
/// implements the Cooley-Tukey (CT) algorithm. Details at
/// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.1 of
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
pub fn ntt(a: [Zq<Q>; N]) -> [Zq<Q>; N] {
let mut t = N / 2;
let mut m = 1;
let mut r: [Zq<Q>; N] = a.clone();
while m < N {
let mut k = 0;
for i in 0..m {
let S: Zq<Q> = Self::ROOTS_OF_UNITY[m + i];
for j in k..k + t {
let U: Zq<Q> = r[j];
let V: Zq<Q> = r[j + t] * S;
r[j] = U + V;
r[j + t] = U - V;
}
k = k + 2 * t;
}
t /= 2;
m *= 2;
}
r
}
/// implements the Cooley-Tukey (CT) algorithm. Details at
/// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.2 of
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
pub fn intt(a: [Zq<Q>; N]) -> [Zq<Q>; N] {
let mut t = 1;
let mut m = N / 2;
let mut r: [Zq<Q>; N] = a.clone();
while m > 0 {
let mut k = 0;
for i in 0..m {
let S: Zq<Q> = Self::ROOTS_OF_UNITY_INV[m + i];
for j in k..k + t {
let U: Zq<Q> = r[j];
let V: Zq<Q> = r[j + t];
r[j] = U + V;
r[j + t] = (U - V) * S;
}
k += 2 * t;
}
t *= 2;
m /= 2;
}
for i in 0..N {
r[i] = r[i] * Self::N_INV;
}
r
}
}
/// computes a primitive N-th root of unity using the method described by Thomas
/// Pornin in https://crypto.stackexchange.com/a/63616
const fn primitive_root_of_unity<const Q: u64>(N: usize) -> u64 {
assert!(N.is_power_of_two());
assert!((Q - 1) % N as u64 == 0);
let n: u64 = N as u64;
let mut k = 1;
while k < Q {
// alternatively could get a random k at each iteration, if so, add the following if:
// `if k == 0 { continue; }`
let w = const_exp_mod::<Q>(k, (Q - 1) / n);
if const_exp_mod::<Q>(w, n / 2) != 1 {
return w; // w is a primitive N-th root of unity
}
k += 1;
}
panic!("No primitive root of unity");
}
const fn roots_of_unity<const Q: u64, const N: usize>(w: u64) -> [Zq<Q>; N] {
let mut r: [Zq<Q>; N] = [Zq(0u64); N];
let mut i = 0;
let log_n = N.ilog2();
while i < N {
// (return the roots in bit-reverset order)
let j = ((i as u64).reverse_bits() >> (64 - log_n)) as usize;
r[i] = Zq(const_exp_mod::<Q>(w, j as u64));
i += 1;
}
r
}
const fn roots_of_unity_inv<const Q: u64, const N: usize>(v: [Zq<Q>; N]) -> [Zq<Q>; N] {
// assumes that the inputted roots are already in bit-reverset order
let mut r: [Zq<Q>; N] = [Zq(0u64); N];
let mut i = 0;
while i < N {
r[i] = Zq(const_inv_mod::<Q>(v[i].0));
i += 1;
}
r
}
/// returns x^k mod Q
const fn const_exp_mod<const Q: u64>(x: u64, k: u64) -> u64 {
// work on u128 to avoid overflow
let mut r = 1u128;
let mut x = x as u128;
let mut k = k as u128;
x = x % Q as u128;
// exponentiation by square strategy
while k > 0 {
if k % 2 == 1 {
r = (r * x) % Q as u128;
}
x = (x * x) % Q as u128;
k /= 2;
}
r as u64
}
/// returns x^-1 mod Q
const fn const_inv_mod<const Q: u64>(x: u64) -> u64 {
// by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q
const_exp_mod::<Q>(x, Q - 2)
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
use std::array;
#[test]
fn test_ntt() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 4;
let a: [u64; N] = [1u64, 2, 3, 4];
let a: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(a[i]));
let a_ntt = NTT::<Q, N>::ntt(a);
let a_intt = NTT::<Q, N>::intt(a_ntt);
dbg!(&a);
dbg!(&a_ntt);
dbg!(&a_intt);
dbg!(NTT::<Q, N>::ROOT_OF_UNITY);
dbg!(NTT::<Q, N>::ROOTS_OF_UNITY);
assert_eq!(a, a_intt);
Ok(())
}
#[test]
fn test_ntt_loop() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 512;
use rand::distributions::Distribution;
use rand::distributions::Uniform;
let mut rng = rand::thread_rng();
let dist = Uniform::new(0_f64, Q as f64);
for _ in 0..100 {
let a: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng)));
let a_ntt = NTT::<Q, N>::ntt(a);
let a_intt = NTT::<Q, N>::intt(a_ntt);
assert_eq!(a, a_intt);
}
Ok(())
}
}

+ 13
- 9
arith/src/ring.rs

@ -3,6 +3,12 @@ use std::fmt::Debug;
use std::iter::Sum; use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}; use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct RingParam {
pub q: u64, // TODO think if really needed or it's fine with coeffs[0].q
pub n: usize,
}
/// Represents a ring element. Currently implemented by ring_nq.rs#Rq and /// Represents a ring element. Currently implemented by ring_nq.rs#Rq and
/// ring_torus.rs#Tn. Is not a 'pure algebraic ring', but more a custom trait /// ring_torus.rs#Tn. Is not a 'pure algebraic ring', but more a custom trait
/// definition which includes methods like `mod_switch`. /// definition which includes methods like `mod_switch`.
@ -21,27 +27,25 @@ pub trait Ring:
+ PartialEq + PartialEq
+ Debug + Debug
+ Clone + Clone
+ Copy
// + Copy
+ Sum<<Self as Add>::Output> + Sum<<Self as Add>::Output>
+ Sum<<Self as Mul>::Output> + Sum<<Self as Mul>::Output>
{ {
/// C defines the coefficient type /// C defines the coefficient type
type C: Debug + Clone; type C: Debug + Clone;
const Q: u64;
const N: usize;
fn param(&self) -> RingParam;
fn coeffs(&self) -> Vec<Self::C>; fn coeffs(&self) -> Vec<Self::C>;
fn zero() -> Self;
fn zero(param: &RingParam) -> Self;
// note/wip/warning: dist (0,q) with f64, will output more '0=q' elements than other values // note/wip/warning: dist (0,q) with f64, will output more '0=q' elements than other values
fn rand(rng: impl Rng, dist: impl Distribution<f64>) -> Self;
fn rand(rng: impl Rng, dist: impl Distribution<f64>, param: &RingParam) -> Self;
fn from_vec(coeffs: Vec<Self::C>) -> Self;
fn from_vec(param: &RingParam, coeffs: Vec<Self::C>) -> Self;
fn decompose(&self, beta: u32, l: u32) -> Vec<Self>; fn decompose(&self, beta: u32, l: u32) -> Vec<Self>;
fn remodule<const P: u64>(&self) -> impl Ring;
fn mod_switch<const P: u64>(&self) -> impl Ring;
fn remodule(&self, p:u64) -> impl Ring;
fn mod_switch(&self, p:u64) -> impl Ring;
/// returns [ [(num/den) * self].round() ] mod q /// returns [ [(num/den) * self].round() ] mod q
/// ie. performs the multiplication and division over f64, and then it /// ie. performs the multiplication and division over f64, and then it

+ 189
- 162
arith/src/ring_n.rs

@ -1,45 +1,43 @@
//! Polynomial ring Z[X]/(X^N+1) //! Polynomial ring Z[X]/(X^N+1)
//! //!
use anyhow::Result;
use itertools::zip_eq;
use rand::{distributions::Distribution, Rng}; use rand::{distributions::Distribution, Rng};
use std::array;
use std::fmt; use std::fmt;
use std::iter::Sum; use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use crate::Ring;
use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign};
// TODO rename to not have name conflicts with the Ring trait (R: Ring) // TODO rename to not have name conflicts with the Ring trait (R: Ring)
// PolynomialRing element, where the PolynomialRing is R = Z[X]/(X^n +1) // PolynomialRing element, where the PolynomialRing is R = Z[X]/(X^n +1)
#[derive(Clone, Copy)]
pub struct R<const N: usize>(pub [i64; N]);
// impl<const N: usize> Ring for R<N> {
impl<const N: usize> R<N> {
// type C = i64;
// const Q: u64 = i64::MAX as u64; // WIP
// const N: usize = N;
#[derive(Clone)]
pub struct R {
pub n: usize,
pub coeffs: Vec<i64>,
}
impl R {
pub fn coeffs(&self) -> Vec<i64> { pub fn coeffs(&self) -> Vec<i64> {
self.0.to_vec()
self.coeffs.clone()
} }
fn zero() -> Self {
let coeffs: [i64; N] = array::from_fn(|_| 0i64);
Self(coeffs)
fn zero(n: usize) -> Self {
Self {
n,
coeffs: vec![0i64; n],
}
} }
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>) -> Self {
// let coeffs: [i64; N] = array::from_fn(|_| Self::C::rand(&mut rng, &dist));
let coeffs: [i64; N] = array::from_fn(|_| dist.sample(&mut rng).round() as i64);
Self(coeffs)
// let coeffs: [C; N] = array::from_fn(|_| Zq::from_u64(dist.sample(&mut rng)));
// Self(coeffs)
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>, n: usize) -> Self {
Self {
n,
coeffs: std::iter::repeat_with(|| dist.sample(&mut rng).round() as i64)
.take(n)
.collect(),
}
} }
pub fn from_vec(coeffs: Vec<i64>) -> Self {
pub fn from_vec(n: usize, coeffs: Vec<i64>) -> Self {
let mut p = coeffs; let mut p = coeffs;
modulus::<N>(&mut p);
Self(array::from_fn(|i| p[i]))
modulus(n, &mut p);
Self { n, coeffs: p }
} }
/* /*
@ -71,34 +69,38 @@ impl R {
*/ */
} }
impl<const Q: u64, const N: usize> From<crate::ring_nq::Rq<Q, N>> for R<N> {
fn from(rq: crate::ring_nq::Rq<Q, N>) -> Self {
Self::from_vec_u64(rq.coeffs().to_vec().iter().map(|e| e.0).collect())
impl From<crate::ring_nq::Rq> for R {
fn from(rq: crate::ring_nq::Rq) -> Self {
Self::from_vec_u64(
rq.param.n,
rq.coeffs().to_vec().iter().map(|e| e.v).collect(),
)
} }
} }
impl<const N: usize> R<N> {
// pub fn coeffs(&self) -> [i64; N] {
// self.0
// }
pub fn to_rq<const Q: u64>(self) -> crate::Rq<Q, N> {
crate::Rq::<Q, N>::from(self)
impl R {
pub fn to_rq(self, q: u64) -> crate::Rq {
crate::Rq::from((q, self))
} }
// this method is mostly for tests // this method is mostly for tests
pub fn from_vec_u64(coeffs: Vec<u64>) -> Self {
pub fn from_vec_u64(n: usize, coeffs: Vec<u64>) -> Self {
let coeffs_i64 = coeffs.iter().map(|c| *c as i64).collect(); let coeffs_i64 = coeffs.iter().map(|c| *c as i64).collect();
Self::from_vec(coeffs_i64)
Self::from_vec(n, coeffs_i64)
} }
pub fn from_vec_f64(coeffs: Vec<f64>) -> Self {
pub fn from_vec_f64(n: usize, coeffs: Vec<f64>) -> Self {
let coeffs_i64 = coeffs.iter().map(|c| c.round() as i64).collect(); let coeffs_i64 = coeffs.iter().map(|c| c.round() as i64).collect();
Self::from_vec(coeffs_i64)
Self::from_vec(n, coeffs_i64)
} }
pub fn new(coeffs: [i64; N]) -> Self {
Self(coeffs)
pub fn new(n: usize, coeffs: Vec<i64>) -> Self {
assert_eq!(n, coeffs.len());
Self { n, coeffs }
} }
pub fn mul_by_i64(&self, s: i64) -> Self { pub fn mul_by_i64(&self, s: i64) -> Self {
Self(array::from_fn(|i| self.0[i] * s))
Self {
n: self.n,
coeffs: self.coeffs.iter().map(|c_i| c_i * s).collect(),
}
} }
pub fn infinity_norm(&self) -> u64 { pub fn infinity_norm(&self) -> u64 {
@ -108,10 +110,10 @@ impl R {
.map(|x| x.abs() as u64) .map(|x| x.abs() as u64)
.fold(0, |a, b| a.max(b)) .fold(0, |a, b| a.max(b))
} }
pub fn mod_centered_q<const Q: u64>(&self) -> R<N> {
let q = Q as i64;
pub fn mod_centered_q(&self, q: u64) -> R {
let q = q as i64;
let r = self let r = self
.0
.coeffs
.iter() .iter()
.map(|v| { .map(|v| {
let mut res = v % q; let mut res = v % q;
@ -121,190 +123,216 @@ impl R {
res res
}) })
.collect::<Vec<i64>>(); .collect::<Vec<i64>>();
R::<N>::from_vec(r)
R::from_vec(self.n, r)
} }
} }
pub fn mul_div_round<const Q: u64, const N: usize>(
v: Vec<i64>,
num: u64,
den: u64,
) -> crate::Rq<Q, N> {
pub fn mul_div_round(q: u64, n: usize, v: Vec<i64>, num: u64, den: u64) -> crate::Rq {
// dbg!(&v); // dbg!(&v);
let r: Vec<f64> = v let r: Vec<f64> = v
.iter() .iter()
.map(|e| ((num as f64 * *e as f64) / den as f64).round()) .map(|e| ((num as f64 * *e as f64) / den as f64).round())
.collect(); .collect();
// dbg!(&r); // dbg!(&r);
crate::Rq::<Q, N>::from_vec_f64(r)
crate::Rq::from_vec_f64(&crate::ring::RingParam { q, n }, r)
} }
// TODO rename to make it clear that is not mod q, but mod X^N+1 // TODO rename to make it clear that is not mod q, but mod X^N+1
// apply mod (X^N+1) // apply mod (X^N+1)
pub fn modulus<const N: usize>(p: &mut Vec<i64>) {
if p.len() < N {
pub fn modulus(n: usize, p: &mut Vec<i64>) {
if p.len() < n {
return; return;
} }
for i in N..p.len() {
p[i - N] = p[i - N].clone() - p[i].clone();
for i in n..p.len() {
p[i - n] = p[i - n].clone() - p[i].clone();
p[i] = 0; p[i] = 0;
} }
p.truncate(N);
p.truncate(n);
} }
pub fn modulus_i128<const N: usize>(p: &mut Vec<i128>) {
if p.len() < N {
pub fn modulus_i128(n: usize, p: &mut Vec<i128>) {
if p.len() < n {
return; return;
} }
for i in N..p.len() {
p[i - N] = p[i - N].clone() - p[i].clone();
for i in n..p.len() {
p[i - n] = p[i - n].clone() - p[i].clone();
p[i] = 0; p[i] = 0;
} }
p.truncate(N);
p.truncate(n);
} }
impl<const N: usize> PartialEq for R<N> {
impl PartialEq for R {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.0 == other.0
self.coeffs == other.coeffs && self.n == other.n
} }
} }
impl<const N: usize> Add<R<N>> for R<N> {
impl Add<R> for R {
type Output = Self; type Output = Self;
fn add(self, rhs: Self) -> Self { fn add(self, rhs: Self) -> Self {
Self(array::from_fn(|i| self.0[i] + rhs.0[i]))
assert_eq!(self.n, rhs.n);
Self {
n: self.n,
coeffs: zip_eq(self.coeffs, rhs.coeffs)
.map(|(l, r)| l + r)
.collect(),
}
} }
} }
impl<const N: usize> Add<&R<N>> for &R<N> {
type Output = R<N>;
fn add(self, rhs: &R<N>) -> Self::Output {
R(array::from_fn(|i| self.0[i] + rhs.0[i]))
impl Add<&R> for &R {
type Output = R;
fn add(self, rhs: &R) -> Self::Output {
assert_eq!(self.n, rhs.n);
R {
n: self.n,
coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone())
.map(|(l, r)| l + r)
.collect(),
}
} }
} }
impl<const N: usize> AddAssign for R<N> {
impl AddAssign for R {
fn add_assign(&mut self, rhs: Self) { fn add_assign(&mut self, rhs: Self) {
for i in 0..N {
self.0[i] += rhs.0[i];
assert_eq!(self.n, rhs.n);
for i in 0..self.n {
self.coeffs[i] += rhs.coeffs[i];
} }
} }
} }
impl<const N: usize> Sum<R<N>> for R<N> {
fn sum<I>(iter: I) -> Self
impl Sum<R> for R {
fn sum<I>(mut iter: I) -> Self
where where
I: Iterator<Item = Self>, I: Iterator<Item = Self>,
{ {
let mut acc = R::<N>::zero();
for e in iter {
acc += e;
}
acc
let first = iter.next().unwrap();
iter.fold(first, |acc, x| acc + x)
} }
} }
impl<const N: usize> Sub<R<N>> for R<N> {
impl Sub<R> for R {
type Output = Self; type Output = Self;
fn sub(self, rhs: Self) -> Self { fn sub(self, rhs: Self) -> Self {
Self(array::from_fn(|i| self.0[i] - rhs.0[i]))
assert_eq!(self.n, rhs.n);
Self {
n: self.n,
coeffs: zip_eq(self.coeffs, rhs.coeffs)
.map(|(l, r)| l - r)
.collect(),
}
} }
} }
impl<const N: usize> Sub<&R<N>> for &R<N> {
type Output = R<N>;
fn sub(self, rhs: &R<N>) -> Self::Output {
R(array::from_fn(|i| self.0[i] - rhs.0[i]))
impl Sub<&R> for &R {
type Output = R;
fn sub(self, rhs: &R) -> Self::Output {
assert_eq!(self.n, rhs.n);
R {
n: self.n,
coeffs: zip_eq(&self.coeffs, &rhs.coeffs)
.map(|(l, r)| l - r)
.collect(),
}
} }
} }
impl<const N: usize> SubAssign for R<N> {
impl SubAssign for R {
fn sub_assign(&mut self, rhs: Self) { fn sub_assign(&mut self, rhs: Self) {
for i in 0..N {
self.0[i] -= rhs.0[i];
assert_eq!(self.n, rhs.n);
for i in 0..self.n {
self.coeffs[i] -= rhs.coeffs[i];
} }
} }
} }
impl<const N: usize> Mul<R<N>> for R<N> {
impl Mul<R> for R {
type Output = Self; type Output = Self;
fn mul(self, rhs: Self) -> Self { fn mul(self, rhs: Self) -> Self {
naive_poly_mul(&self, &rhs) naive_poly_mul(&self, &rhs)
} }
} }
impl<const N: usize> Mul<&R<N>> for &R<N> {
type Output = R<N>;
impl Mul<&R> for &R {
type Output = R;
fn mul(self, rhs: &R<N>) -> Self::Output {
fn mul(self, rhs: &R) -> Self::Output {
naive_poly_mul(self, rhs) naive_poly_mul(self, rhs)
} }
} }
// TODO WIP // TODO WIP
pub fn naive_poly_mul<const N: usize>(poly1: &R<N>, poly2: &R<N>) -> R<N> {
let poly1: Vec<i128> = poly1.0.iter().map(|c| *c as i128).collect();
let poly2: Vec<i128> = poly2.0.iter().map(|c| *c as i128).collect();
let mut result: Vec<i128> = vec![0; (N * 2) - 1];
for i in 0..N {
for j in 0..N {
pub fn naive_poly_mul(poly1: &R, poly2: &R) -> R {
assert_eq!(poly1.n, poly2.n);
let n = poly1.n;
let poly1: Vec<i128> = poly1.coeffs.iter().map(|c| *c as i128).collect();
let poly2: Vec<i128> = poly2.coeffs.iter().map(|c| *c as i128).collect();
let mut result: Vec<i128> = vec![0; (n * 2) - 1];
for i in 0..n {
for j in 0..n {
result[i + j] = result[i + j] + poly1[i] * poly2[j]; result[i + j] = result[i + j] + poly1[i] * poly2[j];
} }
} }
// apply mod (X^N + 1)) // apply mod (X^N + 1))
// R::<N>::from_vec(result.iter().map(|c| *c as i64).collect()) // R::<N>::from_vec(result.iter().map(|c| *c as i64).collect())
modulus_i128::<N>(&mut result);
modulus_i128(n, &mut result);
// dbg!(&result); // dbg!(&result);
// dbg!(R::<N>(array::from_fn(|i| result[i] as i64)).coeffs()); // dbg!(R::<N>(array::from_fn(|i| result[i] as i64)).coeffs());
let result_i64: Vec<i64> = result.iter().map(|c_i| *c_i as i64).collect();
let r = R::from_vec(n, result_i64);
// sanity check: check that there are no coeffs > i64_max // sanity check: check that there are no coeffs > i64_max
assert_eq!( assert_eq!(
result, result,
R::<N>(array::from_fn(|i| result[i] as i64))
.coeffs()
.iter()
.map(|c| *c as i128)
.collect::<Vec<_>>()
r.coeffs.iter().map(|c| *c as i128).collect::<Vec<_>>()
); );
R(array::from_fn(|i| result[i] as i64))
r
} }
pub fn naive_mul_2<const N: usize>(poly1: &Vec<i128>, poly2: &Vec<i128>) -> Vec<i128> {
let mut result: Vec<i128> = vec![0; (N * 2) - 1];
for i in 0..N {
for j in 0..N {
pub fn naive_mul_2(n: usize, poly1: &Vec<i128>, poly2: &Vec<i128>) -> Vec<i128> {
let mut result: Vec<i128> = vec![0; (n * 2) - 1];
for i in 0..n {
for j in 0..n {
result[i + j] = result[i + j] + poly1[i] * poly2[j]; result[i + j] = result[i + j] + poly1[i] * poly2[j];
} }
} }
// apply mod (X^N + 1)) // apply mod (X^N + 1))
// R::<N>::from_vec(result.iter().map(|c| *c as i64).collect()) // R::<N>::from_vec(result.iter().map(|c| *c as i64).collect())
modulus_i128::<N>(&mut result);
modulus_i128(n, &mut result);
result result
} }
pub fn naive_mul<const N: usize>(poly1: &R<N>, poly2: &R<N>) -> Vec<i64> {
let poly1: Vec<i128> = poly1.0.iter().map(|c| *c as i128).collect();
let poly2: Vec<i128> = poly2.0.iter().map(|c| *c as i128).collect();
let mut result = vec![0; (N * 2) - 1];
for i in 0..N {
for j in 0..N {
pub fn naive_mul(poly1: &R, poly2: &R) -> Vec<i64> {
assert_eq!(poly1.n, poly2.n);
let n = poly1.n;
let poly1: Vec<i128> = poly1.coeffs.iter().map(|c| *c as i128).collect();
let poly2: Vec<i128> = poly2.coeffs.iter().map(|c| *c as i128).collect();
let mut result = vec![0; (n * 2) - 1];
for i in 0..n {
for j in 0..n {
result[i + j] = result[i + j] + poly1[i] * poly2[j]; result[i + j] = result[i + j] + poly1[i] * poly2[j];
} }
} }
result.iter().map(|c| *c as i64).collect() result.iter().map(|c| *c as i64).collect()
} }
pub fn naive_mul_TMP<const N: usize>(poly1: &R<N>, poly2: &R<N>) -> Vec<i64> {
let poly1: Vec<i128> = poly1.0.iter().map(|c| *c as i128).collect();
let poly2: Vec<i128> = poly2.0.iter().map(|c| *c as i128).collect();
let mut result: Vec<i128> = vec![0; (N * 2) - 1];
for i in 0..N {
for j in 0..N {
pub fn naive_mul_TMP(poly1: &R, poly2: &R) -> Vec<i64> {
assert_eq!(poly1.n, poly2.n);
let n = poly1.n;
let poly1: Vec<i128> = poly1.coeffs.iter().map(|c| *c as i128).collect();
let poly2: Vec<i128> = poly2.coeffs.iter().map(|c| *c as i128).collect();
let mut result: Vec<i128> = vec![0; (n * 2) - 1];
for i in 0..n {
for j in 0..n {
result[i + j] = result[i + j] + poly1[i] * poly2[j]; result[i + j] = result[i + j] + poly1[i] * poly2[j];
} }
} }
// dbg!(&result); // dbg!(&result);
modulus_i128::<N>(&mut result);
modulus_i128(n, &mut result);
// for c_i in result.iter() { // for c_i in result.iter() {
// println!("---"); // println!("---");
// println!("{:?}", &c_i); // println!("{:?}", &c_i);
@ -316,8 +344,8 @@ pub fn naive_mul_TMP(poly1: &R, poly2: &R) -> Vec {
} }
// wip // wip
pub fn mod_centered_q<const Q: u64, const N: usize>(p: Vec<i128>) -> R<N> {
let q: i128 = Q as i128;
pub fn mod_centered_q(q: u64, n: usize, p: Vec<i128>) -> R {
let q: i128 = q as i128;
let r = p let r = p
.iter() .iter()
.map(|v| { .map(|v| {
@ -328,10 +356,10 @@ pub fn mod_centered_q(p: Vec) -> R {
res res
}) })
.collect::<Vec<i128>>(); .collect::<Vec<i128>>();
R::<N>::from_vec(r.iter().map(|v| *v as i64).collect::<Vec<i64>>())
R::from_vec(n, r.iter().map(|v| *v as i64).collect::<Vec<i64>>())
} }
impl<const N: usize> Mul<i64> for R<N> {
impl Mul<i64> for R {
type Output = Self; type Output = Self;
fn mul(self, s: i64) -> Self { fn mul(self, s: i64) -> Self {
@ -339,34 +367,38 @@ impl Mul for R {
} }
} }
// mul by u64 // mul by u64
impl<const N: usize> Mul<u64> for R<N> {
impl Mul<u64> for R {
type Output = Self; type Output = Self;
fn mul(self, s: u64) -> Self { fn mul(self, s: u64) -> Self {
self.mul_by_i64(s as i64) self.mul_by_i64(s as i64)
} }
} }
impl<const N: usize> Mul<&u64> for &R<N> {
type Output = R<N>;
impl Mul<&u64> for &R {
type Output = R;
fn mul(self, s: &u64) -> Self::Output { fn mul(self, s: &u64) -> Self::Output {
self.mul_by_i64(*s as i64) self.mul_by_i64(*s as i64)
} }
} }
impl<const N: usize> Neg for R<N> {
impl Neg for R {
type Output = Self; type Output = Self;
fn neg(self) -> Self::Output { fn neg(self) -> Self::Output {
Self(array::from_fn(|i| -self.0[i]))
// Self(array::from_fn(|i| -self.0[i]))
Self {
n: self.n,
coeffs: self.coeffs.iter().map(|c_i| -c_i).collect(),
}
} }
} }
impl<const N: usize> R<N> {
impl R {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut str = ""; let mut str = "";
let mut zero = true; let mut zero = true;
for (i, coeff) in self.0.iter().enumerate().rev() {
for (i, coeff) in self.coeffs.iter().enumerate().rev() {
if *coeff == 0 { if *coeff == 0 {
continue; continue;
} }
@ -395,18 +427,18 @@ impl R {
f.write_str(" mod Z")?; f.write_str(" mod Z")?;
f.write_str("/(X^")?; f.write_str("/(X^")?;
f.write_str(N.to_string().as_str())?;
f.write_str(self.n.to_string().as_str())?;
f.write_str("+1)")?; f.write_str("+1)")?;
Ok(()) Ok(())
} }
} }
impl<const N: usize> fmt::Display for R<N> {
impl fmt::Display for R {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt(f)?; self.fmt(f)?;
Ok(()) Ok(())
} }
} }
impl<const N: usize> fmt::Debug for R<N> {
impl fmt::Debug for R {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt(f)?; self.fmt(f)?;
Ok(()) Ok(())
@ -420,38 +452,33 @@ mod tests {
#[test] #[test]
fn test_mul() -> Result<()> { fn test_mul() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 2;
let q: i64 = Q as i64;
let n: usize = 2;
let q: i64 = (2u64.pow(16) + 1) as i64;
// test vectors generated with SageMath // test vectors generated with SageMath
let a: [i64; N] = [q - 1, q - 1];
let b: [i64; N] = [q - 1, q - 1];
let c: [i64; N] = [0, 8589934592];
test_mul_opt::<Q, N>(a, b, c)?;
let a: Vec<i64> = vec![q - 1, q - 1];
let b: Vec<i64> = vec![q - 1, q - 1];
let c: Vec<i64> = vec![0, 8589934592];
test_mul_opt(n, a, b, c)?;
let a: [i64; N] = [1, q - 1];
let b: [i64; N] = [1, q - 1];
let c: [i64; N] = [-4294967295, 131072];
test_mul_opt::<Q, N>(a, b, c)?;
let a: Vec<i64> = vec![1, q - 1];
let b: Vec<i64> = vec![1, q - 1];
let c: Vec<i64> = vec![-4294967295, 131072];
test_mul_opt(n, a, b, c)?;
Ok(()) Ok(())
} }
fn test_mul_opt<const Q: u64, const N: usize>(
a: [i64; N],
b: [i64; N],
expected_c: [i64; N],
) -> Result<()> {
let mut a = R::new(a);
let mut b = R::new(b);
fn test_mul_opt(n: usize, a: Vec<i64>, b: Vec<i64>, expected_c: Vec<i64>) -> Result<()> {
let mut a = R::new(n, a);
let mut b = R::new(n, b);
dbg!(&a); dbg!(&a);
dbg!(&b); dbg!(&b);
let expected_c = R::new(expected_c);
let expected_c = R::new(n, expected_c);
let mut c = naive_mul(&mut a, &mut b); let mut c = naive_mul(&mut a, &mut b);
modulus::<N>(&mut c);
dbg!(R::<N>::from_vec(c.clone()));
assert_eq!(c, expected_c.0.to_vec());
modulus(n, &mut c);
dbg!(R::from_vec(n, c.clone()));
assert_eq!(c, expected_c.coeffs);
Ok(()) Ok(())
} }
} }

+ 301
- 226
arith/src/ring_nq.rs

@ -2,8 +2,8 @@
//! //!
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use itertools::zip_eq;
use rand::{distributions::Distribution, Rng}; use rand::{distributions::Distribution, Rng};
use std::array;
use std::fmt; use std::fmt;
use std::iter::Sum; use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}; use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign};
@ -11,82 +11,91 @@ use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign};
use crate::ntt::NTT; use crate::ntt::NTT;
use crate::zq::{modulus_u64, Zq}; use crate::zq::{modulus_u64, Zq};
use crate::Ring;
use crate::{Ring, RingParam};
// NOTE: currently using fixed-size arrays, but pending to see if with
// real-world parameters the stack can keep up; if not will move everything to
// use Vec.
/// PolynomialRing element, where the PolynomialRing is R = Z_q[X]/(X^n +1) /// PolynomialRing element, where the PolynomialRing is R = Z_q[X]/(X^n +1)
/// The implementation assumes that q is prime. /// The implementation assumes that q is prime.
#[derive(Clone, Copy)]
pub struct Rq<const Q: u64, const N: usize> {
pub(crate) coeffs: [Zq<Q>; N],
#[derive(Clone)]
pub struct Rq {
pub param: RingParam,
pub(crate) coeffs: Vec<Zq>,
// evals are set when doig a PRxPR multiplication, so it can be reused in future // evals are set when doig a PRxPR multiplication, so it can be reused in future
// multiplications avoiding recomputing it // multiplications avoiding recomputing it
pub(crate) evals: Option<[Zq<Q>; N]>,
pub(crate) evals: Option<Vec<Zq>>,
} }
impl<const Q: u64, const N: usize> Ring for Rq<Q, N> {
type C = Zq<Q>;
const Q: u64 = Q;
const N: usize = N;
impl Ring for Rq {
type C = Zq;
fn param(&self) -> RingParam {
self.param
}
fn coeffs(&self) -> Vec<Self::C> { fn coeffs(&self) -> Vec<Self::C> {
self.coeffs.to_vec() self.coeffs.to_vec()
} }
fn zero() -> Self {
let coeffs = array::from_fn(|_| Zq::zero());
fn zero(param: &RingParam) -> Self {
Self { Self {
coeffs,
param: param.clone(),
coeffs: vec![Zq::zero(param.q); param.n],
evals: None, evals: None,
} }
} }
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>) -> Self {
// let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_u64(dist.sample(&mut rng)));
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Self::C::rand(&mut rng, &dist));
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>, param: &RingParam) -> Self {
Self { Self {
coeffs,
param: param.clone(),
coeffs: std::iter::repeat_with(|| Self::C::rand(&mut rng, &dist, param.q))
.take(param.n)
.collect(),
evals: None, evals: None,
} }
} }
fn from_vec(coeffs: Vec<Zq<Q>>) -> Self {
fn from_vec(param: &RingParam, coeffs: Vec<Zq>) -> Self {
let mut p = coeffs; let mut p = coeffs;
modulus::<Q, N>(&mut p);
let coeffs = array::from_fn(|i| p[i]);
modulus(param.q, param.n, &mut p);
Self { Self {
coeffs,
param: param.clone(),
coeffs: p,
evals: None, evals: None,
} }
} }
// returns the decomposition of each polynomial coefficient, such // returns the decomposition of each polynomial coefficient, such
// decomposition will be a vecotor of length N, containint N vectors of Zq
// decomposition will be a vector of length N, containing N vectors of Zq
fn decompose(&self, beta: u32, l: u32) -> Vec<Self> { fn decompose(&self, beta: u32, l: u32) -> Vec<Self> {
let elems: Vec<Vec<Zq<Q>>> = self.coeffs.iter().map(|r| r.decompose(beta, l)).collect();
let elems: Vec<Vec<Zq>> = self.coeffs.iter().map(|r| r.decompose(beta, l)).collect();
// transpose it // transpose it
let r: Vec<Vec<Zq<Q>>> = (0..elems[0].len())
let r: Vec<Vec<Zq>> = (0..elems[0].len())
.map(|i| (0..elems.len()).map(|j| elems[j][i]).collect()) .map(|i| (0..elems.len()).map(|j| elems[j][i]).collect())
.collect(); .collect();
// convert it to Rq<Q,N> // convert it to Rq<Q,N>
r.iter().map(|a_i| Self::from_vec(a_i.clone())).collect()
r.iter()
.map(|a_i| Self::from_vec(&self.param, a_i.clone()))
.collect()
} }
// Warning: this method will behave differently depending on the values P and Q: // Warning: this method will behave differently depending on the values P and Q:
// if Q<P, it just 'renames' the modulus parameter to P // if Q<P, it just 'renames' the modulus parameter to P
// if Q>=P, it crops to mod P // if Q>=P, it crops to mod P
fn remodule<const P: u64>(&self) -> Rq<P, N> {
Rq::<P, N>::from_vec_u64(self.coeffs().iter().map(|m_i| m_i.0).collect())
fn remodule(&self, p: u64) -> Rq {
let param = RingParam {
q: p,
n: self.param.n,
};
Rq::from_vec_u64(&param, self.coeffs().iter().map(|m_i| m_i.v).collect())
} }
/// perform the mod switch operation from Q to Q', where Q2=Q' /// perform the mod switch operation from Q to Q', where Q2=Q'
// fn mod_switch<const P: u64, const M: usize>(&self) -> impl Ring {
fn mod_switch<const P: u64>(&self) -> Rq<P, N> {
// assert_eq!(N, M); // sanity check
Rq::<P, N> {
coeffs: array::from_fn(|i| self.coeffs[i].mod_switch::<P>()),
fn mod_switch(&self, p: u64) -> Rq {
let param = RingParam {
q: p,
n: self.param.n,
};
Rq {
param,
coeffs: self.coeffs.iter().map(|c_i| c_i.mod_switch(p)).collect(),
evals: None, evals: None,
} }
} }
@ -98,105 +107,138 @@ impl Ring for Rq {
let r: Vec<f64> = self let r: Vec<f64> = self
.coeffs() .coeffs()
.iter() .iter()
.map(|e| ((num as f64 * e.0 as f64) / den as f64).round())
.map(|e| ((num as f64 * e.v as f64) / den as f64).round())
.collect(); .collect();
Rq::<Q, N>::from_vec_f64(r)
Rq::from_vec_f64(&self.param, r)
} }
} }
impl<const Q: u64, const N: usize> From<crate::ring_n::R<N>> for Rq<Q, N> {
fn from(r: crate::ring_n::R<N>) -> Self {
impl From<(u64, crate::ring_n::R)> for Rq {
fn from(qr: (u64, crate::ring_n::R)) -> Self {
let (q, r) = qr;
assert_eq!(r.n, r.coeffs.len());
Self::from_vec( Self::from_vec(
&RingParam { q, n: r.n },
r.coeffs() r.coeffs()
.iter() .iter()
.map(|e| Zq::<Q>::from_f64(*e as f64))
.map(|e| Zq::from_f64(q, *e as f64))
.collect(), .collect(),
) )
} }
} }
// apply mod (X^N+1) // apply mod (X^N+1)
pub fn modulus<const Q: u64, const N: usize>(p: &mut Vec<Zq<Q>>) {
if p.len() < N {
pub fn modulus(q: u64, n: usize, p: &mut Vec<Zq>) {
if p.len() < n {
return; return;
} }
for i in N..p.len() {
p[i - N] = p[i - N].clone() - p[i].clone();
p[i] = Zq(0);
for i in n..p.len() {
p[i - n] = p[i - n].clone() - p[i].clone();
p[i] = Zq::zero(q);
} }
p.truncate(N);
p.truncate(n);
} }
// PR stands for PolynomialRing
impl<const Q: u64, const N: usize> Rq<Q, N> {
pub fn coeffs(&self) -> [Zq<Q>; N] {
self.coeffs
impl Rq {
pub fn coeffs(&self) -> Vec<Zq> {
self.coeffs.clone()
} }
pub fn compute_evals(&mut self) { pub fn compute_evals(&mut self) {
self.evals = Some(NTT::<Q, N>::ntt(self.coeffs));
self.evals = Some(NTT::ntt(self).coeffs);
// TODO improve, ntt returns Rq but here just needs Vec<Zq>
} }
pub fn to_r(self) -> crate::R<N> {
crate::R::<N>::from(self)
pub fn to_r(self) -> crate::R {
crate::R::from(self)
} }
// TODO rm since it is implemented in Ring trait impl
// pub fn zero() -> Self {
// let coeffs = array::from_fn(|_| Zq::zero());
// Self {
// coeffs,
// evals: None,
// }
// }
// this method is mostly for tests // this method is mostly for tests
pub fn from_vec_u64(coeffs: Vec<u64>) -> Self {
let coeffs_mod_q = coeffs.iter().map(|c| Zq::from_u64(*c)).collect();
Self::from_vec(coeffs_mod_q)
pub fn from_vec_u64(param: &RingParam, coeffs: Vec<u64>) -> Self {
let coeffs_mod_q: Vec<Zq> = coeffs.iter().map(|c| Zq::from_u64(param.q, *c)).collect();
Self::from_vec(param, coeffs_mod_q)
} }
pub fn from_vec_f64(coeffs: Vec<f64>) -> Self {
let coeffs_mod_q = coeffs.iter().map(|c| Zq::from_f64(*c)).collect();
Self::from_vec(coeffs_mod_q)
pub fn from_vec_f64(param: &RingParam, coeffs: Vec<f64>) -> Self {
let coeffs_mod_q: Vec<Zq> = coeffs.iter().map(|c| Zq::from_f64(param.q, *c)).collect();
Self::from_vec(param, coeffs_mod_q)
} }
pub fn from_vec_i64(coeffs: Vec<i64>) -> Self {
let coeffs_mod_q = coeffs.iter().map(|c| Zq::from_f64(*c as f64)).collect();
Self::from_vec(coeffs_mod_q)
pub fn from_vec_i64(param: &RingParam, coeffs: Vec<i64>) -> Self {
let coeffs_mod_q: Vec<Zq> = coeffs
.iter()
.map(|c| Zq::from_f64(param.q, *c as f64))
.collect();
Self::from_vec(param, coeffs_mod_q)
} }
pub fn new(coeffs: [Zq<Q>; N], evals: Option<[Zq<Q>; N]>) -> Self {
Self { coeffs, evals }
pub fn new(param: &RingParam, coeffs: Vec<Zq>, evals: Option<Vec<Zq>>) -> Self {
Self {
param: *param,
coeffs,
evals,
}
} }
pub fn rand_abs(mut rng: impl Rng, dist: impl Distribution<f64>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng).abs()));
pub fn rand_abs(
mut rng: impl Rng,
dist: impl Distribution<f64>,
param: &RingParam,
) -> Result<Self> {
Ok(Self { Ok(Self {
coeffs,
param: *param,
coeffs: std::iter::repeat_with(|| Zq::from_f64(param.q, dist.sample(&mut rng).abs()))
.take(param.n)
.collect(),
evals: None, evals: None,
}) })
} }
pub fn rand_f64_abs(mut rng: impl Rng, dist: impl Distribution<f64>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng).abs()));
pub fn rand_f64_abs(
mut rng: impl Rng,
dist: impl Distribution<f64>,
param: &RingParam,
) -> Result<Self> {
Ok(Self { Ok(Self {
coeffs,
param: *param,
coeffs: std::iter::repeat_with(|| Zq::from_f64(param.q, dist.sample(&mut rng).abs()))
.take(param.n)
.collect(),
evals: None, evals: None,
}) })
} }
pub fn rand_f64(mut rng: impl Rng, dist: impl Distribution<f64>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng)));
pub fn rand_f64(
mut rng: impl Rng,
dist: impl Distribution<f64>,
param: &RingParam,
) -> Result<Self> {
Ok(Self { Ok(Self {
coeffs,
param: *param,
coeffs: std::iter::repeat_with(|| Zq::from_f64(param.q, dist.sample(&mut rng)))
.take(param.n)
.collect(),
evals: None, evals: None,
}) })
} }
pub fn rand_u64(mut rng: impl Rng, dist: impl Distribution<u64>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_u64(dist.sample(&mut rng)));
pub fn rand_u64(
mut rng: impl Rng,
dist: impl Distribution<u64>,
param: &RingParam,
) -> Result<Self> {
Ok(Self { Ok(Self {
coeffs,
param: *param,
coeffs: std::iter::repeat_with(|| Zq::from_u64(param.q, dist.sample(&mut rng)))
.take(param.n)
.collect(),
evals: None, evals: None,
}) })
} }
// WIP. returns random v \in {0,1}. // TODO {-1, 0, 1} // WIP. returns random v \in {0,1}. // TODO {-1, 0, 1}
pub fn rand_bin(mut rng: impl Rng, dist: impl Distribution<bool>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_bool(dist.sample(&mut rng)));
pub fn rand_bin(
mut rng: impl Rng,
dist: impl Distribution<bool>,
param: &RingParam,
) -> Result<Self> {
Ok(Rq { Ok(Rq {
coeffs,
param: *param,
coeffs: std::iter::repeat_with(|| Zq::from_bool(param.q, dist.sample(&mut rng)))
.take(param.n)
.collect(),
evals: None, evals: None,
}) })
} }
@ -208,36 +250,43 @@ impl Rq {
// } // }
// applies mod(T) to all coefficients of self // applies mod(T) to all coefficients of self
pub fn coeffs_mod<const T: u64>(&self) -> Self {
Rq::<Q, N>::from_vec_u64(
pub fn coeffs_mod(&self, param: &RingParam, t: u64) -> Self {
Rq::from_vec_u64(
param,
self.coeffs() self.coeffs()
.iter() .iter()
.map(|m_i| modulus_u64::<T>(m_i.0))
.map(|m_i| modulus_u64(t, m_i.v))
.collect(), .collect(),
) )
} }
// TODO review if needed, or if with this interface // TODO review if needed, or if with this interface
pub fn mul_by_matrix(&self, m: &Vec<Vec<Zq<Q>>>) -> Result<Vec<Zq<Q>>> {
pub fn mul_by_matrix(&self, m: &Vec<Vec<Zq>>) -> Result<Vec<Zq>> {
matrix_vec_product(m, &self.coeffs.to_vec()) matrix_vec_product(m, &self.coeffs.to_vec())
} }
pub fn mul_by_zq(&self, s: &Zq<Q>) -> Self {
pub fn mul_by_zq(&self, s: &Zq) -> Self {
Self { Self {
coeffs: array::from_fn(|i| self.coeffs[i] * *s),
param: self.param,
coeffs: self.coeffs.iter().map(|c_i| *c_i * *s).collect(),
evals: None, evals: None,
} }
} }
pub fn mul_by_u64(&self, s: u64) -> Self { pub fn mul_by_u64(&self, s: u64) -> Self {
let s = Zq::from_u64(s);
let s = Zq::from_u64(self.param.q, s);
Self { Self {
coeffs: array::from_fn(|i| self.coeffs[i] * s),
// coeffs: self.coeffs.iter().map(|&e| e * s).collect(),
param: self.param,
coeffs: self.coeffs.iter().map(|&e| e * s).collect(),
evals: None, evals: None,
} }
} }
pub fn mul_by_f64(&self, s: f64) -> Self { pub fn mul_by_f64(&self, s: f64) -> Self {
Self { Self {
coeffs: array::from_fn(|i| Zq::from_f64(self.coeffs[i].0 as f64 * s)),
param: self.param,
coeffs: self
.coeffs
.iter()
.map(|c_i| Zq::from_f64(self.param.q, c_i.v as f64 * s))
.collect(),
evals: None, evals: None,
} }
} }
@ -251,9 +300,9 @@ impl Rq {
let r: Vec<f64> = self let r: Vec<f64> = self
.coeffs() .coeffs()
.iter() .iter()
.map(|e| (e.0 as f64 / s as f64).round())
.map(|e| (e.v as f64 / s as f64).round())
.collect(); .collect();
Rq::<Q, N>::from_vec_f64(r)
Rq::from_vec_f64(&self.param, r)
} }
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
@ -261,19 +310,19 @@ impl Rq {
let mut str = ""; let mut str = "";
let mut zero = true; let mut zero = true;
for (i, coeff) in self.coeffs.iter().enumerate().rev() { for (i, coeff) in self.coeffs.iter().enumerate().rev() {
if coeff.0 == 0 {
if coeff.v == 0 {
continue; continue;
} }
zero = false; zero = false;
f.write_str(str)?; f.write_str(str)?;
if coeff.0 != 1 {
f.write_str(coeff.0.to_string().as_str())?;
if coeff.v != 1 {
f.write_str(coeff.v.to_string().as_str())?;
if i > 0 { if i > 0 {
f.write_str("*")?; f.write_str("*")?;
} }
} }
if coeff.0 == 1 && i == 0 {
f.write_str(coeff.0.to_string().as_str())?;
if coeff.v == 1 && i == 0 {
f.write_str(coeff.v.to_string().as_str())?;
} }
if i == 1 { if i == 1 {
f.write_str("x")?; f.write_str("x")?;
@ -288,9 +337,9 @@ impl Rq {
} }
f.write_str(" mod Z_")?; f.write_str(" mod Z_")?;
f.write_str(Q.to_string().as_str())?;
f.write_str(self.param.q.to_string().as_str())?;
f.write_str("/(X^")?; f.write_str("/(X^")?;
f.write_str(N.to_string().as_str())?;
f.write_str(self.param.n.to_string().as_str())?;
f.write_str("+1)")?; f.write_str("+1)")?;
Ok(()) Ok(())
} }
@ -298,16 +347,20 @@ impl Rq {
pub fn infinity_norm(&self) -> u64 { pub fn infinity_norm(&self) -> u64 {
self.coeffs() self.coeffs()
.iter() .iter()
.map(|x| if x.0 > (Q / 2) { Q - x.0 } else { x.0 })
.map(|x| {
if x.v > (self.param.q / 2) {
self.param.q - x.v
} else {
x.v
}
})
.fold(0, |a, b| a.max(b)) .fold(0, |a, b| a.max(b))
} }
pub fn mod_centered_q(&self) -> crate::ring_n::R<N> {
self.to_r().mod_centered_q::<Q>()
pub fn mod_centered_q(&self) -> crate::ring_n::R {
self.clone().to_r().mod_centered_q(self.param.q)
} }
} }
pub fn matrix_vec_product<const Q: u64>(m: &Vec<Vec<Zq<Q>>>, v: &Vec<Zq<Q>>) -> Result<Vec<Zq<Q>>> {
// assert_eq!(m.len(), m[0].len()); // TODO change to returning err
// assert_eq!(m.len(), v.len());
pub fn matrix_vec_product(m: &Vec<Vec<Zq>>, v: &Vec<Zq>) -> Result<Vec<Zq>> {
if m.len() != m[0].len() { if m.len() != m[0].len() {
return Err(anyhow!("expected 'm' to be a square matrix")); return Err(anyhow!("expected 'm' to be a square matrix"));
} }
@ -319,6 +372,8 @@ pub fn matrix_vec_product(m: &Vec>>, v: &Vec>) ->
)); ));
} }
assert_eq!(m[0][0].q, v[0].q); // TODO change to returning err
Ok(m.iter() Ok(m.iter()
.map(|row| { .map(|row| {
row.iter() row.iter()
@ -326,12 +381,15 @@ pub fn matrix_vec_product(m: &Vec>>, v: &Vec>) ->
.map(|(&row_i, &v_i)| row_i * v_i) .map(|(&row_i, &v_i)| row_i * v_i)
.sum() .sum()
}) })
.collect::<Vec<Zq<Q>>>())
.collect::<Vec<Zq>>())
} }
pub fn transpose<const Q: u64>(m: &[Vec<Zq<Q>>]) -> Vec<Vec<Zq<Q>>> {
pub fn transpose(m: &[Vec<Zq>]) -> Vec<Vec<Zq>> {
assert!(m.len() > 0);
assert!(m[0].len() > 0);
let q = m[0][0].q;
// TODO case when m[0].len()=0 // TODO case when m[0].len()=0
// TODO non square matrix // TODO non square matrix
let mut r: Vec<Vec<Zq<Q>>> = vec![vec![Zq(0); m[0].len()]; m.len()];
let mut r: Vec<Vec<Zq>> = vec![vec![Zq::zero(q); m[0].len()]; m.len()];
for (i, m_row) in m.iter().enumerate() { for (i, m_row) in m.iter().enumerate() {
for (j, m_ij) in m_row.iter().enumerate() { for (j, m_ij) in m_row.iter().enumerate() {
r[j][i] = *m_ij; r[j][i] = *m_ij;
@ -340,205 +398,221 @@ pub fn transpose(m: &[Vec>]) -> Vec>> {
r r
} }
impl<const Q: u64, const N: usize> PartialEq for Rq<Q, N> {
impl PartialEq for Rq {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.coeffs == other.coeffs
self.coeffs == other.coeffs && self.param == other.param
} }
} }
impl<const Q: u64, const N: usize> Add<Rq<Q, N>> for Rq<Q, N> {
impl Add<Rq> for Rq {
type Output = Self; type Output = Self;
fn add(self, rhs: Self) -> Self { fn add(self, rhs: Self) -> Self {
assert_eq!(self.param, rhs.param);
Self { Self {
coeffs: array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]),
param: self.param,
coeffs: zip_eq(self.coeffs, rhs.coeffs)
.map(|(l, r)| l + r)
.collect(),
evals: None, evals: None,
} }
// Self {
// coeffs: self
// .coeffs
// .iter()
// .zip(rhs.coeffs)
// .map(|(a, b)| *a + b)
// .collect(),
// evals: None,
// }
// Self(r.iter_mut().map(|e| e.r#mod()).collect()) // TODO mod should happen auto in +
} }
} }
impl<const Q: u64, const N: usize> Add<&Rq<Q, N>> for &Rq<Q, N> {
type Output = Rq<Q, N>;
impl Add<&Rq> for &Rq {
type Output = Rq;
fn add(self, rhs: &Rq<Q, N>) -> Self::Output {
fn add(self, rhs: &Rq) -> Self::Output {
assert_eq!(self.param, rhs.param);
Rq { Rq {
coeffs: array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]),
param: self.param,
coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone())
.map(|(l, r)| l + r)
.collect(),
evals: None, evals: None,
} }
} }
} }
impl<const Q: u64, const N: usize> AddAssign for Rq<Q, N> {
impl AddAssign for Rq {
fn add_assign(&mut self, rhs: Self) { fn add_assign(&mut self, rhs: Self) {
for i in 0..N {
debug_assert_eq!(self.param, rhs.param);
for i in 0..self.param.n {
self.coeffs[i] += rhs.coeffs[i]; self.coeffs[i] += rhs.coeffs[i];
} }
} }
} }
impl<const Q: u64, const N: usize> Sum<Rq<Q, N>> for Rq<Q, N> {
fn sum<I>(iter: I) -> Self
impl Sum<Rq> for Rq {
fn sum<I>(mut iter: I) -> Self
where where
I: Iterator<Item = Self>, I: Iterator<Item = Self>,
{ {
let mut acc = Rq::<Q, N>::zero();
for e in iter {
acc += e;
}
acc
let first = iter.next().unwrap();
iter.fold(first, |acc, x| acc + x)
} }
} }
impl<const Q: u64, const N: usize> Sub<Rq<Q, N>> for Rq<Q, N> {
impl Sub<Rq> for Rq {
type Output = Self; type Output = Self;
fn sub(self, rhs: Self) -> Self { fn sub(self, rhs: Self) -> Self {
assert_eq!(self.param, rhs.param);
Self { Self {
coeffs: array::from_fn(|i| self.coeffs[i] - rhs.coeffs[i]),
param: self.param,
coeffs: zip_eq(self.coeffs, rhs.coeffs)
.map(|(l, r)| l - r)
.collect(),
evals: None, evals: None,
} }
} }
} }
impl<const Q: u64, const N: usize> Sub<&Rq<Q, N>> for &Rq<Q, N> {
type Output = Rq<Q, N>;
impl Sub<&Rq> for &Rq {
type Output = Rq;
fn sub(self, rhs: &Rq<Q, N>) -> Self::Output {
fn sub(self, rhs: &Rq) -> Self::Output {
debug_assert_eq!(self.param, rhs.param);
Rq { Rq {
coeffs: array::from_fn(|i| self.coeffs[i] - rhs.coeffs[i]),
param: self.param,
coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone())
.map(|(l, r)| l - r)
.collect(),
evals: None, evals: None,
} }
} }
} }
impl<const Q: u64, const N: usize> SubAssign for Rq<Q, N> {
impl SubAssign for Rq {
fn sub_assign(&mut self, rhs: Self) { fn sub_assign(&mut self, rhs: Self) {
for i in 0..N {
debug_assert_eq!(self.param, rhs.param);
for i in 0..self.param.n {
self.coeffs[i] -= rhs.coeffs[i]; self.coeffs[i] -= rhs.coeffs[i];
} }
} }
} }
impl<const Q: u64, const N: usize> Mul<Rq<Q, N>> for Rq<Q, N> {
impl Mul<Rq> for Rq {
type Output = Self; type Output = Self;
fn mul(self, rhs: Self) -> Self { fn mul(self, rhs: Self) -> Self {
mul(&self, &rhs) mul(&self, &rhs)
} }
} }
impl<const Q: u64, const N: usize> Mul<&Rq<Q, N>> for &Rq<Q, N> {
type Output = Rq<Q, N>;
impl Mul<&Rq> for &Rq {
type Output = Rq;
fn mul(self, rhs: &Rq<Q, N>) -> Self::Output {
fn mul(self, rhs: &Rq) -> Self::Output {
mul(self, rhs) mul(self, rhs)
} }
} }
// mul by Zq element // mul by Zq element
impl<const Q: u64, const N: usize> Mul<Zq<Q>> for Rq<Q, N> {
impl Mul<Zq> for Rq {
type Output = Self; type Output = Self;
fn mul(self, s: Zq<Q>) -> Self {
fn mul(self, s: Zq) -> Self {
self.mul_by_zq(&s) self.mul_by_zq(&s)
} }
} }
impl<const Q: u64, const N: usize> Mul<&Zq<Q>> for &Rq<Q, N> {
type Output = Rq<Q, N>;
impl Mul<&Zq> for &Rq {
type Output = Rq;
fn mul(self, s: &Zq<Q>) -> Self::Output {
fn mul(self, s: &Zq) -> Self::Output {
self.mul_by_zq(s) self.mul_by_zq(s)
} }
} }
// mul by u64 // mul by u64
impl<const Q: u64, const N: usize> Mul<u64> for Rq<Q, N> {
impl Mul<u64> for Rq {
type Output = Self; type Output = Self;
fn mul(self, s: u64) -> Self { fn mul(self, s: u64) -> Self {
self.mul_by_u64(s) self.mul_by_u64(s)
} }
} }
impl<const Q: u64, const N: usize> Mul<&u64> for &Rq<Q, N> {
type Output = Rq<Q, N>;
impl Mul<&u64> for &Rq {
type Output = Rq;
fn mul(self, s: &u64) -> Self::Output { fn mul(self, s: &u64) -> Self::Output {
self.mul_by_u64(*s) self.mul_by_u64(*s)
} }
} }
// mul by f64 // mul by f64
impl<const Q: u64, const N: usize> Mul<f64> for Rq<Q, N> {
impl Mul<f64> for Rq {
type Output = Self; type Output = Self;
fn mul(self, s: f64) -> Self { fn mul(self, s: f64) -> Self {
self.mul_by_f64(s) self.mul_by_f64(s)
} }
} }
impl<const Q: u64, const N: usize> Mul<&f64> for &Rq<Q, N> {
type Output = Rq<Q, N>;
impl Mul<&f64> for &Rq {
type Output = Rq;
fn mul(self, s: &f64) -> Self::Output { fn mul(self, s: &f64) -> Self::Output {
self.mul_by_f64(*s) self.mul_by_f64(*s)
} }
} }
impl<const Q: u64, const N: usize> Neg for Rq<Q, N> {
impl Neg for Rq {
type Output = Self; type Output = Self;
fn neg(self) -> Self::Output { fn neg(self) -> Self::Output {
Self { Self {
coeffs: array::from_fn(|i| -self.coeffs[i]),
param: self.param,
coeffs: self.coeffs.iter().map(|c_i| -*c_i).collect(),
evals: None, evals: None,
} }
} }
} }
// note: this assumes that Q is prime // note: this assumes that Q is prime
fn mul_mut<const Q: u64, const N: usize>(lhs: &mut Rq<Q, N>, rhs: &mut Rq<Q, N>) -> Rq<Q, N> {
fn mul_mut(lhs: &mut Rq, rhs: &mut Rq) -> Rq {
assert_eq!(lhs.param, rhs.param);
// reuse evaluations if already computed // reuse evaluations if already computed
if !lhs.evals.is_some() { if !lhs.evals.is_some() {
lhs.evals = Some(NTT::<Q, N>::ntt(lhs.coeffs));
lhs.evals = Some(NTT::ntt(lhs).coeffs);
}; };
if !rhs.evals.is_some() { if !rhs.evals.is_some() {
rhs.evals = Some(NTT::<Q, N>::ntt(rhs.coeffs));
rhs.evals = Some(NTT::ntt(rhs).coeffs);
}; };
let lhs_evals = lhs.evals.unwrap();
let rhs_evals = rhs.evals.unwrap();
let c_ntt: [Zq<Q>; N] = array::from_fn(|i| lhs_evals[i] * rhs_evals[i]);
let c = NTT::<Q, { N }>::intt(c_ntt);
Rq::new(c, Some(c_ntt))
let lhs_evals = lhs.evals.clone().unwrap();
let rhs_evals = rhs.evals.clone().unwrap();
let c_ntt: Rq = Rq::from_vec(
&lhs.param,
zip_eq(lhs_evals, rhs_evals).map(|(l, r)| l * r).collect(),
);
let c = NTT::intt(&c_ntt);
Rq::new(&lhs.param, c.coeffs, Some(c_ntt.coeffs))
} }
// note: this assumes that Q is prime // note: this assumes that Q is prime
// TODO impl karatsuba for non-prime Q
fn mul<const Q: u64, const N: usize>(lhs: &Rq<Q, N>, rhs: &Rq<Q, N>) -> Rq<Q, N> {
// TODO impl karatsuba for non-prime Q. Alternatively check NTT with RNS trick.
fn mul(lhs: &Rq, rhs: &Rq) -> Rq {
assert_eq!(lhs.param, rhs.param);
// reuse evaluations if already computed // reuse evaluations if already computed
let lhs_evals = if lhs.evals.is_some() {
lhs.evals.unwrap()
let lhs_evals: Vec<Zq> = if lhs.evals.is_some() {
lhs.evals.clone().unwrap()
} else { } else {
NTT::<Q, N>::ntt(lhs.coeffs)
NTT::ntt(lhs).coeffs
}; };
let rhs_evals = if rhs.evals.is_some() {
rhs.evals.unwrap()
let rhs_evals: Vec<Zq> = if rhs.evals.is_some() {
rhs.evals.clone().unwrap()
} else { } else {
NTT::<Q, N>::ntt(rhs.coeffs)
NTT::ntt(rhs).coeffs
}; };
let c_ntt: [Zq<Q>; N] = array::from_fn(|i| lhs_evals[i] * rhs_evals[i]);
let c = NTT::<Q, { N }>::intt(c_ntt);
Rq::new(c, Some(c_ntt))
let c_ntt: Rq = Rq::from_vec(
&lhs.param,
zip_eq(lhs_evals, rhs_evals).map(|(l, r)| l * r).collect(),
);
let c = NTT::intt(&c_ntt);
Rq::new(&lhs.param, c.coeffs, Some(c_ntt.coeffs))
} }
impl<const Q: u64, const N: usize> fmt::Display for Rq<Q, N> {
impl fmt::Display for Rq {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt(f)?; self.fmt(f)?;
Ok(()) Ok(())
} }
} }
impl<const Q: u64, const N: usize> fmt::Debug for Rq<Q, N> {
impl fmt::Debug for Rq {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt(f)?; self.fmt(f)?;
Ok(()) Ok(())
@ -552,31 +626,30 @@ mod tests {
#[test] #[test]
fn test_polynomial_ring() { fn test_polynomial_ring() {
// the test values used are generated with SageMath // the test values used are generated with SageMath
const Q: u64 = 7;
const N: usize = 3;
let param = RingParam { q: 7, n: 3 };
// p = 1x + 2x^2 + 3x^3 + 4 x^4 + 5 x^5 in R=Z_q[X]/(X^n +1) // p = 1x + 2x^2 + 3x^3 + 4 x^4 + 5 x^5 in R=Z_q[X]/(X^n +1)
let p = Rq::<Q, N>::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]);
let p = Rq::from_vec_u64(&param, vec![0u64, 1, 2, 3, 4, 5]);
assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)"); assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
// try with coefficients bigger than Q // try with coefficients bigger than Q
let p = Rq::<Q, N>::from_vec_u64(vec![0u64, 1, Q + 2, 3, 4, 5]);
let p = Rq::from_vec_u64(&param, vec![0u64, 1, param.q + 2, 3, 4, 5]);
assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)"); assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
// try with other ring // try with other ring
let p = Rq::<7, 4>::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]);
let p = Rq::from_vec_u64(&RingParam { q: 7, n: 4 }, vec![0u64, 1, 2, 3, 4, 5]);
assert_eq!(p.to_string(), "3*x^3 + 2*x^2 + 3*x + 3 mod Z_7/(X^4+1)"); assert_eq!(p.to_string(), "3*x^3 + 2*x^2 + 3*x + 3 mod Z_7/(X^4+1)");
let p = Rq::<Q, N>::from_vec_u64(vec![0u64, 0, 0, 0, 4, 5]);
let p = Rq::from_vec_u64(&param, vec![0u64, 0, 0, 0, 4, 5]);
assert_eq!(p.to_string(), "2*x^2 + 3*x mod Z_7/(X^3+1)"); assert_eq!(p.to_string(), "2*x^2 + 3*x mod Z_7/(X^3+1)");
let p = Rq::<Q, N>::from_vec_u64(vec![5u64, 4, 5, 2, 1, 0]);
let p = Rq::from_vec_u64(&param, vec![5u64, 4, 5, 2, 1, 0]);
assert_eq!(p.to_string(), "5*x^2 + 3*x + 3 mod Z_7/(X^3+1)"); assert_eq!(p.to_string(), "5*x^2 + 3*x + 3 mod Z_7/(X^3+1)");
let a = Rq::<Q, N>::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]);
let a = Rq::from_vec_u64(&param, vec![0u64, 1, 2, 3, 4, 5]);
assert_eq!(a.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)"); assert_eq!(a.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
let b = Rq::<Q, N>::from_vec_u64(vec![5u64, 4, 3, 2, 1, 0]);
let b = Rq::from_vec_u64(&param, vec![5u64, 4, 3, 2, 1, 0]);
assert_eq!(b.to_string(), "3*x^2 + 3*x + 3 mod Z_7/(X^3+1)"); assert_eq!(b.to_string(), "3*x^2 + 3*x + 3 mod Z_7/(X^3+1)");
// add // add
@ -593,34 +666,37 @@ mod tests {
#[test] #[test]
fn test_mul() -> Result<()> { fn test_mul() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 4;
let param = RingParam {
q: 2u64.pow(16) + 1,
n: 4,
};
let a: [u64; N] = [1u64, 2, 3, 4];
let b: [u64; N] = [1u64, 2, 3, 4];
let c: [u64; N] = [65513, 65517, 65531, 20];
test_mul_opt::<Q, N>(a, b, c)?;
let a: Vec<u64> = vec![1u64, 2, 3, 4];
let b: Vec<u64> = vec![1u64, 2, 3, 4];
let c: Vec<u64> = vec![65513, 65517, 65531, 20];
test_mul_opt(&param, a, b, c)?;
let a: [u64; N] = [0u64, 0, 0, 2];
let b: [u64; N] = [0u64, 0, 0, 2];
let c: [u64; N] = [0u64, 0, 65533, 0];
test_mul_opt::<Q, N>(a, b, c)?;
let a: Vec<u64> = vec![0u64, 0, 0, 2];
let b: Vec<u64> = vec![0u64, 0, 0, 2];
let c: Vec<u64> = vec![0u64, 0, 65533, 0];
test_mul_opt(&param, a, b, c)?;
// TODO more testvectors // TODO more testvectors
Ok(()) Ok(())
} }
fn test_mul_opt<const Q: u64, const N: usize>(
a: [u64; N],
b: [u64; N],
expected_c: [u64; N],
fn test_mul_opt(
param: &RingParam,
a: Vec<u64>,
b: Vec<u64>,
expected_c: Vec<u64>,
) -> Result<()> { ) -> Result<()> {
let a: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(a[i]));
let mut a = Rq::new(a, None);
let b: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(b[i]));
let mut b = Rq::new(b, None);
let expected_c: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(expected_c[i]));
let expected_c = Rq::new(expected_c, None);
assert_eq!(a.len(), param.n);
assert_eq!(b.len(), param.n);
let mut a = Rq::from_vec_u64(&param, a);
let mut b = Rq::from_vec_u64(&param, b);
let expected_c = Rq::from_vec_u64(&param, expected_c);
let c = mul_mut(&mut a, &mut b); let c = mul_mut(&mut a, &mut b);
assert_eq!(c, expected_c); assert_eq!(c, expected_c);
@ -629,26 +705,25 @@ mod tests {
#[test] #[test]
fn test_rq_decompose() -> Result<()> { fn test_rq_decompose() -> Result<()> {
const Q: u64 = 16;
const N: usize = 4;
let param = RingParam { q: 16, n: 4 };
let beta = 4; let beta = 4;
let l = 2; let l = 2;
let a = Rq::<Q, N>::from_vec_u64(vec![7u64, 14, 3, 6]);
let a = Rq::from_vec_u64(&param, vec![7u64, 14, 3, 6]);
let d = a.decompose(beta, l); let d = a.decompose(beta, l);
assert_eq!( assert_eq!(
d[0].coeffs().to_vec(),
d[0].coeffs(),
vec![1u64, 3, 0, 1] vec![1u64, 3, 0, 1]
.iter() .iter()
.map(|e| Zq::<Q>::from_u64(*e))
.map(|e| Zq::from_u64(param.q, *e))
.collect::<Vec<_>>() .collect::<Vec<_>>()
); );
assert_eq!( assert_eq!(
d[1].coeffs().to_vec(),
d[1].coeffs(),
vec![3u64, 2, 3, 2] vec![3u64, 2, 3, 2]
.iter() .iter()
.map(|e| Zq::<Q>::from_u64(*e))
.map(|e| Zq::from_u64(param.q, *e))
.collect::<Vec<_>>() .collect::<Vec<_>>()
); );
Ok(()) Ok(())

+ 178
- 99
arith/src/ring_torus.rs

@ -7,63 +7,94 @@
//! u64, we fit it into the `Ring` trait (from ring.rs) so that we can compose //! u64, we fit it into the `Ring` trait (from ring.rs) so that we can compose
//! the 𝕋_<N,q> implementation with the other objects from the code. //! the 𝕋_<N,q> implementation with the other objects from the code.
use itertools::zip_eq;
use rand::{distributions::Distribution, Rng}; use rand::{distributions::Distribution, Rng};
use std::array;
use std::iter::Sum; use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}; use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign};
use crate::{ring::Ring, torus::T64, Rq, Zq};
use crate::{
ring::{Ring, RingParam},
torus::T64,
Rq, Zq,
};
/// 𝕋_<N,Q>[X] = 𝕋<Q>[X]/(X^N +1), polynomials modulo X^N+1 with coefficients in /// 𝕋_<N,Q>[X] = 𝕋<Q>[X]/(X^N +1), polynomials modulo X^N+1 with coefficients in
/// 𝕋, where Q=2^64. /// 𝕋, where Q=2^64.
#[derive(Clone, Copy, Debug)]
pub struct Tn<const N: usize>(pub [T64; N]);
#[derive(Clone, Debug)]
pub struct Tn {
pub param: RingParam,
pub coeffs: Vec<T64>,
}
impl<const N: usize> Ring for Tn<N> {
impl Ring for Tn {
type C = T64; type C = T64;
const Q: u64 = u64::MAX; // WIP
const N: usize = N;
fn param(&self) -> RingParam {
RingParam {
q: u64::MAX,
n: self.param.n,
}
}
fn coeffs(&self) -> Vec<T64> { fn coeffs(&self) -> Vec<T64> {
self.0.to_vec()
self.coeffs.to_vec()
} }
fn zero() -> Self {
Self(array::from_fn(|_| T64::zero()))
fn zero(param: &RingParam) -> Self {
Self {
param: *param,
coeffs: vec![T64::zero(param); param.n],
}
} }
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>) -> Self {
Self(array::from_fn(|_| T64::rand(&mut rng, &dist)))
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>, param: &RingParam) -> Self {
Self {
param: *param,
coeffs: std::iter::repeat_with(|| T64::rand(&mut rng, &dist, &param))
.take(param.n)
.collect(),
}
} }
fn from_vec(coeffs: Vec<Self::C>) -> Self {
fn from_vec(param: &RingParam, coeffs: Vec<Self::C>) -> Self {
let mut p = coeffs; let mut p = coeffs;
modulus::<N>(&mut p);
Self(array::from_fn(|i| p[i]))
modulus(param, &mut p);
Self {
param: *param,
coeffs: p,
}
} }
fn decompose(&self, beta: u32, l: u32) -> Vec<Self> { fn decompose(&self, beta: u32, l: u32) -> Vec<Self> {
let elems: Vec<Vec<T64>> = self.0.iter().map(|r| r.decompose(beta, l)).collect();
let elems: Vec<Vec<T64>> = self.coeffs.iter().map(|r| r.decompose(beta, l)).collect();
// transpose it // transpose it
let r: Vec<Vec<T64>> = (0..elems[0].len()) let r: Vec<Vec<T64>> = (0..elems[0].len())
.map(|i| (0..elems.len()).map(|j| elems[j][i]).collect()) .map(|i| (0..elems.len()).map(|j| elems[j][i]).collect())
.collect(); .collect();
// convert it to Tn<N>
r.iter().map(|a_i| Self::from_vec(a_i.clone())).collect()
// convert it to Tn
r.iter()
.map(|a_i| Self::from_vec(&self.param, a_i.clone()))
.collect()
} }
fn remodule<const P: u64>(&self) -> Tn<N> {
fn remodule(&self, p: u64) -> Tn {
todo!() todo!()
// Rq::<P, N>::from_vec_u64(self.coeffs().iter().map(|m_i| m_i.0).collect()) // Rq::<P, N>::from_vec_u64(self.coeffs().iter().map(|m_i| m_i.0).collect())
} }
// fn mod_switch<const P: u64>(&self) -> impl Ring { // fn mod_switch<const P: u64>(&self) -> impl Ring {
fn mod_switch<const P: u64>(&self) -> Rq<P, N> {
fn mod_switch(&self, p: u64) -> Rq {
// unimplemented!() // unimplemented!()
// TODO WIP // TODO WIP
let coeffs = array::from_fn(|i| Zq::<P>::from_u64(self.0[i].mod_switch::<P>().0));
Rq::<P, N> {
let coeffs = self
.coeffs
.iter()
.map(|c_i| Zq::from_u64(p, c_i.mod_switch(p).0))
.collect();
Rq {
param: RingParam {
q: p,
n: self.param.n,
},
coeffs, coeffs,
evals: None, evals: None,
} }
@ -78,175 +109,220 @@ impl Ring for Tn {
.iter() .iter()
.map(|e| T64(((num as f64 * e.0 as f64) / den as f64).round() as u64)) .map(|e| T64(((num as f64 * e.0 as f64) / den as f64).round() as u64))
.collect(); .collect();
Self::from_vec(r)
Self::from_vec(&self.param, r)
} }
} }
impl<const N: usize> Tn<N> {
impl Tn {
// multiply self by X^-h // multiply self by X^-h
pub fn left_rotate(&self, h: usize) -> Self { pub fn left_rotate(&self, h: usize) -> Self {
let h = h % N;
assert!(h < N);
let c = self.0;
let n = self.param.n;
let h = h % n;
assert!(h < n);
let c = &self.coeffs;
// c[h], c[h+1], c[h+2], ..., c[n-1], -c[0], -c[1], ..., -c[h-1] // c[h], c[h+1], c[h+2], ..., c[n-1], -c[0], -c[1], ..., -c[h-1]
// let r: Vec<T64> = vec![c[h..N], c[0..h].iter().map(|&c_i| -c_i).collect()].concat(); // let r: Vec<T64> = vec![c[h..N], c[0..h].iter().map(|&c_i| -c_i).collect()].concat();
let r: Vec<T64> = c[h..N]
let r: Vec<T64> = c[h..n]
.iter() .iter()
.copied() .copied()
.chain(c[0..h].iter().map(|&x| -x)) .chain(c[0..h].iter().map(|&x| -x))
.collect(); .collect();
Self::from_vec(r)
Self::from_vec(&self.param, r)
} }
pub fn from_vec_u64(v: Vec<u64>) -> Self {
pub fn from_vec_u64(param: &RingParam, v: Vec<u64>) -> Self {
let coeffs = v.iter().map(|c| T64(*c)).collect(); let coeffs = v.iter().map(|c| T64(*c)).collect();
Self::from_vec(coeffs)
Self::from_vec(param, coeffs)
} }
} }
// apply mod (X^N+1) // apply mod (X^N+1)
pub fn modulus<const N: usize>(p: &mut Vec<T64>) {
if p.len() < N {
pub fn modulus(param: &RingParam, p: &mut Vec<T64>) {
let n = param.n;
if p.len() < n {
return; return;
} }
for i in N..p.len() {
p[i - N] = p[i - N].clone() - p[i].clone();
p[i] = T64::zero();
for i in n..p.len() {
p[i - n] = p[i - n].clone() - p[i].clone();
p[i] = T64::zero(param);
} }
p.truncate(N);
p.truncate(n);
} }
impl<const N: usize> Add<Tn<N>> for Tn<N> {
impl Add<Tn> for Tn {
type Output = Self; type Output = Self;
fn add(self, rhs: Self) -> Self { fn add(self, rhs: Self) -> Self {
Self(array::from_fn(|i| self.0[i] + rhs.0[i]))
assert_eq!(self.param, rhs.param);
Self {
param: self.param,
coeffs: zip_eq(self.coeffs, rhs.coeffs)
.map(|(l, r)| l + r)
.collect(),
}
} }
} }
impl<const N: usize> Add<&Tn<N>> for &Tn<N> {
type Output = Tn<N>;
fn add(self, rhs: &Tn<N>) -> Self::Output {
Tn(array::from_fn(|i| self.0[i] + rhs.0[i]))
impl Add<&Tn> for &Tn {
type Output = Tn;
fn add(self, rhs: &Tn) -> Self::Output {
assert_eq!(self.param, rhs.param);
Tn {
param: self.param,
coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone())
.map(|(l, r)| l + r)
.collect(),
}
} }
} }
impl<const N: usize> AddAssign for Tn<N> {
impl AddAssign for Tn {
fn add_assign(&mut self, rhs: Self) { fn add_assign(&mut self, rhs: Self) {
for i in 0..N {
self.0[i] += rhs.0[i];
assert_eq!(self.param, rhs.param);
for i in 0..self.param.n {
self.coeffs[i] += rhs.coeffs[i];
} }
} }
} }
impl<const N: usize> Sum<Tn<N>> for Tn<N> {
fn sum<I>(iter: I) -> Self
impl Sum<Tn> for Tn {
fn sum<I>(mut iter: I) -> Self
where where
I: Iterator<Item = Self>, I: Iterator<Item = Self>,
{ {
let mut acc = Tn::<N>::zero();
for e in iter {
acc += e;
}
acc
let first = iter.next().unwrap();
iter.fold(first, |acc, x| acc + x)
} }
} }
impl<const N: usize> Sub<Tn<N>> for Tn<N> {
impl Sub<Tn> for Tn {
type Output = Self; type Output = Self;
fn sub(self, rhs: Self) -> Self { fn sub(self, rhs: Self) -> Self {
Self(array::from_fn(|i| self.0[i] - rhs.0[i]))
assert_eq!(self.param, rhs.param);
Self {
param: self.param,
coeffs: zip_eq(self.coeffs, rhs.coeffs)
.map(|(l, r)| l - r)
.collect(),
}
} }
} }
impl<const N: usize> Sub<&Tn<N>> for &Tn<N> {
type Output = Tn<N>;
fn sub(self, rhs: &Tn<N>) -> Self::Output {
Tn(array::from_fn(|i| self.0[i] - rhs.0[i]))
impl Sub<&Tn> for &Tn {
type Output = Tn;
fn sub(self, rhs: &Tn) -> Self::Output {
assert_eq!(self.param, rhs.param);
Tn {
param: self.param,
coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone())
.map(|(l, r)| l - r)
.collect(),
}
} }
} }
impl<const N: usize> SubAssign for Tn<N> {
impl SubAssign for Tn {
fn sub_assign(&mut self, rhs: Self) { fn sub_assign(&mut self, rhs: Self) {
for i in 0..N {
self.0[i] -= rhs.0[i];
assert_eq!(self.param, rhs.param);
for i in 0..self.param.n {
self.coeffs[i] -= rhs.coeffs[i];
} }
} }
} }
impl<const N: usize> Neg for Tn<N> {
impl Neg for Tn {
type Output = Self; type Output = Self;
fn neg(self) -> Self::Output { fn neg(self) -> Self::Output {
Tn(array::from_fn(|i| -self.0[i]))
Self {
param: self.param,
coeffs: self.coeffs.iter().map(|c_i| -*c_i).collect(),
}
} }
} }
impl<const N: usize> PartialEq for Tn<N> {
impl PartialEq for Tn {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.0 == other.0
self.coeffs == other.coeffs && self.param == other.param
} }
} }
impl<const N: usize> Mul<Tn<N>> for Tn<N> {
impl Mul<Tn> for Tn {
type Output = Self; type Output = Self;
fn mul(self, rhs: Self) -> Self { fn mul(self, rhs: Self) -> Self {
naive_poly_mul(&self, &rhs) naive_poly_mul(&self, &rhs)
} }
} }
impl<const N: usize> Mul<&Tn<N>> for &Tn<N> {
type Output = Tn<N>;
impl Mul<&Tn> for &Tn {
type Output = Tn;
fn mul(self, rhs: &Tn<N>) -> Self::Output {
fn mul(self, rhs: &Tn) -> Self::Output {
naive_poly_mul(self, rhs) naive_poly_mul(self, rhs)
} }
} }
fn naive_poly_mul<const N: usize>(poly1: &Tn<N>, poly2: &Tn<N>) -> Tn<N> {
let poly1: Vec<u128> = poly1.0.iter().map(|c| c.0 as u128).collect();
let poly2: Vec<u128> = poly2.0.iter().map(|c| c.0 as u128).collect();
let mut result: Vec<u128> = vec![0; (N * 2) - 1];
for i in 0..N {
for j in 0..N {
fn naive_poly_mul(poly1: &Tn, poly2: &Tn) -> Tn {
assert_eq!(poly1.param, poly2.param);
let n = poly1.param.n;
let param = poly1.param;
let poly1: Vec<u128> = poly1.coeffs.iter().map(|c| c.0 as u128).collect();
let poly2: Vec<u128> = poly2.coeffs.iter().map(|c| c.0 as u128).collect();
let mut result: Vec<u128> = vec![0; (n * 2) - 1];
for i in 0..n {
for j in 0..n {
result[i + j] = result[i + j] + poly1[i] * poly2[j]; result[i + j] = result[i + j] + poly1[i] * poly2[j];
} }
} }
// apply mod (X^N + 1))
modulus_u128::<N>(&mut result);
// apply mod (X^n + 1))
modulus_u128(n, &mut result);
Tn(array::from_fn(|i| T64(result[i] as u64)))
Tn {
param,
coeffs: result.iter().map(|r_i| T64(*r_i as u64)).collect(),
}
} }
fn modulus_u128<const N: usize>(p: &mut Vec<u128>) {
if p.len() < N {
fn modulus_u128(n: usize, p: &mut Vec<u128>) {
if p.len() < n {
return; return;
} }
for i in N..p.len() {
// p[i - N] = p[i - N].clone() - p[i].clone();
p[i - N] = p[i - N].wrapping_sub(p[i]);
for i in n..p.len() {
// p[i - n] = p[i - n].clone() - p[i].clone();
p[i - n] = p[i - n].wrapping_sub(p[i]);
p[i] = 0; p[i] = 0;
} }
p.truncate(N);
p.truncate(n);
} }
impl<const N: usize> Mul<T64> for Tn<N> {
impl Mul<T64> for Tn {
type Output = Self; type Output = Self;
fn mul(self, s: T64) -> Self { fn mul(self, s: T64) -> Self {
Self(array::from_fn(|i| self.0[i] * s))
Self {
param: self.param,
coeffs: self.coeffs.iter().map(|c_i| *c_i * s).collect(),
}
} }
} }
// mul by u64 // mul by u64
impl<const N: usize> Mul<u64> for Tn<N> {
impl Mul<u64> for Tn {
type Output = Self; type Output = Self;
fn mul(self, s: u64) -> Self { fn mul(self, s: u64) -> Self {
Self(array::from_fn(|i| self.0[i] * s))
Tn {
param: self.param,
coeffs: self.coeffs.iter().map(|c_i| *c_i * s).collect(),
}
} }
} }
impl<const N: usize> Mul<&u64> for &Tn<N> {
type Output = Tn<N>;
impl Mul<&u64> for &Tn {
type Output = Tn;
fn mul(self, s: &u64) -> Self::Output { fn mul(self, s: &u64) -> Self::Output {
Tn::<N>(array::from_fn(|i| self.0[i] * *s))
Tn {
param: self.param,
coeffs: self.coeffs.iter().map(|c_i| c_i * s).collect(),
}
} }
} }
@ -256,8 +332,9 @@ mod tests {
#[test] #[test]
fn test_left_rotate() { fn test_left_rotate() {
const N: usize = 4;
let f = Tn::<N>::from_vec(
let param = RingParam { q: u64::MAX, n: 4 };
let f = Tn::from_vec(
&param,
vec![2i64, 3, -4, -1] vec![2i64, 3, -4, -1]
.iter() .iter()
.map(|c| T64(*c as u64)) .map(|c| T64(*c as u64))
@ -267,7 +344,8 @@ mod tests {
// expect f*x^-3 == -1 -2x -3x^2 +4x^3 // expect f*x^-3 == -1 -2x -3x^2 +4x^3
assert_eq!( assert_eq!(
f.left_rotate(3), f.left_rotate(3),
Tn::<N>::from_vec(
Tn::from_vec(
&param,
vec![-1i64, -2, -3, 4] vec![-1i64, -2, -3, 4]
.iter() .iter()
.map(|c| T64(*c as u64)) .map(|c| T64(*c as u64))
@ -277,7 +355,8 @@ mod tests {
// expect f*x^-1 == 3 -4x -1x^2 -2x^3 // expect f*x^-1 == 3 -4x -1x^2 -2x^3
assert_eq!( assert_eq!(
f.left_rotate(1), f.left_rotate(1),
Tn::<N>::from_vec(
Tn::from_vec(
&param,
vec![3i64, -4, -1, -2] vec![3i64, -4, -1, -2]
.iter() .iter()
.map(|c| T64(*c as u64)) .map(|c| T64(*c as u64))

+ 23
- 15
arith/src/torus.rs

@ -4,7 +4,7 @@ use std::{
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
}; };
use crate::ring::Ring;
use crate::ring::{Ring, RingParam};
/// Let 𝕋 = ℝ/ℤ, where 𝕋 is a ℤ-module, with homogeneous external product. /// Let 𝕋 = ℝ/ℤ, where 𝕋 is a ℤ-module, with homogeneous external product.
/// Let 𝕋q /// Let 𝕋q
@ -16,20 +16,24 @@ pub struct T64(pub u64);
// `Tn<1>`. // `Tn<1>`.
impl Ring for T64 { impl Ring for T64 {
type C = T64; type C = T64;
const Q: u64 = u64::MAX; // WIP
const N: usize = 1;
fn param(&self) -> RingParam {
RingParam {
q: u64::MAX, // WIP
n: 1,
}
}
fn coeffs(&self) -> Vec<T64> { fn coeffs(&self) -> Vec<T64> {
vec![self.clone()] vec![self.clone()]
} }
fn zero() -> Self {
fn zero(_: &RingParam) -> Self {
Self(0u64) Self(0u64)
} }
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>) -> Self {
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>, _: &RingParam) -> Self {
let r: f64 = dist.sample(&mut rng); let r: f64 = dist.sample(&mut rng);
Self(r.round() as u64) Self(r.round() as u64)
} }
fn from_vec(coeffs: Vec<Self::C>) -> Self {
fn from_vec(_n: &RingParam, coeffs: Vec<Self::C>) -> Self {
assert_eq!(coeffs.len(), 1); assert_eq!(coeffs.len(), 1);
coeffs[0] coeffs[0]
} }
@ -37,27 +41,27 @@ impl Ring for T64 {
// TODO rm beta & l from inputs, make it always beta=2,l=64. // TODO rm beta & l from inputs, make it always beta=2,l=64.
/// Note: only beta=2 and l=64 is supported. /// Note: only beta=2 and l=64 is supported.
fn decompose(&self, beta: u32, l: u32) -> Vec<Self> { fn decompose(&self, beta: u32, l: u32) -> Vec<Self> {
assert_eq!(beta, 2u32); // only beta=2 supported
// assert_eq!(l, 64u32); // only l=64 supported
assert_eq!(beta, 2u32, "only beta=2 supported");
// assert_eq!(l, 64u32, "only l=64 supported");
// (0..64) // (0..64)
(0..l)
(0..l as u64)
.rev() .rev()
.map(|i| T64(((self.0 >> i) & 1) as u64))
.map(|i| T64((self.0 >> i) & 1))
.collect() .collect()
} }
fn remodule<const P: u64>(&self) -> T64 {
fn remodule(&self, p: u64) -> T64 {
todo!() todo!()
} }
// modulus switch from Q to Q2: self * Q2/Q // modulus switch from Q to Q2: self * Q2/Q
fn mod_switch<const Q2: u64>(&self) -> T64 {
fn mod_switch(&self, q2: u64) -> T64 {
// for the moment we assume Q|Q2, since Q=2^64, check that Q2 is a power // for the moment we assume Q|Q2, since Q=2^64, check that Q2 is a power
// of two: // of two:
assert!(Q2.is_power_of_two());
assert!(q2.is_power_of_two());
// since Q=2^64, dividing Q2/Q is equivalent to dividing 2^log2(Q2)/2^64 // since Q=2^64, dividing Q2/Q is equivalent to dividing 2^log2(Q2)/2^64
// which would be like right-shifting 64-log2(Q2). // which would be like right-shifting 64-log2(Q2).
let log2_q2 = 63 - Q2.leading_zeros();
let log2_q2 = 63 - q2.leading_zeros();
T64(self.0 >> (64 - log2_q2)) T64(self.0 >> (64 - log2_q2))
} }
@ -173,9 +177,13 @@ mod tests {
let d = x.decompose(beta, l); let d = x.decompose(beta, l);
assert_eq!(recompose(d), T64(u64::MAX - 1)); assert_eq!(recompose(d), T64(u64::MAX - 1));
let param = RingParam {
q: u64::MAX, // WIP
n: 1,
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..1000 { for _ in 0..1000 {
let x = T64::rand(&mut rng, Standard);
let x = T64::rand(&mut rng, Standard, &param);
let d = x.decompose(beta, l); let d = x.decompose(beta, l);
assert_eq!(recompose(d), x); assert_eq!(recompose(d), x);
} }

+ 81
- 48
arith/src/tuple_ring.rs

@ -1,40 +1,46 @@
//! This file implements the struct for an Tuple of Ring Rq elements and its //! This file implements the struct for an Tuple of Ring Rq elements and its
//! operations, which are performed element-wise. //! operations, which are performed element-wise.
use anyhow::Result;
use itertools::zip_eq; use itertools::zip_eq;
use rand::{distributions::Distribution, Rng}; use rand::{distributions::Distribution, Rng};
use rand_distr::{Normal, Uniform};
use std::iter::Sum;
use std::{
array,
ops::{Add, Mul, Neg, Sub},
};
use std::ops::{Add, Mul, Neg, Sub};
use crate::Ring;
use crate::{Ring, RingParam};
/// Tuple of K Ring (Rq) elements. We use Vec<R> to allocate it in the heap, /// Tuple of K Ring (Rq) elements. We use Vec<R> to allocate it in the heap,
/// since if using a fixed-size array it would overflow the stack. /// since if using a fixed-size array it would overflow the stack.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct TR<R: Ring, const K: usize>(pub Vec<R>);
pub struct TR<R: Ring> {
pub k: usize,
pub r: Vec<R>,
}
// TODO rm pub from Vec<R>, so that TR can not be created from a Vec with // TODO rm pub from Vec<R>, so that TR can not be created from a Vec with
// invalid length, since it has to be created using the `new` method. // invalid length, since it has to be created using the `new` method.
impl<R: Ring, const K: usize> TR<R, K> {
pub fn new(v: Vec<R>) -> Self {
assert_eq!(v.len(), K);
Self(v)
impl<R: Ring> TR<R> {
pub fn new(k: usize, r: Vec<R>) -> Self {
assert_eq!(r.len(), k);
Self { k, r }
} }
pub fn zero() -> Self {
Self((0..K).into_iter().map(|_| R::zero()).collect())
pub fn zero(k: usize, r_param: &RingParam) -> Self {
Self {
k,
r: (0..k).into_iter().map(|_| R::zero(r_param)).collect(),
}
} }
pub fn rand(mut rng: impl Rng, dist: impl Distribution<f64>) -> Self {
Self(
(0..K)
pub fn rand(
mut rng: impl Rng,
dist: impl Distribution<f64>,
k: usize,
r_param: &RingParam,
) -> Self {
Self {
k,
r: (0..k)
.into_iter() .into_iter()
.map(|_| R::rand(&mut rng, &dist))
.map(|_| R::rand(&mut rng, &dist, r_param))
.collect(), .collect(),
)
}
} }
// returns the decomposition of each polynomial element // returns the decomposition of each polynomial element
pub fn decompose(&self, beta: u32, l: u32) -> Vec<Self> { pub fn decompose(&self, beta: u32, l: u32) -> Vec<Self> {
@ -43,64 +49,85 @@ impl TR {
} }
} }
impl<const K: usize> TR<crate::torus::T64, K> {
pub fn mod_switch<const Q2: u64>(&self) -> TR<crate::torus::T64, K> {
TR(self.0.iter().map(|c_i| c_i.mod_switch::<Q2>()).collect())
impl TR<crate::torus::T64> {
pub fn mod_switch(&self, q2: u64) -> TR<crate::torus::T64> {
TR::<crate::torus::T64> {
k: self.k,
r: self.r.iter().map(|c_i| c_i.mod_switch(q2)).collect(),
}
} }
// pub fn mod_switch(&self, Q2: u64) -> TR<crate::torus::T64, K> { // pub fn mod_switch(&self, Q2: u64) -> TR<crate::torus::T64, K> {
// TR(self.0.iter().map(|c_i| c_i.mod_switch(Q2)).collect()) // TR(self.0.iter().map(|c_i| c_i.mod_switch(Q2)).collect())
// } // }
} }
impl<const N: usize, const K: usize> TR<crate::ring_torus::Tn<N>, K> {
impl TR<crate::ring_torus::Tn> {
pub fn left_rotate(&self, h: usize) -> Self { pub fn left_rotate(&self, h: usize) -> Self {
TR(self.0.iter().map(|c_i| c_i.left_rotate(h)).collect())
TR {
k: self.k,
r: self.r.iter().map(|c_i| c_i.left_rotate(h)).collect(),
}
} }
} }
impl<R: Ring, const K: usize> TR<R, K> {
impl<R: Ring> TR<R> {
pub fn iter(&self) -> std::slice::Iter<R> { pub fn iter(&self) -> std::slice::Iter<R> {
self.0.iter()
self.r.iter()
} }
} }
impl<R: Ring, const K: usize> Add<TR<R, K>> for TR<R, K> {
impl<R: Ring> Add<TR<R>> for TR<R> {
type Output = Self; type Output = Self;
fn add(self, other: Self) -> Self { fn add(self, other: Self) -> Self {
Self(
zip_eq(self.0, other.0)
debug_assert_eq!(self.k, other.k);
Self {
k: self.k,
r: zip_eq(self.r, other.r)
.map(|(s, o)| s + o) .map(|(s, o)| s + o)
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
)
}
} }
} }
impl<R: Ring, const K: usize> Sub<TR<R, K>> for TR<R, K> {
impl<R: Ring> Sub<TR<R>> for TR<R> {
type Output = Self; type Output = Self;
fn sub(self, other: Self) -> Self { fn sub(self, other: Self) -> Self {
Self(zip_eq(self.0, other.0).map(|(s, o)| s - o).collect())
debug_assert_eq!(self.k, other.k);
Self {
k: self.k,
r: zip_eq(self.r, other.r).map(|(s, o)| s - o).collect(),
}
} }
} }
impl<R: Ring, const K: usize> Neg for TR<R, K> {
impl<R: Ring> Neg for TR<R> {
type Output = Self; type Output = Self;
fn neg(self) -> Self::Output { fn neg(self) -> Self::Output {
Self(self.0.iter().map(|&e_i| -e_i).collect())
Self {
k: self.k,
r: self.r.iter().map(|e_i| -e_i.clone()).collect(),
}
} }
} }
/// for (TR,TR), the Mul operation is defined as the dot product: /// for (TR,TR), the Mul operation is defined as the dot product:
/// for A, B \in R^k, result = Σ A_i * B_i \in R /// for A, B \in R^k, result = Σ A_i * B_i \in R
impl<R: Ring, const K: usize> Mul<TR<R, K>> for TR<R, K> {
impl<R: Ring> Mul<TR<R>> for TR<R> {
type Output = R; type Output = R;
fn mul(self, other: Self) -> R { fn mul(self, other: Self) -> R {
zip_eq(self.0, other.0).map(|(s, o)| s * o).sum()
debug_assert_eq!(self.k, other.k);
zip_eq(self.r, other.r).map(|(s, o)| s * o).sum()
} }
} }
impl<R: Ring, const K: usize> Mul<&TR<R, K>> for &TR<R, K> {
impl<R: Ring> Mul<&TR<R>> for &TR<R> {
type Output = R; type Output = R;
fn mul(self, other: &TR<R, K>) -> R {
zip_eq(self.0.clone(), other.0.clone())
fn mul(self, other: &TR<R>) -> R {
debug_assert_eq!(self.k, other.k);
zip_eq(self.r.clone(), other.r.clone())
.map(|(s, o)| s * o) .map(|(s, o)| s * o)
.sum() .sum()
} }
@ -108,15 +135,21 @@ impl Mul<&TR> for &TR {
/// for (TR, R), the Mul operation is defined as each element of TR is /// for (TR, R), the Mul operation is defined as each element of TR is
/// multiplied by R /// multiplied by R
impl<R: Ring, const K: usize> Mul<R> for TR<R, K> {
type Output = TR<R, K>;
fn mul(self, other: R) -> TR<R, K> {
Self(self.0.iter().map(|s| s.clone() * other.clone()).collect())
impl<R: Ring> Mul<R> for TR<R> {
type Output = TR<R>;
fn mul(self, other: R) -> TR<R> {
Self {
k: self.k,
r: self.r.iter().map(|s| s.clone() * other.clone()).collect(),
}
} }
} }
impl<R: Ring, const K: usize> Mul<&R> for &TR<R, K> {
type Output = TR<R, K>;
fn mul(self, other: &R) -> TR<R, K> {
TR::<R, K>(self.0.iter().map(|s| s.clone() * other.clone()).collect())
impl<R: Ring> Mul<&R> for &TR<R> {
type Output = TR<R>;
fn mul(self, other: &R) -> TR<R> {
TR::<R> {
k: self.k,
r: self.r.iter().map(|s| s.clone() * other.clone()).collect(),
}
} }
} }

+ 169
- 124
arith/src/zq.rs

@ -4,41 +4,39 @@ use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign};
/// Z_q, integers modulus q, not necessarily prime /// Z_q, integers modulus q, not necessarily prime
#[derive(Clone, Copy, PartialEq)] #[derive(Clone, Copy, PartialEq)]
pub struct Zq<const Q: u64>(pub u64);
// WIP
// impl<const Q: u64> From<Vec<u64>> for Vec<Zq<Q>> {
// fn from(v: Vec<u64>) -> Self {
// v.into_iter().map(Zq::new).collect()
// }
// }
pub struct Zq {
pub q: u64,
pub v: u64,
}
pub(crate) fn modulus_u64<const Q: u64>(e: u64) -> u64 {
(e % Q + Q) % Q
pub(crate) fn modulus_u64(q: u64, e: u64) -> u64 {
(e % q + q) % q
} }
impl<const Q: u64> Zq<Q> {
pub fn rand(mut rng: impl Rng, dist: impl Distribution<f64>) -> Self {
impl Zq {
pub fn rand(mut rng: impl Rng, dist: impl Distribution<f64>, q: u64) -> Self {
// TODO WIP // TODO WIP
let r: f64 = dist.sample(&mut rng); let r: f64 = dist.sample(&mut rng);
Self::from_f64(r)
// Self::from_u64(r.round() as u64)
}
pub fn from_u64(e: u64) -> Self {
if e >= Q {
// (e % Q + Q) % Q
return Zq(modulus_u64::<Q>(e));
// return Zq(e % Q);
Self::from_f64(q, r)
}
pub fn from_u64(q: u64, v: u64) -> Self {
if v >= q {
// (v % Q + Q) % Q
return Zq {
q,
v: modulus_u64(q, v),
};
// return Zq(v % Q);
} }
Zq(e)
Zq { q, v }
} }
pub fn from_f64(e: f64) -> Self {
pub fn from_f64(q: u64, e: f64) -> Self {
// WIP method // WIP method
let e: i64 = e.round() as i64; let e: i64 = e.round() as i64;
let q = Q as i64;
if e < 0 || e >= q {
return Zq(((e % q + q) % q) as u64);
let q_i64 = q as i64;
if e < 0 || e >= q_i64 {
return Zq::from_u64(q, ((e % q_i64 + q_i64) % q_i64) as u64);
} }
Zq(e as u64)
Zq { q, v: e as u64 }
// if e < 0 { // if e < 0 {
// // dbg!(&e); // // dbg!(&e);
@ -50,15 +48,18 @@ impl Zq {
// } // }
// Zq(e as u64) // Zq(e as u64)
} }
pub fn from_bool(b: bool) -> Self {
pub fn from_bool(q: u64, b: bool) -> Self {
if b { if b {
Zq(1)
Zq { q, v: 1 }
} else { } else {
Zq(0)
Zq { q, v: 0 }
} }
} }
pub fn zero() -> Self {
Self(0u64)
pub fn zero(q: u64) -> Self {
Self { q, v: 0u64 }
}
pub fn one(q: u64) -> Self {
Self { q, v: 1u64 }
} }
pub fn square(self) -> Self { pub fn square(self) -> Self {
self * self self * self
@ -66,18 +67,21 @@ impl Zq {
// modular exponentiation // modular exponentiation
pub fn exp(self, e: Self) -> Self { pub fn exp(self, e: Self) -> Self {
// mul-square approach // mul-square approach
let mut res = Self(1);
let mut res = Self::one(self.q);
let mut rem = e.clone(); let mut rem = e.clone();
let mut exp = self; let mut exp = self;
// for rem != Self(0) { // for rem != Self(0) {
while rem != Self(0) {
while rem != Self::zero(self.q) {
// if odd // if odd
// TODO use a more readible expression
if 1 - ((rem.0 & 1) << 1) as i64 == -1 {
// TODO use a more readeable expression
if 1 - ((rem.v & 1) << 1) as i64 == -1 {
res = res * exp; res = res * exp;
} }
exp = exp.square(); exp = exp.square();
rem = Self(rem.0 >> 1);
rem = Self {
q: self.q,
v: rem.v >> 1,
};
} }
res res
} }
@ -89,9 +93,9 @@ impl Zq {
// let a = self.0; // let a = self.0;
// let q = Q; // let q = Q;
let mut t = 0; let mut t = 0;
let mut r = Q;
let mut r = self.q;
let mut new_t = 0; let mut new_t = 0;
let mut new_r = self.0.clone();
let mut new_r = self.v.clone();
while new_r != 0 { while new_r != 0 {
let q = r / new_r; let q = r / new_r;
@ -104,16 +108,16 @@ impl Zq {
// if t < 0 { // if t < 0 {
// t = t + q; // t = t + q;
// } // }
return Zq::from_u64(t);
return Zq::from_u64(self.q, t);
} }
pub fn inv(self) -> Zq<Q> {
let (g, x, _) = Self::egcd(self.0 as i128, Q as i128);
pub fn inv(self) -> Zq {
let (g, x, _) = Self::egcd(self.v as i128, self.q as i128);
if g != 1 { if g != 1 {
// None // None
panic!("E"); panic!("E");
} else { } else {
let q = Q as i128;
Zq(((x % q + q) % q) as u64) // TODO maybe just Zq::new(x)
let q = self.q as i128;
Zq::from_u64(self.q, ((x % q + q) % q) as u64) // TODO maybe just Zq::new(x)
} }
} }
fn egcd(a: i128, b: i128) -> (i128, i128, i128) { fn egcd(a: i128, b: i128) -> (i128, i128, i128) {
@ -126,8 +130,11 @@ impl Zq {
} }
/// perform the mod switch operation from Q to Q', where Q2=Q' /// perform the mod switch operation from Q to Q', where Q2=Q'
pub fn mod_switch<const Q2: u64>(&self) -> Zq<Q2> {
Zq::<Q2>::from_u64(((self.0 as f64 * Q2 as f64) / Q as f64).round() as u64)
pub fn mod_switch(&self, q2: u64) -> Zq {
Zq::from_u64(
q2,
((self.v as f64 * q2 as f64) / self.q as f64).round() as u64,
)
} }
pub fn decompose(&self, beta: u32, l: u32) -> Vec<Self> { pub fn decompose(&self, beta: u32, l: u32) -> Vec<Self> {
@ -138,19 +145,25 @@ impl Zq {
} }
} }
pub fn decompose_base_beta(&self, beta: u32, l: u32) -> Vec<Self> { pub fn decompose_base_beta(&self, beta: u32, l: u32) -> Vec<Self> {
let mut rem: u64 = self.0;
let mut rem: u64 = self.v;
// next if is for cases in which beta does not divide Q (concretely // next if is for cases in which beta does not divide Q (concretely
// beta^l!=Q). round to the nearest multiple of q/beta^l // beta^l!=Q). round to the nearest multiple of q/beta^l
if rem >= beta.pow(l) as u64 { if rem >= beta.pow(l) as u64 {
// rem = Q - 1 - (Q / beta as u64); // floor // rem = Q - 1 - (Q / beta as u64); // floor
return vec![Zq(beta as u64 - 1); l as usize];
return vec![
Zq {
q: self.q,
v: beta as u64 - 1
};
l as usize
];
} }
let mut x: Vec<Self> = vec![]; let mut x: Vec<Self> = vec![];
for i in 1..l + 1 { for i in 1..l + 1 {
let den = Q / beta.pow(i) as u64;
let den = self.q / beta.pow(i) as u64;
let x_i = rem / den; // division between u64 already does floor let x_i = rem / den; // division between u64 already does floor
x.push(Self::from_u64(x_i));
x.push(Self::from_u64(self.q, x_i));
if x_i != 0 { if x_i != 0 {
rem = rem % den; rem = rem % den;
} }
@ -161,15 +174,15 @@ impl Zq {
pub fn decompose_base2(&self, l: u32) -> Vec<Self> { pub fn decompose_base2(&self, l: u32) -> Vec<Self> {
// next if is for cases in which beta does not divide Q (concretely // next if is for cases in which beta does not divide Q (concretely
// beta^l!=Q). round to the nearest multiple of q/beta^l // beta^l!=Q). round to the nearest multiple of q/beta^l
if self.0 >= 1 << l as u64 {
if self.v >= 1 << l as u64 {
// rem = Q - 1 - (Q / beta as u64); // floor // rem = Q - 1 - (Q / beta as u64); // floor
// (where beta=2) // (where beta=2)
return vec![Zq(1); l as usize];
return vec![Zq::one(self.q); l as usize];
} }
(0..l) (0..l)
.rev() .rev()
.map(|i| Self(((self.0 >> i) & 1) as u64))
.map(|i| Self::from_u64(self.q, ((self.v >> i) & 1) as u64))
.collect() .collect()
// naive ver: // naive ver:
@ -194,114 +207,143 @@ impl Zq {
} }
} }
impl<const Q: u64> Zq<Q> {
impl Zq {
fn r#mod(self) -> Self { fn r#mod(self) -> Self {
if self.0 >= Q {
return Zq(self.0 % Q);
if self.v >= self.q {
return Zq::from_u64(self.q, self.v % self.q);
} }
self self
} }
} }
impl<const Q: u64> Add<Zq<Q>> for Zq<Q> {
impl Add<Zq> for Zq {
type Output = Self; type Output = Self;
fn add(self, rhs: Self) -> Self::Output { fn add(self, rhs: Self) -> Self::Output {
let mut r = self.0 + rhs.0;
if r >= Q {
r -= Q;
assert_eq!(self.q, rhs.q);
let mut v = self.v + rhs.v;
if v >= self.q {
v -= self.q;
} }
Zq(r)
Zq { q: self.q, v }
} }
} }
impl<const Q: u64> Add<&Zq<Q>> for &Zq<Q> {
type Output = Zq<Q>;
impl Add<&Zq> for &Zq {
type Output = Zq;
fn add(self, rhs: &Zq<Q>) -> Self::Output {
let mut r = self.0 + rhs.0;
if r >= Q {
r -= Q;
fn add(self, rhs: &Zq) -> Self::Output {
assert_eq!(self.q, rhs.q);
let mut v = self.v + rhs.v;
if v >= self.q {
v -= self.q;
} }
Zq(r)
Zq { q: self.q, v }
} }
} }
impl<const Q: u64> AddAssign<Zq<Q>> for Zq<Q> {
impl AddAssign<Zq> for Zq {
fn add_assign(&mut self, rhs: Self) { fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs *self = *self + rhs
} }
} }
impl<const Q: u64> std::iter::Sum for Zq<Q> {
fn sum<I>(iter: I) -> Self
impl std::iter::Sum for Zq {
fn sum<I>(mut iter: I) -> Self
where where
I: Iterator<Item = Self>, I: Iterator<Item = Self>,
{ {
iter.fold(Zq(0), |acc, x| acc + x)
let first: Zq = iter.next().unwrap();
iter.fold(first, |acc, x| acc + x)
} }
} }
impl<const Q: u64> Sub<Zq<Q>> for Zq<Q> {
impl Sub<Zq> for Zq {
type Output = Self; type Output = Self;
fn sub(self, rhs: Self) -> Zq<Q> {
if self.0 >= rhs.0 {
Zq(self.0 - rhs.0)
fn sub(self, rhs: Self) -> Zq {
assert_eq!(self.q, rhs.q);
if self.v >= rhs.v {
Zq {
q: self.q,
v: self.v - rhs.v,
}
} else { } else {
Zq((Q + self.0) - rhs.0)
Zq {
q: self.q,
v: (self.q + self.v) - rhs.v,
}
} }
} }
} }
impl<const Q: u64> Sub<&Zq<Q>> for &Zq<Q> {
type Output = Zq<Q>;
impl Sub<&Zq> for &Zq {
type Output = Zq;
fn sub(self, rhs: &Zq) -> Self::Output {
assert_eq!(self.q, rhs.q);
fn sub(self, rhs: &Zq<Q>) -> Self::Output {
if self.0 >= rhs.0 {
Zq(self.0 - rhs.0)
if self.q >= rhs.q {
Zq {
q: self.q,
v: self.v - rhs.v,
}
} else { } else {
Zq((Q + self.0) - rhs.0)
Zq {
q: self.q,
v: (self.q + self.v) - rhs.v,
}
} }
} }
} }
impl<const Q: u64> SubAssign<Zq<Q>> for Zq<Q> {
impl SubAssign<Zq> for Zq {
fn sub_assign(&mut self, rhs: Self) { fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs *self = *self - rhs
} }
} }
impl<const Q: u64> Neg for Zq<Q> {
impl Neg for Zq {
type Output = Self; type Output = Self;
fn neg(self) -> Self::Output { fn neg(self) -> Self::Output {
if self.0 == 0 {
if self.v == 0 {
return self; return self;
} }
Zq(Q - self.0)
Zq {
q: self.q,
v: self.q - self.v,
}
} }
} }
impl<const Q: u64> Mul<Zq<Q>> for Zq<Q> {
impl Mul<Zq> for Zq {
type Output = Self; type Output = Self;
fn mul(self, rhs: Self) -> Zq<Q> {
fn mul(self, rhs: Self) -> Zq {
assert_eq!(self.q, rhs.q);
// TODO non-naive way // TODO non-naive way
Zq(((self.0 as u128 * rhs.0 as u128) % Q as u128) as u64)
Zq {
q: self.q,
v: ((self.v as u128 * rhs.v as u128) % self.q as u128) as u64,
}
// Zq((self.0 * rhs.0) % Q) // Zq((self.0 * rhs.0) % Q)
} }
} }
impl<const Q: u64> Div<Zq<Q>> for Zq<Q> {
impl Div<Zq> for Zq {
type Output = Self; type Output = Self;
fn div(self, rhs: Self) -> Zq<Q> {
fn div(self, rhs: Self) -> Zq {
// TODO non-naive way // TODO non-naive way
// Zq((self.0 / rhs.0) % Q) // Zq((self.0 / rhs.0) % Q)
self * rhs.inv() self * rhs.inv()
} }
} }
impl<const Q: u64> fmt::Display for Zq<Q> {
impl fmt::Display for Zq {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
write!(f, "{}", self.v)
} }
} }
impl<const Q: u64> fmt::Debug for Zq<Q> {
impl fmt::Debug for Zq {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
write!(f, "{}", self.v)
} }
} }
@ -312,80 +354,83 @@ mod tests {
#[test] #[test]
fn exp() { fn exp() {
const Q: u64 = 1021;
let a = Zq::<Q>(3);
let b = Zq::<Q>(3);
assert_eq!(a.exp(b), Zq(27));
const q: u64 = 1021;
let a = Zq::from_u64(q, 3);
let b = Zq::from_u64(q, 3);
assert_eq!(a.exp(b), Zq::from_u64(q, 27));
let a = Zq::<Q>(1000);
let b = Zq::<Q>(3);
assert_eq!(a.exp(b), Zq(949));
let a = Zq::from_u64(q, 1000);
let b = Zq::from_u64(q, 3);
assert_eq!(a.exp(b), Zq::from_u64(q, 949));
} }
#[test] #[test]
fn neg() { fn neg() {
const Q: u64 = 1021;
let a = Zq::<Q>::from_f64(101.0);
let b = Zq::<Q>::from_f64(-1.0);
let q: u64 = 1021;
let a = Zq::from_f64(q, 101.0);
let b = Zq::from_f64(q, -1.0);
assert_eq!(-a, a * b); assert_eq!(-a, a * b);
} }
fn recompose<const Q: u64>(beta: u32, l: u32, d: Vec<Zq<Q>>) -> Zq<Q> {
fn recompose(q: u64, beta: u32, l: u32, d: Vec<Zq>) -> Zq {
let mut x = 0u64; let mut x = 0u64;
for i in 0..l { for i in 0..l {
x += d[i as usize].0 * Q / beta.pow(i + 1) as u64;
x += d[i as usize].v * q / beta.pow(i + 1) as u64;
} }
Zq::from_u64(x)
Zq::from_u64(q, x)
} }
#[test] #[test]
fn test_decompose() { fn test_decompose() {
const Q1: u64 = 16;
let q1: u64 = 16;
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 4; let l: u32 = 4;
let x = Zq::<Q1>::from_u64(9);
let x = Zq::from_u64(q1, 9);
let d = x.decompose(beta, l); let d = x.decompose(beta, l);
assert_eq!(recompose::<Q1>(beta, l, d), x);
assert_eq!(recompose(q1, beta, l, d), x);
const Q: u64 = 5u64.pow(3);
let q: u64 = 5u64.pow(3);
let beta: u32 = 5; let beta: u32 = 5;
let l: u32 = 3; let l: u32 = 3;
let dist = Uniform::new(0_u64, Q);
let dist = Uniform::new(0_u64, q);
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..1000 { for _ in 0..1000 {
let x = Zq::<Q>::from_u64(dist.sample(&mut rng));
let x = Zq::from_u64(q, dist.sample(&mut rng));
let d = x.decompose(beta, l); let d = x.decompose(beta, l);
assert_eq!(d.len(), l as usize); assert_eq!(d.len(), l as usize);
assert_eq!(recompose::<Q>(beta, l, d), x);
assert_eq!(recompose(q, beta, l, d), x);
} }
} }
#[test] #[test]
fn test_decompose_approx() { fn test_decompose_approx() {
const Q: u64 = 2u64.pow(4) + 1;
let q: u64 = 2u64.pow(4) + 1;
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 4; let l: u32 = 4;
let x = Zq::<Q>::from_u64(16); // in q, but bigger than beta^l
let x = Zq::from_u64(q, 16); // in q, but bigger than beta^l
let d = x.decompose(beta, l); let d = x.decompose(beta, l);
assert_eq!(d.len(), l as usize); assert_eq!(d.len(), l as usize);
assert_eq!(recompose::<Q>(beta, l, d), Zq(15));
assert_eq!(recompose(q, beta, l, d), Zq::from_u64(q, 15));
const Q2: u64 = 5u64.pow(3) + 1;
let q2: u64 = 5u64.pow(3) + 1;
let beta: u32 = 5; let beta: u32 = 5;
let l: u32 = 3; let l: u32 = 3;
let x = Zq::<Q2>::from_u64(125); // in q, but bigger than beta^l
let x = Zq::from_u64(q2, 125); // in q, but bigger than beta^l
let d = x.decompose(beta, l); let d = x.decompose(beta, l);
assert_eq!(d.len(), l as usize); assert_eq!(d.len(), l as usize);
assert_eq!(recompose::<Q2>(beta, l, d), Zq(124));
assert_eq!(recompose(q2, beta, l, d), Zq::from_u64(q2, 124));
const Q3: u64 = 2u64.pow(16) + 1;
let q3: u64 = 2u64.pow(16) + 1;
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 16; let l: u32 = 16;
let x = Zq::<Q3>::from_u64(Q3 - 1); // in q, but bigger than beta^l
let x = Zq::from_u64(q3, q3 - 1); // in q, but bigger than beta^l
let d = x.decompose(beta, l); let d = x.decompose(beta, l);
assert_eq!(d.len(), l as usize); assert_eq!(d.len(), l as usize);
assert_eq!(recompose::<Q3>(beta, l, d), Zq(beta.pow(l) as u64 - 1));
assert_eq!(
recompose(q3, beta, l, d),
Zq::from_u64(q3, beta.pow(l) as u64 - 1)
);
} }
} }

+ 343
- 302
bfv/src/lib.rs

@ -10,44 +10,61 @@ use rand::Rng;
use rand_distr::{Normal, Uniform}; use rand_distr::{Normal, Uniform};
use std::ops; use std::ops;
use arith::{Ring, Rq, R};
use arith::{Ring, RingParam, Rq, R};
// error deviation for the Gaussian(Normal) distribution // error deviation for the Gaussian(Normal) distribution
// sigma=3.2 from: https://eprint.iacr.org/2022/162.pdf page 5 // sigma=3.2 from: https://eprint.iacr.org/2022/162.pdf page 5
const ERR_SIGMA: f64 = 3.2; const ERR_SIGMA: f64 = 3.2;
#[derive(Clone, Copy, Debug)]
pub struct Param {
ring: RingParam,
t: u64,
p: u64,
}
impl Param {
// returns the plaintext param
pub fn pt(&self) -> RingParam {
RingParam {
q: self.t,
n: self.ring.n,
}
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct SecretKey<const Q: u64, const N: usize>(Rq<Q, N>);
pub struct SecretKey(Rq);
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct PublicKey<const Q: u64, const N: usize>(Rq<Q, N>, Rq<Q, N>);
pub struct PublicKey(Rq, Rq);
/// Relinearization key /// Relinearization key
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct RLK<const PQ: u64, const N: usize>(Rq<PQ, N>, Rq<PQ, N>);
pub struct RLK(Rq, Rq);
// RLWE ciphertext // RLWE ciphertext
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct RLWE<const Q: u64, const N: usize>(Rq<Q, N>, Rq<Q, N>);
pub struct RLWE(Rq, Rq);
impl<const Q: u64, const N: usize> RLWE<Q, N> {
impl RLWE {
fn add(lhs: Self, rhs: Self) -> Self { fn add(lhs: Self, rhs: Self) -> Self {
RLWE::<Q, N>(lhs.0 + rhs.0, lhs.1 + rhs.1)
RLWE(lhs.0 + rhs.0, lhs.1 + rhs.1)
} }
pub fn remodule<const P: u64>(&self) -> RLWE<P, N> {
let x = self.0.remodule::<P>();
let y = self.1.remodule::<P>();
RLWE::<P, N>(x, y)
pub fn remodule(&self, p: u64) -> RLWE {
let x = self.0.remodule(p);
let y = self.1.remodule(p);
RLWE(x, y)
} }
fn tensor<const PQ: u64, const T: u64>(a: &Self, b: &Self) -> (Rq<Q, N>, Rq<Q, N>, Rq<Q, N>) {
fn tensor(t: u64, a: &Self, b: &Self) -> (Rq, Rq, Rq) {
let (q, n) = (a.0.param.q, a.0.param.n);
// expand Q->PQ // TODO rm // expand Q->PQ // TODO rm
// get the coefficients in Z, ie. interpret a,b \in R (instead of R_q) // get the coefficients in Z, ie. interpret a,b \in R (instead of R_q)
let a0: R<N> = a.0.to_r();
let a1: R<N> = a.1.to_r();
let b0: R<N> = b.0.to_r();
let b1: R<N> = b.1.to_r();
let a0: R = a.0.clone().to_r(); // TODO rm clone()
let a1: R = a.1.clone().to_r();
let b0: R = b.0.clone().to_r();
let b1: R = b.1.clone().to_r();
// tensor (\in R) (2021-204 p.9) // tensor (\in R) (2021-204 p.9)
// NOTE: here can use *, but at first versions want to make it explicit // NOTE: here can use *, but at first versions want to make it explicit
@ -60,44 +77,47 @@ impl RLWE {
let c2: Vec<i64> = naive_mul(&a1, &b1); let c2: Vec<i64> = naive_mul(&a1, &b1);
// scale down, then reduce module Q, so result is \in R_q // scale down, then reduce module Q, so result is \in R_q
let c0: Rq<Q, N> = arith::ring_n::mul_div_round::<Q, N>(c0, T, Q);
let c1: Rq<Q, N> = arith::ring_n::mul_div_round::<Q, N>(c1, T, Q);
let c2: Rq<Q, N> = arith::ring_n::mul_div_round::<Q, N>(c2, T, Q);
let c0: Rq = arith::ring_n::mul_div_round(q, n, c0, t, q);
let c1: Rq = arith::ring_n::mul_div_round(q, n, c1, t, q);
let c2: Rq = arith::ring_n::mul_div_round(q, n, c2, t, q);
(c0, c1, c2) (c0, c1, c2)
} }
/// ciphertext multiplication /// ciphertext multiplication
fn mul<const PQ: u64, const T: u64>(rlk: &RLK<PQ, N>, a: &Self, b: &Self) -> Self {
let (c0, c1, c2) = Self::tensor::<PQ, T>(a, b);
BFV::<Q, N, T>::relinearize_204::<PQ>(&rlk, &c0, &c1, &c2)
fn mul(t: u64, rlk: &RLK, a: &Self, b: &Self) -> Self {
let (c0, c1, c2) = Self::tensor(t, a, b);
BFV::relinearize_204(&rlk, &c0, &c1, &c2)
} }
} }
// naive mul in the ring Rq, reusing the ring_n::naive_mul and then applying mod(X^N +1) // naive mul in the ring Rq, reusing the ring_n::naive_mul and then applying mod(X^N +1)
fn tmp_naive_mul<const Q: u64, const N: usize>(a: Rq<Q, N>, b: Rq<Q, N>) -> Rq<Q, N> {
Rq::<Q, N>::from_vec_i64(arith::ring_n::naive_mul(&a.to_r(), &b.to_r()))
fn tmp_naive_mul(a: Rq, b: Rq) -> Rq {
Rq::from_vec_i64(
&a.param.clone(),
arith::ring_n::naive_mul(&a.to_r(), &b.to_r()),
)
} }
impl<const Q: u64, const N: usize> ops::Add<RLWE<Q, N>> for RLWE<Q, N> {
impl ops::Add<RLWE> for RLWE {
type Output = Self; type Output = Self;
fn add(self, rhs: Self) -> Self { fn add(self, rhs: Self) -> Self {
Self::add(self, rhs) Self::add(self, rhs)
} }
} }
impl<const Q: u64, const N: usize, const T: u64> ops::Add<&Rq<T, N>> for &RLWE<Q, N> {
type Output = RLWE<Q, N>;
fn add(self, rhs: &Rq<T, N>) -> Self::Output {
BFV::<Q, N, T>::add_const(self, rhs)
impl ops::Add<&Rq> for &RLWE {
type Output = RLWE;
fn add(self, rhs: &Rq) -> Self::Output {
BFV::add_const(self, rhs)
} }
} }
pub struct BFV<const Q: u64, const N: usize, const T: u64> {}
pub struct BFV {}
impl<const Q: u64, const N: usize, const T: u64> BFV<Q, N, T> {
const DELTA: u64 = Q / T; // floor
impl BFV {
// const DELTA: u64 = Q / T; // floor
/// generate a new key pair (privK, pubK) /// generate a new key pair (privK, pubK)
pub fn new_key(mut rng: impl Rng) -> Result<(SecretKey<Q, N>, PublicKey<Q, N>)> {
pub fn new_key(mut rng: impl Rng, param: &Param) -> Result<(SecretKey, PublicKey)> {
// WIP: review probabilities // WIP: review probabilities
// let Xi_key = Uniform::new(-1_f64, 1_f64); // let Xi_key = Uniform::new(-1_f64, 1_f64);
@ -105,114 +125,135 @@ impl BFV {
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
// secret key // secret key
// let mut s = Rq::<Q, N>::rand_f64(&mut rng, Xi_key)?;
let mut s = Rq::<Q, N>::rand_u64(&mut rng, Xi_key)?;
// let mut s = Rq::rand_f64(&mut rng, Xi_key)?;
let mut s = Rq::rand_u64(&mut rng, Xi_key, &param.ring)?;
// since s is going to be multiplied by other Rq elements, already // since s is going to be multiplied by other Rq elements, already
// compute its NTT // compute its NTT
s.compute_evals(); s.compute_evals();
// pk = (-a * s + e, a) // pk = (-a * s + e, a)
let a = Rq::<Q, N>::rand_u64(&mut rng, Uniform::new(0_u64, Q))?;
let e = Rq::<Q, N>::rand_f64(&mut rng, Xi_err)?;
let pk: PublicKey<Q, N> = PublicKey((&(-a) * &s) + e, a.clone());
let a = Rq::rand_u64(&mut rng, Uniform::new(0_u64, param.ring.q), &param.ring)?;
let e = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let pk: PublicKey = PublicKey(&(&(-a.clone()) * &s) + &e, a.clone()); // TODO rm clones
Ok((SecretKey(s), pk)) Ok((SecretKey(s), pk))
} }
pub fn encrypt(mut rng: impl Rng, pk: &PublicKey<Q, N>, m: &Rq<T, N>) -> Result<RLWE<Q, N>> {
// note: m is modulus t
pub fn encrypt(mut rng: impl Rng, param: &Param, pk: &PublicKey, m: &Rq) -> Result<RLWE> {
// assert param & inputs
debug_assert_eq!(param.ring, pk.0.param);
debug_assert_eq!(param.t, m.param.q);
debug_assert_eq!(param.ring.n, m.param.n);
let Xi_key = Uniform::new(-1_f64, 1_f64); let Xi_key = Uniform::new(-1_f64, 1_f64);
// let Xi_key = Uniform::new(0_u64, 2_u64); // let Xi_key = Uniform::new(0_u64, 2_u64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let u = Rq::<Q, N>::rand_f64(&mut rng, Xi_key)?;
// let u = Rq::<Q, N>::rand_u64(&mut rng, Xi_key)?;
let e_1 = Rq::<Q, N>::rand_f64(&mut rng, Xi_err)?;
let e_2 = Rq::<Q, N>::rand_f64(&mut rng, Xi_err)?;
let u = Rq::rand_f64(&mut rng, Xi_key, &param.ring)?;
// let u = Rq::rand_u64(&mut rng, Xi_key)?;
let e_1 = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let e_2 = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
// migrate m's coeffs to the bigger modulus Q (from T) // migrate m's coeffs to the bigger modulus Q (from T)
let m = m.remodule::<Q>();
let c0 = &pk.0 * &u + e_1 + m * Self::DELTA;
let m = m.remodule(param.ring.q);
let c0 = &pk.0 * &u + e_1 + m * (param.ring.q / param.t); // floor(q/t)=DELTA
let c1 = &pk.1 * &u + e_2; let c1 = &pk.1 * &u + e_2;
Ok(RLWE::<Q, N>(c0, c1))
Ok(RLWE(c0, c1))
} }
pub fn decrypt(sk: &SecretKey<Q, N>, c: &RLWE<Q, N>) -> Rq<T, N> {
let cs = c.0 + c.1 * sk.0; // done in mod q
pub fn decrypt(param: &Param, sk: &SecretKey, c: &RLWE) -> Rq {
debug_assert_eq!(param.ring, sk.0.param);
debug_assert_eq!(param.ring.q, c.0.param.q);
debug_assert_eq!(param.ring.n, c.0.param.n);
let cs: Rq = &c.0 + &(&c.1 * &sk.0); // done in mod q
// same but with naive_mul: // same but with naive_mul:
// let c1s = arith::ring_n::naive_mul(&c.1.to_r(), &sk.0.to_r()); // let c1s = arith::ring_n::naive_mul(&c.1.to_r(), &sk.0.to_r());
// let c1s = Rq::<Q, N>::from_vec_i64(c1s);
// let c1s = Rq::from_vec_i64(c1s);
// let cs = c.0 + c1s; // let cs = c.0 + c1s;
let r: Rq<Q, N> = cs.mul_div_round(T, Q);
r.remodule::<T>()
let r: Rq = cs.mul_div_round(param.t, param.ring.q);
r.remodule(param.t)
} }
fn add_const(c: &RLWE<Q, N>, m: &Rq<T, N>) -> RLWE<Q, N> {
fn add_const(c: &RLWE, m: &Rq) -> RLWE {
let q = c.0.param.q;
let t = m.param.q;
// assuming T<Q, move m from Zq<T> to Zq<Q> // assuming T<Q, move m from Zq<T> to Zq<Q>
let m = m.remodule::<Q>();
RLWE::<Q, N>(c.0 + m * Self::DELTA, c.1)
let m = m.remodule(c.0.param.q);
// TODO rm clones
RLWE(c.0.clone() + m * (q / t), c.1.clone()) // floor(q/t)=DELTA
} }
fn mul_const<const PQ: u64>(rlk: &RLK<PQ, N>, c: &RLWE<Q, N>, m: &Rq<T, N>) -> RLWE<Q, N> {
fn mul_const(rlk: &RLK, c: &RLWE, m: &Rq) -> RLWE {
// let pq = rlk.0.q;
let q = c.0.param.q;
let t = m.param.q;
// assuming T<Q, move m from Zq<T> to Zq<Q> // assuming T<Q, move m from Zq<T> to Zq<Q>
let m = m.remodule::<Q>();
let m = m.remodule(q);
// encrypt m*Delta without noise, and then perform normal ciphertext multiplication // encrypt m*Delta without noise, and then perform normal ciphertext multiplication
let md = RLWE::<Q, N>(m * Self::DELTA, Rq::zero());
RLWE::<Q, N>::mul::<PQ, T>(&rlk, &c, &md)
let md = RLWE(m * (q / t), Rq::zero(&c.0.param)); // floor(q/t)=DELTA
RLWE::mul(t, &rlk, &c, &md)
} }
fn rlk_key<const PQ: u64>(mut rng: impl Rng, s: &SecretKey<Q, N>) -> Result<RLK<PQ, N>> {
fn rlk_key(mut rng: impl Rng, param: &Param, s: &SecretKey) -> Result<RLK> {
let pq = param.p * param.ring.q;
let rlk_param = RingParam {
q: pq,
n: param.ring.n,
};
// TODO review using Xi' instead of Xi // TODO review using Xi' instead of Xi
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
// let Xi_err = Normal::new(0_f64, 0.0)?; // let Xi_err = Normal::new(0_f64, 0.0)?;
let s = s.0.remodule::<PQ>();
let a = Rq::<PQ, N>::rand_u64(&mut rng, Uniform::new(0_u64, PQ))?;
let e = Rq::<PQ, N>::rand_f64(&mut rng, Xi_err)?;
let P = PQ / Q;
let s = s.0.remodule(pq);
let a = Rq::rand_u64(&mut rng, Uniform::new(0_u64, pq), &rlk_param)?;
let e = Rq::rand_f64(&mut rng, Xi_err, &rlk_param)?;
// let rlk: RLK<PQ, N> = RLK::<PQ, N>(-(&a * &s + e) + (s * s) * P, a.clone()); // let rlk: RLK<PQ, N> = RLK::<PQ, N>(-(&a * &s + e) + (s * s) * P, a.clone());
let rlk: RLK<PQ, N> = RLK::<PQ, N>(
-(tmp_naive_mul(a, s) + e) + tmp_naive_mul(s, s) * P,
// TODO rm clones
let rlk: RLK = RLK(
-(tmp_naive_mul(a.clone(), s.clone()) + e)
+ tmp_naive_mul(s.clone(), s.clone()) * param.p,
a.clone(), a.clone(),
); );
Ok(rlk) Ok(rlk)
} }
fn relinearize<const PQ: u64>(
rlk: &RLK<PQ, N>,
c0: &Rq<Q, N>,
c1: &Rq<Q, N>,
c2: &Rq<Q, N>,
) -> RLWE<Q, N> {
let P = PQ / Q;
fn relinearize(rlk: &RLK, c0: &Rq, c1: &Rq, c2: &Rq) -> RLWE {
let pq = rlk.0.param.q;
let param = c0.param;
let q = param.q;
let p = pq / q;
let c2rlk0: Vec<f64> = (c2.to_r() * rlk.0.to_r())
let c2rlk0: Vec<f64> = (c2.clone().to_r() * rlk.0.clone().to_r())
.coeffs() .coeffs()
.iter() .iter()
.map(|e| (*e as f64 / P as f64).round())
.map(|e| (*e as f64 / p as f64).round())
.collect(); .collect();
let c2rlk1: Vec<f64> = (c2.to_r() * rlk.1.to_r())
let c2rlk1: Vec<f64> = (c2.clone().to_r() * rlk.1.clone().to_r()) // TODO rm clones
.coeffs() .coeffs()
.iter() .iter()
.map(|e| (*e as f64 / P as f64).round())
.map(|e| (*e as f64 / p as f64).round())
.collect(); .collect();
let r0 = Rq::<Q, N>::from_vec_f64(c2rlk0);
let r1 = Rq::<Q, N>::from_vec_f64(c2rlk1);
let r0 = Rq::from_vec_f64(&param, c2rlk0);
let r1 = Rq::from_vec_f64(&param, c2rlk1);
let res = RLWE::<Q, N>(c0 + &r0, c1 + &r1);
let res = RLWE(c0 + &r0, c1 + &r1);
res res
} }
fn relinearize_204<const PQ: u64>(
rlk: &RLK<PQ, N>,
c0: &Rq<Q, N>,
c1: &Rq<Q, N>,
c2: &Rq<Q, N>,
) -> RLWE<Q, N> {
let P = PQ / Q;
fn relinearize_204(rlk: &RLK, c0: &Rq, c1: &Rq, c2: &Rq) -> RLWE {
let pq = rlk.0.param.q;
let q = c0.param.q;
let p = pq / q;
let n = c0.param.n;
// TODO (in debug) check that all Ns match
// let c2rlk0: Rq<PQ, N> = c2.remodule::<PQ>() * rlk.0.remodule::<PQ>(); // let c2rlk0: Rq<PQ, N> = c2.remodule::<PQ>() * rlk.0.remodule::<PQ>();
// let c2rlk1: Rq<PQ, N> = c2.remodule::<PQ>() * rlk.1.remodule::<PQ>(); // let c2rlk1: Rq<PQ, N> = c2.remodule::<PQ>() * rlk.1.remodule::<PQ>();
@ -220,12 +261,12 @@ impl BFV {
// let r1: Rq<Q, N> = c2rlk1.mul_div_round(1, P).remodule::<Q>(); // let r1: Rq<Q, N> = c2rlk1.mul_div_round(1, P).remodule::<Q>();
use arith::ring_n::naive_mul; use arith::ring_n::naive_mul;
let c2rlk0: Vec<i64> = naive_mul(&c2.to_r(), &rlk.0.to_r());
let c2rlk1: Vec<i64> = naive_mul(&c2.to_r(), &rlk.1.to_r());
let r0: Rq<Q, N> = arith::ring_n::mul_div_round::<Q, N>(c2rlk0, 1, P);
let r1: Rq<Q, N> = arith::ring_n::mul_div_round::<Q, N>(c2rlk1, 1, P);
let c2rlk0: Vec<i64> = naive_mul(&c2.clone().to_r(), &rlk.0.clone().to_r()); // TODO rm clones
let c2rlk1: Vec<i64> = naive_mul(&c2.clone().to_r(), &rlk.1.clone().to_r());
let r0: Rq = arith::ring_n::mul_div_round(q, n, c2rlk0, 1, p);
let r1: Rq = arith::ring_n::mul_div_round(q, n, c2rlk1, 1, p);
let res = RLWE::<Q, N>(c0 + &r0, c1 + &r1);
let res = RLWE(c0 + &r0, c1 + &r1);
res res
} }
} }
@ -239,21 +280,25 @@ mod tests {
#[test] #[test]
fn test_encrypt_decrypt() -> Result<()> { fn test_encrypt_decrypt() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 512;
const T: u64 = 32; // plaintext modulus
type S = BFV<Q, N, T>;
let param = Param {
ring: RingParam {
q: 2u64.pow(16) + 1, // q prime, and 2^q + 1 shape
n: 512,
},
t: 32, // plaintext modulus
p: 0, // unused in this test
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..100 { for _ in 0..100 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let msg_dist = Uniform::new(0_u64, T);
let m = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let msg_dist = Uniform::new(0_u64, param.t);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let c = S::encrypt(&mut rng, &pk, &m)?;
let m_recovered = S::decrypt(&sk, &c);
let c = BFV::encrypt(&mut rng, &param, &pk, &m)?;
let m_recovered = BFV::decrypt(&param, &sk, &c);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
} }
@ -263,26 +308,30 @@ mod tests {
#[test] #[test]
fn test_addition() -> Result<()> { fn test_addition() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 128;
const T: u64 = 32; // plaintext modulus
type S = BFV<Q, N, T>;
let param = Param {
ring: RingParam {
q: 2u64.pow(16) + 1, // q prime, and 2^q + 1 shape
n: 128,
},
t: 32, // plaintext modulus
p: 0, // unused in this test
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..100 { for _ in 0..100 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let msg_dist = Uniform::new(0_u64, T);
let m1 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let msg_dist = Uniform::new(0_u64, param.t);
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let c1 = S::encrypt(&mut rng, &pk, &m1)?;
let c2 = S::encrypt(&mut rng, &pk, &m2)?;
let c1 = BFV::encrypt(&mut rng, &param, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &param, &pk, &m2)?;
let c3 = c1 + c2; let c3 = c1 + c2;
let m3_recovered = S::decrypt(&sk, &c3);
let m3_recovered = BFV::decrypt(&param, &sk, &c3);
assert_eq!(m1 + m2, m3_recovered); assert_eq!(m1 + m2, m3_recovered);
} }
@ -292,211 +341,208 @@ mod tests {
#[test] #[test]
fn test_constant_add_mul() -> Result<()> { fn test_constant_add_mul() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 16;
const T: u64 = 8; // plaintext modulus
type S = BFV<Q, N, T>;
let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
let param = Param {
ring: RingParam { q, n: 16 },
t: 8, // plaintext modulus
p: q * q,
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let msg_dist = Uniform::new(0_u64, T);
let m1 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m2_const = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let c1 = S::encrypt(&mut rng, &pk, &m1)?;
let msg_dist = Uniform::new(0_u64, param.t);
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2_const = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let c1 = BFV::encrypt(&mut rng, &param, &pk, &m1)?;
let c3_add = &c1 + &m2_const; let c3_add = &c1 + &m2_const;
let m3_add_recovered = S::decrypt(&sk, &c3_add);
assert_eq!(m1 + m2_const, m3_add_recovered);
let m3_add_recovered = BFV::decrypt(&param, &sk, &c3_add);
assert_eq!(&m1 + &m2_const, m3_add_recovered);
// test multiplication of a ciphertext by a constant // test multiplication of a ciphertext by a constant
const P: u64 = Q * Q;
const PQ: u64 = P * Q;
let rlk = BFV::<Q, N, T>::rlk_key::<PQ>(&mut rng, &sk)?;
let rlk = BFV::rlk_key(&mut rng, &param, &sk)?;
let c3_mul = S::mul_const(&rlk, &c1, &m2_const);
let c3_mul = BFV::mul_const(&rlk, &c1, &m2_const);
let m3_mul_recovered = S::decrypt(&sk, &c3_mul);
let m3_mul_recovered = BFV::decrypt(&param, &sk, &c3_mul);
assert_eq!( assert_eq!(
(m1.to_r() * m2_const.to_r()).to_rq::<T>().coeffs(),
(m1.to_r() * m2_const.to_r()).to_rq(param.t).coeffs(),
m3_mul_recovered.coeffs() m3_mul_recovered.coeffs()
); );
Ok(()) Ok(())
} }
// TMP WIP
#[test]
#[ignore]
fn test_params() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
const N: usize = 32;
const T: u64 = 8; // plaintext modulus
const P: u64 = Q * Q;
const PQ: u64 = P * Q;
const DELTA: u64 = Q / T; // floor
let mut rng = rand::thread_rng();
let Xi_key = Uniform::new(0_f64, 1_f64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let s = Rq::<Q, N>::rand_f64(&mut rng, Xi_key)?;
let e = Rq::<Q, N>::rand_f64(&mut rng, Xi_err)?;
let u = Rq::<Q, N>::rand_f64(&mut rng, Xi_key)?;
let e_0 = Rq::<Q, N>::rand_f64(&mut rng, Xi_err)?;
let e_1 = Rq::<Q, N>::rand_f64(&mut rng, Xi_err)?;
let m = Rq::<Q, N>::rand_u64(&mut rng, Uniform::new(0_u64, T))?;
// v_fresh
let v: Rq<Q, N> = u * e + e_1 * s + e_0;
let q: f64 = Q as f64;
let t: f64 = T as f64;
let n: f64 = N as f64;
let delta: f64 = DELTA as f64;
// r_t(q)/t should be equal to q/t-Δ
assert_eq!(
// r_t(q)/t, where r_t(q)=q mod t
(q % t) / t,
// Δt/Q = q - r_t(Q)/Q, so r_t(Q)=q - Δt
(q / t) - delta
);
let rt: f64 = (q % t) / t;
dbg!(&rt);
dbg!(v.infinity_norm());
let bound: f64 = (q / (2_f64 * t)) - (rt / 2_f64);
dbg!(bound);
assert!((v.infinity_norm() as f64) < bound);
let max_v_infnorm = bound - 1.0;
// addition noise
let v_add: Rq<Q, N> = v + v + u * rt;
let v_add: Rq<Q, N> = v_add + v_add + u * rt;
assert!((v_add.infinity_norm() as f64) < bound);
// multiplication noise
let (_, pk) = BFV::<Q, N, T>::new_key(&mut rng)?;
let c = BFV::<Q, N, T>::encrypt(&mut rng, &pk, &m.remodule::<T>())?;
let b_key: f64 = 1_f64;
// ef: expansion factor
let ef: f64 = 2.0 * n.sqrt();
let bound: f64 = ((ef * t) / 2.0)
* ((2.0 * max_v_infnorm * max_v_infnorm) / q
+ (4.0 + ef * b_key) * (max_v_infnorm + max_v_infnorm)
+ rt * (ef * b_key + 5.0))
+ (1.0 + ef * b_key + ef * ef * b_key * b_key) / 2.0;
dbg!(&bound);
let k: Vec<f64> = (c.0 + c.1 * s - m * delta - v)
.coeffs()
.iter()
.map(|e_i| e_i.0 as f64 / q)
.collect();
let k = Rq::<Q, N>::from_vec_f64(k);
let v_tensor_0 = (v * v)
.coeffs()
.iter()
.map(|e_i| (e_i.0 as f64 * t) / q)
.collect::<Vec<f64>>();
let v_tensor_0 = Rq::<Q, N>::from_vec_f64(v_tensor_0);
let v_tensor_1 = ((m * v) + (m * v))
.coeffs()
.iter()
.map(|e_i| (e_i.0 as f64 * t * delta) / q)
.collect::<Vec<f64>>();
let v_tensor_1 = Rq::<Q, N>::from_vec_f64(v_tensor_1);
let v_tensor_2: Rq<Q, N> = (v * k + v * k) * t;
let rm: f64 = (ef * t) / 2.0;
let rm: Rq<Q, N> = Rq::<Q, N>::from_vec_f64(vec![rm; N]);
let v_tensor_3: Rq<Q, N> = (m * k
+ m * k
+ rm
+ Rq::from_vec_f64(
((m * m) * DELTA)
.coeffs()
.iter()
.map(|e_i| e_i.0 as f64 / q)
.collect::<Vec<f64>>(),
))
* rt;
let v_tensor = v_tensor_0 + v_tensor_1 + v_tensor_2 - v_tensor_3;
let v_r = (1.0 + ef * b_key + ef * ef * b_key * b_key) / 2.0;
let v_mult_norm = v_tensor.infinity_norm() as f64 + v_r;
dbg!(&v_mult_norm);
dbg!(&bound);
assert!(v_mult_norm < bound);
// let m1 = Rq::<T, N>::zero();
// let m2 = Rq::<T, N>::zero();
// let (_, pk) = BFV::<Q, N, T>::new_key(&mut rng)?;
// let c1 = BFV::<Q, N, T>::encrypt(&mut rng, &pk, &m1)?;
// let c2 = BFV::<Q, N, T>::encrypt(&mut rng, &pk, &m2)?;
// let (c_a, c_b, c_c) = RLWE::<Q, N>::tensor::<PQ, T>(&c1, &c2);
// dbg!(&c_a.infinity_norm());
// dbg!(&c_b.infinity_norm());
// dbg!(&c_c.infinity_norm());
// assert!((c_a.infinity_norm() as f64) < bound);
// assert!((c_b.infinity_norm() as f64) < bound);
// assert!((c_c.infinity_norm() as f64) < bound);
// WIP
Ok(())
}
/*
// TMP WIP
#[test]
#[ignore]
fn test_param() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
const N: usize = 32;
const T: u64 = 8; // plaintext modulus
const P: u64 = Q * Q;
const PQ: u64 = P * Q;
const DELTA: u64 = Q / T; // floor
let mut rng = rand::thread_rng();
let Xi_key = Uniform::new(0_f64, 1_f64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let s = Rq::rand_f64(&mut rng, Xi_key)?;
let e = Rq::rand_f64(&mut rng, Xi_err)?;
let u = Rq::rand_f64(&mut rng, Xi_key)?;
let e_0 = Rq::rand_f64(&mut rng, Xi_err)?;
let e_1 = Rq::rand_f64(&mut rng, Xi_err)?;
let m = Rq::rand_u64(&mut rng, Uniform::new(0_u64, T))?;
// v_fresh
let v: Rq<Q, N> = u * e + e_1 * s + e_0;
let q: f64 = Q as f64;
let t: f64 = T as f64;
let n: f64 = N as f64;
let delta: f64 = DELTA as f64;
// r_t(q)/t should be equal to q/t-Δ
assert_eq!(
// r_t(q)/t, where r_t(q)=q mod t
(q % t) / t,
// Δt/Q = q - r_t(Q)/Q, so r_t(Q)=q - Δt
(q / t) - delta
);
let rt: f64 = (q % t) / t;
dbg!(&rt);
dbg!(v.infinity_norm());
let bound: f64 = (q / (2_f64 * t)) - (rt / 2_f64);
dbg!(bound);
assert!((v.infinity_norm() as f64) < bound);
let max_v_infnorm = bound - 1.0;
// addition noise
let v_add: Rq<Q, N> = v + v + u * rt;
let v_add: Rq<Q, N> = v_add + v_add + u * rt;
assert!((v_add.infinity_norm() as f64) < bound);
// multiplication noise
let (_, pk) = BFV::<Q, N, T>::new_key(&mut rng)?;
let c = BFV::<Q, N, T>::encrypt(&mut rng, &pk, &m.remodule::<T>())?;
let b_key: f64 = 1_f64;
// ef: expansion factor
let ef: f64 = 2.0 * n.sqrt();
let bound: f64 = ((ef * t) / 2.0)
* ((2.0 * max_v_infnorm * max_v_infnorm) / q
+ (4.0 + ef * b_key) * (max_v_infnorm + max_v_infnorm)
+ rt * (ef * b_key + 5.0))
+ (1.0 + ef * b_key + ef * ef * b_key * b_key) / 2.0;
dbg!(&bound);
let k: Vec<f64> = (c.0 + c.1 * s - m * delta - v)
.coeffs()
.iter()
.map(|e_i| e_i.0 as f64 / q)
.collect();
let k = Rq::from_vec_f64(k);
let v_tensor_0 = (v * v)
.coeffs()
.iter()
.map(|e_i| (e_i.0 as f64 * t) / q)
.collect::<Vec<f64>>();
let v_tensor_0 = Rq::from_vec_f64(v_tensor_0);
let v_tensor_1 = ((m * v) + (m * v))
.coeffs()
.iter()
.map(|e_i| (e_i.0 as f64 * t * delta) / q)
.collect::<Vec<f64>>();
let v_tensor_1 = Rq::from_vec_f64(v_tensor_1);
let v_tensor_2: Rq<Q, N> = (v * k + v * k) * t;
let rm: f64 = (ef * t) / 2.0;
let rm: Rq<Q, N> = Rq::from_vec_f64(vec![rm; N]);
let v_tensor_3: Rq<Q, N> = (m * k
+ m * k
+ rm
+ Rq::from_vec_f64(
((m * m) * DELTA)
.coeffs()
.iter()
.map(|e_i| e_i.0 as f64 / q)
.collect::<Vec<f64>>(),
))
* rt;
let v_tensor = v_tensor_0 + v_tensor_1 + v_tensor_2 - v_tensor_3;
let v_r = (1.0 + ef * b_key + ef * ef * b_key * b_key) / 2.0;
let v_mult_norm = v_tensor.infinity_norm() as f64 + v_r;
dbg!(&v_mult_norm);
dbg!(&bound);
assert!(v_mult_norm < bound);
// let m1 = Rq::<T, N>::zero();
// let m2 = Rq::<T, N>::zero();
// let (_, pk) = BFV::<Q, N, T>::new_key(&mut rng)?;
// let c1 = BFV::<Q, N, T>::encrypt(&mut rng, &pk, &m1)?;
// let c2 = BFV::<Q, N, T>::encrypt(&mut rng, &pk, &m2)?;
// let (c_a, c_b, c_c) = RLWE::tensor::<PQ, T>(&c1, &c2);
// dbg!(&c_a.infinity_norm());
// dbg!(&c_b.infinity_norm());
// dbg!(&c_c.infinity_norm());
// assert!((c_a.infinity_norm() as f64) < bound);
// assert!((c_b.infinity_norm() as f64) < bound);
// assert!((c_c.infinity_norm() as f64) < bound);
// WIP
Ok(())
}
*/
#[test] #[test]
fn test_tensor() -> Result<()> { fn test_tensor() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
const N: usize = 16;
const T: u64 = 2; // plaintext modulus
// const P: u64 = Q;
const P: u64 = Q * Q;
const PQ: u64 = P * Q;
let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
let param = Param {
ring: RingParam { q, n: 16 },
t: 2, // plaintext modulus
p: q * q,
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..1_000 { for _ in 0..1_000 {
let m1 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
test_tensor_opt::<Q, N, T, PQ>(&mut rng, m1, m2)?;
test_tensor_opt(&mut rng, &param, m1, m2)?;
} }
Ok(()) Ok(())
} }
fn test_tensor_opt<const Q: u64, const N: usize, const T: u64, const PQ: u64>(
mut rng: impl Rng,
m1: Rq<T, N>,
m2: Rq<T, N>,
) -> Result<()> {
let (sk, pk) = BFV::<Q, N, T>::new_key(&mut rng)?;
fn test_tensor_opt(mut rng: impl Rng, param: &Param, m1: Rq, m2: Rq) -> Result<()> {
let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let c1 = BFV::<Q, N, T>::encrypt(&mut rng, &pk, &m1)?;
let c2 = BFV::<Q, N, T>::encrypt(&mut rng, &pk, &m2)?;
let c1 = BFV::encrypt(&mut rng, &param, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &param, &pk, &m2)?;
let (c_a, c_b, c_c) = RLWE::<Q, N>::tensor::<PQ, T>(&c1, &c2);
// let (c_a, c_b, c_c) = RLWE::<Q, N>::tensor_new::<PQ, T>(&c1, &c2);
let (c_a, c_b, c_c) = RLWE::tensor(param.t, &c1, &c2);
// let (c_a, c_b, c_c) = RLWE::tensor_new::<PQ, T>(&c1, &c2);
// decrypt non-relinearized mul result // decrypt non-relinearized mul result
let m3: Rq<Q, N> = c_a + c_b * sk.0 + c_c * sk.0 * sk.0;
let m3: Rq = c_a + &c_b * &sk.0 + &c_c * &(&sk.0 * &sk.0);
// let m3: Rq<Q, N> = c_a // let m3: Rq<Q, N> = c_a
// + Rq::<Q, N>::from_vec_i64(arith::ring_n::naive_mul(&c_b.to_r(), &sk.0.to_r()))
// + Rq::<Q, N>::from_vec_i64(arith::ring_n::naive_mul(
// + Rq::from_vec_i64(arith::ring_n::naive_mul(&c_b.to_r(), &sk.0.to_r()))
// + Rq::from_vec_i64(arith::ring_n::naive_mul(
// &c_c.to_r(), // &c_c.to_r(),
// &R::<N>::from_vec(arith::ring_n::naive_mul(&sk.0.to_r(), &sk.0.to_r())), // &R::<N>::from_vec(arith::ring_n::naive_mul(&sk.0.to_r(), &sk.0.to_r())),
// )); // ));
let m3: Rq<Q, N> = m3.mul_div_round(T, Q); // descale
let m3 = m3.remodule::<T>();
let m3: Rq = m3.mul_div_round(param.t, param.ring.q); // descale
let m3 = m3.remodule(param.t);
let naive = (m1.to_r() * m2.to_r()).to_rq::<T>();
let naive = (m1.clone().to_r() * m2.clone().to_r()).to_rq(param.t); // TODO rm clones
assert_eq!( assert_eq!(
m3.coeffs().to_vec(), m3.coeffs().to_vec(),
naive.coeffs().to_vec(), naive.coeffs().to_vec(),
@ -510,44 +556,39 @@ mod tests {
#[test] #[test]
fn test_mul_relin() -> Result<()> { fn test_mul_relin() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 16;
const T: u64 = 2; // plaintext modulus
type S = BFV<Q, N, T>;
const P: u64 = Q * Q;
const PQ: u64 = P * Q;
let q: u64 = 2u64.pow(16) + 1; // q prime, and 2^q + 1 shape
let param = Param {
ring: RingParam { q, n: 16 },
t: 2, // plaintext modulus
p: q * q,
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..1_000 { for _ in 0..1_000 {
let m1 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
test_mul_relin_opt::<Q, N, T, PQ>(&mut rng, m1, m2)?;
test_mul_relin_opt(&mut rng, &param, m1, m2)?;
} }
Ok(()) Ok(())
} }
fn test_mul_relin_opt<const Q: u64, const N: usize, const T: u64, const PQ: u64>(
mut rng: impl Rng,
m1: Rq<T, N>,
m2: Rq<T, N>,
) -> Result<()> {
let (sk, pk) = BFV::<Q, N, T>::new_key(&mut rng)?;
fn test_mul_relin_opt(mut rng: impl Rng, param: &Param, m1: Rq, m2: Rq) -> Result<()> {
let (sk, pk) = BFV::new_key(&mut rng, &param)?;
let rlk = BFV::<Q, N, T>::rlk_key::<PQ>(&mut rng, &sk)?;
let rlk = BFV::rlk_key(&mut rng, &param, &sk)?;
let c1 = BFV::<Q, N, T>::encrypt(&mut rng, &pk, &m1)?;
let c2 = BFV::<Q, N, T>::encrypt(&mut rng, &pk, &m2)?;
let c1 = BFV::encrypt(&mut rng, &param, &pk, &m1)?;
let c2 = BFV::encrypt(&mut rng, &param, &pk, &m2)?;
let c3 = RLWE::<Q, N>::mul::<PQ, T>(&rlk, &c1, &c2); // uses relinearize internally
let c3 = RLWE::mul(param.t, &rlk, &c1, &c2); // uses relinearize internally
let m3 = BFV::<Q, N, T>::decrypt(&sk, &c3);
let m3 = BFV::decrypt(&param, &sk, &c3);
let naive = (m1.to_r() * m2.to_r()).to_rq::<T>();
let naive = (m1.clone().to_r() * m2.clone().to_r()).to_rq(param.t); // TODO rm clones
assert_eq!( assert_eq!(
m3.coeffs().to_vec(), m3.coeffs().to_vec(),
naive.coeffs().to_vec(), naive.coeffs().to_vec(),

+ 17
- 14
ckks/src/encoder.rs

@ -1,14 +1,15 @@
use anyhow::Result; use anyhow::Result;
use arith::{Matrix, Ring, Rq, C, R};
use arith::{Matrix, Rq, C, R};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct SecretKey<const Q: u64, const N: usize>(Rq<Q, N>);
pub struct SecretKey(Rq);
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct PublicKey<const Q: u64, const N: usize>(Rq<Q, N>, Rq<Q, N>);
pub struct PublicKey(Rq, Rq);
pub struct Encoder<const Q: u64, const N: usize> {
pub struct Encoder {
n: usize,
scale_factor: C<f64>, // Δ (delta) scale_factor: C<f64>, // Δ (delta)
primitive: C<f64>, primitive: C<f64>,
basis: Matrix<C<f64>>, basis: Matrix<C<f64>>,
@ -34,13 +35,14 @@ fn vandermonde(n: usize, w: C) -> Matrix> {
} }
Matrix::<C<f64>>(v) Matrix::<C<f64>>(v)
} }
impl<const Q: u64, const N: usize> Encoder<Q, N> {
pub fn new(scale_factor: C<f64>) -> Self {
let primitive: C<f64> = primitive_root_of_unity(2 * N);
let basis = vandermonde(N, primitive);
impl Encoder {
pub fn new(n: usize, scale_factor: C<f64>) -> Self {
let primitive: C<f64> = primitive_root_of_unity(2 * n);
let basis = vandermonde(n, primitive);
let basis_t = basis.transpose(); let basis_t = basis.transpose();
Self { Self {
n,
scale_factor, scale_factor,
primitive, primitive,
basis, basis,
@ -52,7 +54,7 @@ impl Encoder {
/// from $\mathbb{C}^{N/2} \longrightarrow \mathbb{Z_q}[X]/(X^N +1) = R$ /// from $\mathbb{C}^{N/2} \longrightarrow \mathbb{Z_q}[X]/(X^N +1) = R$
// TODO use alg.1 from 2018-1043, // TODO use alg.1 from 2018-1043,
// or as in 2018-1073: $f(x) = 1N (U^T.conj() m + U^T m.conj())$ // or as in 2018-1073: $f(x) = 1N (U^T.conj() m + U^T m.conj())$
pub fn encode(&self, z: &[C<f64>]) -> Result<R<N>> {
pub fn encode(&self, z: &[C<f64>]) -> Result<R> {
// $pi^{-1}: \mathbb{C}^{N/2} \longrightarrow \mathbb{H}$ // $pi^{-1}: \mathbb{C}^{N/2} \longrightarrow \mathbb{H}$
let expanded = self.pi_inv(z); let expanded = self.pi_inv(z);
@ -93,10 +95,10 @@ impl Encoder {
// TMP: naive round, maybe do gaussian // TMP: naive round, maybe do gaussian
let coeffs = r.iter().map(|e| e.re.round() as i64).collect::<Vec<i64>>(); let coeffs = r.iter().map(|e| e.re.round() as i64).collect::<Vec<i64>>();
Ok(R::from_vec(coeffs))
Ok(R::from_vec(self.n, coeffs))
} }
pub fn decode(&self, p: &R<N>) -> Result<Vec<C<f64>>> {
pub fn decode(&self, p: &R) -> Result<Vec<C<f64>>> {
let p: Vec<C<f64>> = p let p: Vec<C<f64>> = p
.coeffs() .coeffs()
.iter() .iter()
@ -110,7 +112,7 @@ impl Encoder {
/// pi: \mathbb{H} \longrightarrow \mathbb{C}^{N/2} /// pi: \mathbb{H} \longrightarrow \mathbb{C}^{N/2}
fn pi(&self, z: &[C<f64>]) -> Vec<C<f64>> { fn pi(&self, z: &[C<f64>]) -> Vec<C<f64>> {
z[..N / 2].to_vec()
z[..self.n / 2].to_vec()
} }
/// pi^{-1}: \mathbb{C}^{N/2} \longrightarrow \mathbb{H} /// pi^{-1}: \mathbb{C}^{N/2} \longrightarrow \mathbb{H}
fn pi_inv(&self, z: &[C<f64>]) -> Vec<C<f64>> { fn pi_inv(&self, z: &[C<f64>]) -> Vec<C<f64>> {
@ -154,6 +156,7 @@ mod tests {
fn test_encode_decode() -> Result<()> { fn test_encode_decode() -> Result<()> {
const Q: u64 = 1024; const Q: u64 = 1024;
const N: usize = 32; const N: usize = 32;
let n: usize = 32;
let T = 128; // WIP let T = 128; // WIP
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
@ -166,9 +169,9 @@ mod tests {
.collect(); .collect();
let delta = C::<f64>::new(64.0, 0.0); // delta = scaling factor let delta = C::<f64>::new(64.0, 0.0); // delta = scaling factor
let encoder = Encoder::<Q, N>::new(delta);
let encoder = Encoder::new(n, delta);
let m: R<N> = encoder.encode(&z)?; // polynomial (encoded vec) \in R
let m: R = encoder.encode(&z)?; // polynomial (encoded vec) \in R
let z_decoded = encoder.decode(&m)?; let z_decoded = encoder.decode(&m)?;

+ 98
- 79
ckks/src/lib.rs

@ -5,7 +5,7 @@
#![allow(clippy::upper_case_acronyms)] #![allow(clippy::upper_case_acronyms)]
#![allow(dead_code)] // TMP #![allow(dead_code)] // TMP
use arith::{Rq, C, R};
use arith::{RingParam, Rq, C, R};
use anyhow::Result; use anyhow::Result;
use rand::Rng; use rand::Rng;
@ -18,35 +18,47 @@ pub use encoder::Encoder;
// sigma=3.2 from: https://eprint.iacr.org/2016/421.pdf page 17 // sigma=3.2 from: https://eprint.iacr.org/2016/421.pdf page 17
const ERR_SIGMA: f64 = 3.2; const ERR_SIGMA: f64 = 3.2;
#[derive(Clone, Copy, Debug)]
pub struct Param {
ring: RingParam,
t: u64,
}
#[derive(Debug)] #[derive(Debug)]
pub struct PublicKey<const Q: u64, const N: usize>(Rq<Q, N>, Rq<Q, N>);
pub struct PublicKey(Rq, Rq);
pub struct SecretKey<const Q: u64, const N: usize>(Rq<Q, N>);
pub struct SecretKey(Rq);
pub struct CKKS<const Q: u64, const N: usize> {
encoder: Encoder<Q, N>,
pub struct CKKS {
param: Param,
encoder: Encoder,
} }
impl<const Q: u64, const N: usize> CKKS<Q, N> {
pub fn new(delta: C<f64>) -> Self {
let encoder = Encoder::<Q, N>::new(delta);
Self { encoder }
impl CKKS {
pub fn new(param: &Param, delta: C<f64>) -> Self {
let encoder = Encoder::new(param.ring.n, delta);
Self {
param: param.clone(),
encoder,
}
} }
/// generate a new key pair (privK, pubK) /// generate a new key pair (privK, pubK)
pub fn new_key(&self, mut rng: impl Rng) -> Result<(SecretKey<Q, N>, PublicKey<Q, N>)> {
pub fn new_key(&self, mut rng: impl Rng) -> Result<(SecretKey, PublicKey)> {
let param = &self.param;
let Xi_key = Uniform::new(-1_f64, 1_f64); let Xi_key = Uniform::new(-1_f64, 1_f64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let e = Rq::<Q, N>::rand_f64(&mut rng, Xi_err)?;
let e = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let mut s = Rq::<Q, N>::rand_f64(&mut rng, Xi_key)?;
let mut s = Rq::rand_f64(&mut rng, Xi_key, &param.ring)?;
// since s is going to be multiplied by other Rq elements, already // since s is going to be multiplied by other Rq elements, already
// compute its NTT // compute its NTT
s.compute_evals(); s.compute_evals();
let a = Rq::<Q, N>::rand_f64(&mut rng, Xi_key)?;
let a = Rq::rand_f64(&mut rng, Xi_key, &param.ring)?;
let pk: PublicKey<Q, N> = PublicKey((&(-a) * &s) + e, a.clone());
let pk: PublicKey = PublicKey((&(-a.clone()) * &s) + e, a.clone()); // TODO rm clones
Ok((SecretKey(s), pk)) Ok((SecretKey(s), pk))
} }
@ -54,64 +66,54 @@ impl CKKS {
fn encrypt( fn encrypt(
&self, // TODO maybe rm? &self, // TODO maybe rm?
mut rng: impl Rng, mut rng: impl Rng,
pk: &PublicKey<Q, N>,
m: &R<N>,
) -> Result<(Rq<Q, N>, Rq<Q, N>)> {
pk: &PublicKey,
m: &R,
) -> Result<(Rq, Rq)> {
let param = self.param;
let Xi_key = Uniform::new(-1_f64, 1_f64); let Xi_key = Uniform::new(-1_f64, 1_f64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?; let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let e_0 = Rq::<Q, N>::rand_f64(&mut rng, Xi_err)?;
let e_1 = Rq::<Q, N>::rand_f64(&mut rng, Xi_err)?;
let e_0 = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let e_1 = Rq::rand_f64(&mut rng, Xi_err, &param.ring)?;
let v = Rq::<Q, N>::rand_f64(&mut rng, Xi_key)?;
let v = Rq::rand_f64(&mut rng, Xi_key, &param.ring)?;
let m: Rq<Q, N> = Rq::<Q, N>::from(*m);
// let m: Rq = Rq::from(*m);
let m: Rq = m.clone().to_rq(param.ring.q); // TODO rm clone
Ok((m + e_0 + v * pk.0.clone(), v * pk.1.clone() + e_1))
Ok((m + e_0 + &v * &pk.0.clone(), &v * &pk.1 + e_1))
} }
fn decrypt( fn decrypt(
&self, // TODO maybe rm? &self, // TODO maybe rm?
sk: &SecretKey<Q, N>,
c: (Rq<Q, N>, Rq<Q, N>),
) -> Result<R<N>> {
let m = c.0.clone() + c.1 * sk.0;
sk: &SecretKey,
c: (Rq, Rq),
) -> Result<R> {
let m = c.0.clone() + &c.1 * &sk.0;
Ok(m.mod_centered_q()) Ok(m.mod_centered_q())
} }
pub fn encode_and_encrypt( pub fn encode_and_encrypt(
&self, &self,
mut rng: impl Rng, mut rng: impl Rng,
pk: &PublicKey<Q, N>,
pk: &PublicKey,
z: &[C<f64>], z: &[C<f64>],
) -> Result<(Rq<Q, N>, Rq<Q, N>)> {
let m: R<N> = self.encoder.encode(&z)?; // polynomial (encoded vec) \in R
) -> Result<(Rq, Rq)> {
let m: R = self.encoder.encode(&z)?; // polynomial (encoded vec) \in R
self.encrypt(&mut rng, pk, &m) self.encrypt(&mut rng, pk, &m)
} }
pub fn decrypt_and_decode(
&self,
sk: SecretKey<Q, N>,
c: (Rq<Q, N>, Rq<Q, N>),
) -> Result<Vec<C<f64>>> {
pub fn decrypt_and_decode(&self, sk: SecretKey, c: (Rq, Rq)) -> Result<Vec<C<f64>>> {
let d = self.decrypt(&sk, c)?; let d = self.decrypt(&sk, c)?;
self.encoder.decode(&d) self.encoder.decode(&d)
} }
pub fn add(
&self,
c0: &(Rq<Q, N>, Rq<Q, N>),
c1: &(Rq<Q, N>, Rq<Q, N>),
) -> Result<(Rq<Q, N>, Rq<Q, N>)> {
pub fn add(&self, c0: &(Rq, Rq), c1: &(Rq, Rq)) -> Result<(Rq, Rq)> {
Ok((&c0.0 + &c1.0, &c0.1 + &c1.1)) Ok((&c0.0 + &c1.0, &c0.1 + &c1.1))
} }
pub fn sub(
&self,
c0: &(Rq<Q, N>, Rq<Q, N>),
c1: &(Rq<Q, N>, Rq<Q, N>),
) -> Result<(Rq<Q, N>, Rq<Q, N>)> {
pub fn sub(&self, c0: &(Rq, Rq), c1: &(Rq, Rq)) -> Result<(Rq, Rq)> {
Ok((&c0.0 - &c1.0, &c0.1 + &c1.1)) Ok((&c0.0 - &c1.0, &c0.1 + &c1.1))
} }
} }
@ -122,21 +124,26 @@ mod tests {
#[test] #[test]
fn test_encrypt_decrypt() -> Result<()> { fn test_encrypt_decrypt() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 32;
const T: u64 = 50;
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 32;
let t: u64 = 50;
let param = Param {
ring: RingParam { q, n },
t,
};
let scale_factor_u64 = 512_u64; // delta let scale_factor_u64 = 512_u64; // delta
let scale_factor = C::<f64>::new(scale_factor_u64 as f64, 0.0); // delta let scale_factor = C::<f64>::new(scale_factor_u64 as f64, 0.0); // delta
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..1000 { for _ in 0..1000 {
let ckks = CKKS::<Q, N>::new(scale_factor);
let ckks = CKKS::new(&param, scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?; let (sk, pk) = ckks.new_key(&mut rng)?;
let m_raw: R<N> = Rq::<Q, N>::rand_f64(&mut rng, Uniform::new(0_f64, T as f64))?.to_r();
let m = m_raw * scale_factor_u64;
let m_raw: R =
Rq::rand_f64(&mut rng, Uniform::new(0_f64, t as f64), &param.ring)?.to_r();
let m = &m_raw * &scale_factor_u64;
let ct = ckks.encrypt(&mut rng, &pk, &m)?; let ct = ckks.encrypt(&mut rng, &pk, &m)?;
let m_decrypted = ckks.decrypt(&sk, ct)?; let m_decrypted = ckks.decrypt(&sk, ct)?;
@ -146,8 +153,8 @@ mod tests {
.iter() .iter()
.map(|e| (*e as f64 / (scale_factor_u64 as f64)).round() as u64) .map(|e| (*e as f64 / (scale_factor_u64 as f64)).round() as u64)
.collect(); .collect();
let m_decrypted = Rq::<Q, N>::from_vec_u64(m_decrypted);
assert_eq!(m_decrypted, Rq::<Q, N>::from(m_raw));
let m_decrypted = Rq::from_vec_u64(&param.ring, m_decrypted);
assert_eq!(m_decrypted, m_raw.to_rq(q));
} }
Ok(()) Ok(())
@ -155,21 +162,25 @@ mod tests {
#[test] #[test]
fn test_encode_encrypt_decrypt_decode() -> Result<()> { fn test_encode_encrypt_decrypt_decode() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 16;
const T: u64 = 8;
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 16;
let t: u64 = 8;
let param = Param {
ring: RingParam { q, n },
t,
};
let scale_factor = C::<f64>::new(512.0, 0.0); // delta let scale_factor = C::<f64>::new(512.0, 0.0); // delta
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..1000 { for _ in 0..1000 {
let ckks = CKKS::<Q, N>::new(scale_factor);
let ckks = CKKS::new(&param, scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?; let (sk, pk) = ckks.new_key(&mut rng)?;
let z: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, T))
.take(N / 2)
let z: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, t))
.take(n / 2)
.collect(); .collect();
let m: R<N> = ckks.encoder.encode(&z)?;
let m: R = ckks.encoder.encode(&z)?;
println!("{}", m); println!("{}", m);
// sanity check // sanity check
@ -200,26 +211,30 @@ mod tests {
#[test] #[test]
fn test_add() -> Result<()> { fn test_add() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 16;
const T: u64 = 8;
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 16;
let t: u64 = 8;
let param = Param {
ring: RingParam { q, n },
t,
};
let scale_factor = C::<f64>::new(1024.0, 0.0); // delta let scale_factor = C::<f64>::new(1024.0, 0.0); // delta
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..1000 { for _ in 0..1000 {
let ckks = CKKS::<Q, N>::new(scale_factor);
let ckks = CKKS::new(&param, scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?; let (sk, pk) = ckks.new_key(&mut rng)?;
let z0: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, T))
.take(N / 2)
let z0: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, t))
.take(n / 2)
.collect(); .collect();
let z1: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, T))
.take(N / 2)
let z1: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, t))
.take(n / 2)
.collect(); .collect();
let m0: R<N> = ckks.encoder.encode(&z0)?;
let m1: R<N> = ckks.encoder.encode(&z1)?;
let m0: R = ckks.encoder.encode(&z0)?;
let m1: R = ckks.encoder.encode(&z1)?;
let ct0 = ckks.encrypt(&mut rng, &pk, &m0)?; let ct0 = ckks.encrypt(&mut rng, &pk, &m0)?;
let ct1 = ckks.encrypt(&mut rng, &pk, &m1)?; let ct1 = ckks.encrypt(&mut rng, &pk, &m1)?;
@ -243,26 +258,30 @@ mod tests {
#[test] #[test]
fn test_sub() -> Result<()> { fn test_sub() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 16;
const T: u64 = 8;
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 16;
let t: u64 = 2;
let param = Param {
ring: RingParam { q, n },
t,
};
let scale_factor = C::<f64>::new(1024.0, 0.0); // delta let scale_factor = C::<f64>::new(1024.0, 0.0); // delta
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..1000 { for _ in 0..1000 {
let ckks = CKKS::<Q, N>::new(scale_factor);
let ckks = CKKS::new(&param, scale_factor);
let (sk, pk) = ckks.new_key(&mut rng)?; let (sk, pk) = ckks.new_key(&mut rng)?;
let z0: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, T))
.take(N / 2)
let z0: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, t))
.take(n / 2)
.collect(); .collect();
let z1: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, T))
.take(N / 2)
let z1: Vec<C<f64>> = std::iter::repeat_with(|| C::<f64>::rand(&mut rng, t))
.take(n / 2)
.collect(); .collect();
let m0: R<N> = ckks.encoder.encode(&z0)?;
let m1: R<N> = ckks.encoder.encode(&z1)?;
let m0: R = ckks.encoder.encode(&z0)?;
let m1: R = ckks.encoder.encode(&z1)?;
let ct0 = ckks.encrypt(&mut rng, &pk, &m0)?; let ct0 = ckks.encrypt(&mut rng, &pk, &m0)?;
let ct1 = ckks.encrypt(&mut rng, &pk, &m1)?; let ct1 = ckks.encrypt(&mut rng, &pk, &m1)?;

+ 51
- 34
gfhe/src/glev.rs

@ -1,28 +1,33 @@
use anyhow::Result; use anyhow::Result;
use itertools::zip_eq; use itertools::zip_eq;
use rand::Rng; use rand::Rng;
use rand_distr::{Normal, Uniform};
use std::ops::{Add, Mul};
use std::ops::Mul;
use arith::{Ring, TR};
use arith::Ring;
use crate::glwe::{PublicKey, SecretKey, GLWE};
use crate::glwe::{Param, PublicKey, SecretKey, GLWE};
// l GLWEs // l GLWEs
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct GLev<R: Ring, const K: usize>(pub(crate) Vec<GLWE<R, K>>);
pub struct GLev<R: Ring>(pub(crate) Vec<GLWE<R>>);
impl<R: Ring, const K: usize> GLev<R, K> {
impl<R: Ring> GLev<R> {
pub fn encrypt( pub fn encrypt(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
beta: u32, beta: u32,
l: u32, l: u32,
pk: &PublicKey<R, K>,
pk: &PublicKey<R>,
m: &R, m: &R,
) -> Result<Self> { ) -> Result<Self> {
let glev: Vec<GLWE<R, K>> = (0..l)
let glev: Vec<GLWE<R>> = (0..l)
.map(|i| { .map(|i| {
GLWE::<R, K>::encrypt(&mut rng, pk, &(*m * (R::Q / beta.pow(i as u32) as u64)))
GLWE::<R>::encrypt(
&mut rng,
param,
pk,
&(m.clone() * (param.ring.q / beta.pow(i as u32) as u64)),
)
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
@ -30,38 +35,46 @@ impl GLev {
} }
pub fn encrypt_s( pub fn encrypt_s(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
beta: u32, beta: u32,
l: u32, l: u32,
sk: &SecretKey<R, K>,
sk: &SecretKey<R>,
m: &R, m: &R,
// delta: u64,
) -> Result<Self> { ) -> Result<Self> {
let glev: Vec<GLWE<R, K>> = (1..l + 1)
let glev: Vec<GLWE<R>> = (1..l + 1)
.map(|i| { .map(|i| {
GLWE::<R, K>::encrypt_s(&mut rng, sk, &(*m * (R::Q / beta.pow(i as u32) as u64)))
GLWE::<R>::encrypt_s(
&mut rng,
param,
sk,
&(m.clone() * (param.ring.q / beta.pow(i as u32) as u64)), // TODO rm clone
)
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
Ok(Self(glev)) Ok(Self(glev))
} }
pub fn decrypt<const T: u64>(&self, sk: &SecretKey<R, K>, beta: u32) -> R {
pub fn decrypt(&self, param: &Param, sk: &SecretKey<R>, beta: u32) -> R {
let pt = self.0[1].decrypt(sk); let pt = self.0[1].decrypt(sk);
pt.mul_div_round(beta as u64, R::Q)
pt.mul_div_round(beta as u64, param.ring.q)
} }
} }
// dot product between a GLev and Vec<R>. // dot product between a GLev and Vec<R>.
// Used for operating decompositions with KSK_i. // Used for operating decompositions with KSK_i.
// GLev * Vec<R> --> GLWE // GLev * Vec<R> --> GLWE
impl<R: Ring, const K: usize> Mul<Vec<R>> for GLev<R, K> {
type Output = GLWE<R, K>;
fn mul(self, v: Vec<R>) -> GLWE<R, K> {
impl<R: Ring> Mul<Vec<R>> for GLev<R> {
type Output = GLWE<R>;
fn mul(self, v: Vec<R>) -> GLWE<R> {
debug_assert_eq!(self.0.len(), v.len());
// TODO debug_assert_eq of param
// l times GLWES // l times GLWES
let glwes: Vec<GLWE<R, K>> = self.0;
let glwes: Vec<GLWE<R>> = self.0;
// l iterations // l iterations
let r: GLWE<R, K> = zip_eq(v, glwes).map(|(v_i, glwe_i)| glwe_i * v_i).sum();
let r: GLWE<R> = zip_eq(v, glwes).map(|(v_i, glwe_i)| glwe_i * v_i).sum();
r r
} }
} }
@ -72,33 +85,37 @@ mod tests {
use rand::distributions::Uniform; use rand::distributions::Uniform;
use super::*; use super::*;
use arith::Rq;
use arith::{RingParam, Rq};
#[test] #[test]
fn test_encrypt_decrypt() -> Result<()> { fn test_encrypt_decrypt() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 128;
const T: u64 = 2; // plaintext modulus
const K: usize = 16;
type S = GLev<Rq<Q, N>, K>;
let param = Param {
err_sigma: crate::glwe::ERR_SIGMA,
ring: RingParam {
q: 2u64.pow(16) + 1,
n: 128,
},
k: 16,
t: 2, // plaintext modulus
};
type S = GLev<Rq>;
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 16; let l: u32 = 16;
// let delta: u64 = Q / T; // floored
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = GLWE::<Rq<Q, N>, K>::new_key(&mut rng)?;
let (sk, pk) = GLWE::<Rq>::new_key(&mut rng, &param)?;
let m = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m: Rq<Q, N> = m.remodule::<Q>();
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m: Rq = m.remodule(param.ring.q);
let c = S::encrypt(&mut rng, beta, l, &pk, &m)?;
let m_recovered = c.decrypt::<T>(&sk, beta);
let c = S::encrypt(&mut rng, &param, beta, l, &pk, &m)?;
let m_recovered = c.decrypt(&param, &sk, beta);
assert_eq!(m.remodule::<T>(), m_recovered.remodule::<T>());
assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t));
} }
Ok(()) Ok(())

+ 318
- 199
gfhe/src/glwe.rs

@ -8,79 +8,128 @@ use rand_distr::{Normal, Uniform};
use std::iter::Sum; use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Sub}; use std::ops::{Add, AddAssign, Mul, Sub};
use arith::{Ring, Rq, Zq, TR};
use arith::{Ring, RingParam, Rq, Zq, TR};
use crate::glev::GLev; use crate::glev::GLev;
// const ERR_SIGMA: f64 = 3.2;
const ERR_SIGMA: f64 = 0.0; // TODO WIP
// error deviation for the Gaussian(Normal) distribution
// sigma=3.2 from: https://eprint.iacr.org/2022/162.pdf page 5
pub(crate) const ERR_SIGMA: f64 = 3.2;
#[derive(Clone, Copy, Debug)]
pub struct Param {
pub err_sigma: f64,
pub ring: RingParam,
pub k: usize,
pub t: u64,
}
impl Param {
/// returns the plaintext param
pub fn pt(&self) -> RingParam {
// TODO think if maybe return a new truct "PtParam" to differenciate
// between the ciphertexxt (RingParam) and the plaintext param. Maybe it
// can be just a wrapper on top of RingParam.
RingParam {
q: self.t,
n: self.ring.n,
}
}
/// returns the LWE param for the given GLWE (self), that is, it uses k=K*N
/// as the length for the secret key. This follows [2018-421] where
/// TLWE sk: s \in B^n , where n=K*N
/// TRLWE sk: s \in B_N[X]^K
pub fn lwe(&self) -> Self {
Self {
err_sigma: ERR_SIGMA,
ring: RingParam {
q: self.ring.q,
n: 1,
},
k: self.k * self.ring.n,
t: self.t,
}
}
}
/// GLWE implemented over the `Ring` trait, so that it can be also instantiated /// GLWE implemented over the `Ring` trait, so that it can be also instantiated
/// over the Torus polynomials 𝕋_<N,q>[X] = 𝕋_q[X]/ (X^N+1). /// over the Torus polynomials 𝕋_<N,q>[X] = 𝕋_q[X]/ (X^N+1).
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct GLWE<R: Ring, const K: usize>(pub TR<R, K>, pub R);
pub struct GLWE<R: Ring>(pub TR<R>, pub R);
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct SecretKey<R: Ring, const K: usize>(pub TR<R, K>);
pub struct SecretKey<R: Ring>(pub TR<R>);
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct PublicKey<R: Ring, const K: usize>(pub R, pub TR<R, K>);
pub struct PublicKey<R: Ring>(pub R, pub TR<R>);
// K GLevs, each KSK_i=l GLWEs // K GLevs, each KSK_i=l GLWEs
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct KSK<R: Ring, const K: usize>(Vec<GLev<R, K>>);
pub struct KSK<R: Ring>(Vec<GLev<R>>);
impl<R: Ring, const K: usize> GLWE<R, K> {
pub fn zero() -> Self {
Self(TR::zero(), R::zero())
impl<R: Ring> GLWE<R> {
pub fn zero(k: usize, param: &RingParam) -> Self {
Self(TR::zero(k, &param), R::zero(&param))
} }
pub fn from_plaintext(p: R) -> Self {
Self(TR::zero(), p)
pub fn from_plaintext(k: usize, param: &RingParam, p: R) -> Self {
Self(TR::zero(k, &param), p)
} }
pub fn new_key(mut rng: impl Rng) -> Result<(SecretKey<R, K>, PublicKey<R, K>)> {
pub fn new_key(mut rng: impl Rng, param: &Param) -> Result<(SecretKey<R>, PublicKey<R>)> {
let Xi_key = Uniform::new(0_f64, 2_f64); let Xi_key = Uniform::new(0_f64, 2_f64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let s: TR<R, K> = TR::rand(&mut rng, Xi_key);
let a: TR<R, K> = TR::rand(&mut rng, Uniform::new(0_f64, R::Q as f64));
let e = R::rand(&mut rng, Xi_err);
let pk: PublicKey<R, K> = PublicKey((&a * &s) + e, a);
let Xi_err = Normal::new(0_f64, param.err_sigma)?;
let s: TR<R> = TR::rand(&mut rng, Xi_key, param.k, &param.ring);
let a: TR<R> = TR::rand(
&mut rng,
Uniform::new(0_f64, param.ring.q as f64),
param.k,
&param.ring,
);
let e = R::rand(&mut rng, Xi_err, &param.ring);
let pk: PublicKey<R> = PublicKey((&a * &s) + e, a);
Ok((SecretKey(s), pk)) Ok((SecretKey(s), pk))
} }
pub fn pk_from_sk(mut rng: impl Rng, sk: SecretKey<R, K>) -> Result<PublicKey<R, K>> {
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let a: TR<R, K> = TR::rand(&mut rng, Uniform::new(0_f64, R::Q as f64));
let e = R::rand(&mut rng, Xi_err);
let pk: PublicKey<R, K> = PublicKey((&a * &sk.0) + e, a);
pub fn pk_from_sk(mut rng: impl Rng, param: &Param, sk: SecretKey<R>) -> Result<PublicKey<R>> {
let Xi_err = Normal::new(0_f64, param.err_sigma)?;
let a: TR<R> = TR::rand(
&mut rng,
Uniform::new(0_f64, param.ring.q as f64),
param.k,
&param.ring,
);
let e = R::rand(&mut rng, Xi_err, &param.ring);
let pk: PublicKey<R> = PublicKey((&a * &sk.0) + e, a);
Ok(pk) Ok(pk)
} }
pub fn new_ksk( pub fn new_ksk(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
beta: u32, beta: u32,
l: u32, l: u32,
sk: &SecretKey<R, K>,
new_sk: &SecretKey<R, K>,
) -> Result<KSK<R, K>> {
let r: Vec<GLev<R, K>> = (0..K)
sk: &SecretKey<R>,
new_sk: &SecretKey<R>,
) -> Result<KSK<R>> {
debug_assert_eq!(param.k, sk.0.k);
let k = sk.0.k;
let r: Vec<GLev<R>> = (0..k)
.into_iter() .into_iter()
.map(|i| .map(|i|
// treat sk_i as the msg being encrypted // treat sk_i as the msg being encrypted
GLev::<R, K>::encrypt_s(&mut rng, beta, l, &new_sk, &sk.0 .0[i]))
GLev::<R>::encrypt_s(&mut rng, param, beta, l, &new_sk, &sk.0 .r[i]))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
Ok(KSK(r)) Ok(KSK(r))
} }
pub fn key_switch(&self, beta: u32, l: u32, ksk: &KSK<R, K>) -> Self {
let (a, b): (TR<R, K>, R) = (self.0.clone(), self.1);
pub fn key_switch(&self, param: &Param, beta: u32, l: u32, ksk: &KSK<R>) -> Self {
let (a, b): (TR<R>, R) = (self.0.clone(), self.1.clone()); // TODO rm clones
let lhs: GLWE<R, K> = GLWE(TR::zero(), b);
let lhs: GLWE<R> = GLWE(TR::zero(param.k, &param.ring), b);
// K iterations, ksk.0 contains K times GLev // K iterations, ksk.0 contains K times GLev
let rhs: GLWE<R, K> = zip_eq(a.0, ksk.0.clone())
let rhs: GLWE<R> = zip_eq(a.r, ksk.0.clone())
.map(|(a_i, ksk_i)| ksk_i * a_i.decompose(beta, l)) // dot_product .map(|(a_i, ksk_i)| ksk_i * a_i.decompose(beta, l)) // dot_product
.sum(); .sum();
@ -90,121 +139,141 @@ impl GLWE {
// encrypts with the given SecretKey (instead of PublicKey) // encrypts with the given SecretKey (instead of PublicKey)
pub fn encrypt_s( pub fn encrypt_s(
mut rng: impl Rng, mut rng: impl Rng,
sk: &SecretKey<R, K>,
param: &Param,
sk: &SecretKey<R>,
m: &R, // already scaled m: &R, // already scaled
) -> Result<Self> { ) -> Result<Self> {
let Xi_key = Uniform::new(0_f64, 2_f64); let Xi_key = Uniform::new(0_f64, 2_f64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let Xi_err = Normal::new(0_f64, param.err_sigma)?;
let a: TR<R, K> = TR::rand(&mut rng, Xi_key);
let e = R::rand(&mut rng, Xi_err);
let a: TR<R> = TR::rand(&mut rng, Xi_key, param.k, &param.ring);
let e = R::rand(&mut rng, Xi_err, &param.ring);
let b: R = (&a * &sk.0) + *m + e;
let b: R = (&a * &sk.0) + m.clone() + e; // TODO rm clone
Ok(Self(a, b)) Ok(Self(a, b))
} }
pub fn encrypt( pub fn encrypt(
mut rng: impl Rng, mut rng: impl Rng,
pk: &PublicKey<R, K>,
param: &Param,
pk: &PublicKey<R>,
m: &R, // already scaled m: &R, // already scaled
) -> Result<Self> { ) -> Result<Self> {
let Xi_key = Uniform::new(0_f64, 2_f64); let Xi_key = Uniform::new(0_f64, 2_f64);
let Xi_err = Normal::new(0_f64, ERR_SIGMA)?;
let Xi_err = Normal::new(0_f64, param.err_sigma)?;
let u: R = R::rand(&mut rng, Xi_key);
let u: R = R::rand(&mut rng, Xi_key, &param.ring);
let e0 = R::rand(&mut rng, Xi_err);
let e1 = TR::<R, K>::rand(&mut rng, Xi_err);
let e0 = R::rand(&mut rng, Xi_err, &param.ring);
let e1 = TR::<R>::rand(&mut rng, Xi_err, param.k, &param.ring);
let b: R = pk.0.clone() * u.clone() + *m + e0;
let d: TR<R, K> = &pk.1 * &u + e1;
let b: R = pk.0.clone() * u.clone() + m.clone() + e0; // TODO rm clones
let d: TR<R> = &pk.1 * &u + e1;
Ok(Self(d, b)) Ok(Self(d, b))
} }
// returns m' not downscaled // returns m' not downscaled
pub fn decrypt(&self, sk: &SecretKey<R, K>) -> R {
let (d, b): (TR<R, K>, R) = (self.0.clone(), self.1);
pub fn decrypt(&self, sk: &SecretKey<R>) -> R {
let (d, b): (TR<R>, R) = (self.0.clone(), self.1.clone());
let p: R = b - &d * &sk.0; let p: R = b - &d * &sk.0;
p p
} }
} }
// Methods for when Ring=Rq<Q,N> // Methods for when Ring=Rq<Q,N>
impl<const Q: u64, const N: usize, const K: usize> GLWE<Rq<Q, N>, K> {
impl GLWE<Rq> {
// scale up // scale up
pub fn encode<const T: u64>(m: &Rq<T, N>) -> Rq<Q, N> {
let m = m.remodule::<Q>();
let delta = Q / T; // floored
pub fn encode(param: &Param, m: &Rq) -> Rq {
debug_assert_eq!(param.t, m.param.q);
let m = m.remodule(param.ring.q);
let delta = param.ring.q / param.t; // floored
m * delta m * delta
} }
// scale down // scale down
pub fn decode<const T: u64>(m: &Rq<Q, N>) -> Rq<T, N> {
let r = m.mul_div_round(T, Q);
let r: Rq<T, N> = r.remodule::<T>();
pub fn decode(param: &Param, m: &Rq) -> Rq {
let r = m.mul_div_round(param.t, param.ring.q);
let r: Rq = r.remodule(param.t);
r r
} }
pub fn mod_switch<const P: u64>(&self) -> GLWE<Rq<P, N>, K> {
let a: TR<Rq<P, N>, K> = TR(self
.0
.0
.iter()
.map(|r| r.mod_switch::<P>())
.collect::<Vec<_>>());
let b: Rq<P, N> = self.1.mod_switch::<P>();
pub fn mod_switch(&self, p: u64) -> GLWE<Rq> {
let a: TR<Rq> = TR {
k: self.0.k,
r: self.0.r.iter().map(|r| r.mod_switch(p)).collect::<Vec<_>>(),
};
let b: Rq = self.1.mod_switch(p);
GLWE(a, b) GLWE(a, b)
} }
} }
impl<R: Ring, const K: usize> Add<GLWE<R, K>> for GLWE<R, K> {
impl<R: Ring> Add<GLWE<R>> for GLWE<R> {
type Output = Self; type Output = Self;
fn add(self, other: Self) -> Self { fn add(self, other: Self) -> Self {
let a: TR<R, K> = self.0 + other.0;
debug_assert_eq!(self.0.k, other.0.k);
debug_assert_eq!(self.1.param(), other.1.param());
let a: TR<R> = self.0 + other.0;
let b: R = self.1 + other.1; let b: R = self.1 + other.1;
Self(a, b) Self(a, b)
} }
} }
impl<R: Ring, const K: usize> Add<R> for GLWE<R, K> {
impl<R: Ring> Add<R> for GLWE<R> {
type Output = Self; type Output = Self;
fn add(self, plaintext: R) -> Self { fn add(self, plaintext: R) -> Self {
let a: TR<R, K> = self.0;
debug_assert_eq!(self.1.param(), plaintext.param());
let a: TR<R> = self.0;
let b: R = self.1 + plaintext; let b: R = self.1 + plaintext;
Self(a, b) Self(a, b)
} }
} }
impl<R: Ring, const K: usize> AddAssign for GLWE<R, K> {
impl<R: Ring> AddAssign for GLWE<R> {
fn add_assign(&mut self, rhs: Self) { fn add_assign(&mut self, rhs: Self) {
for i in 0..K {
self.0 .0[i] = self.0 .0[i].clone() + rhs.0 .0[i].clone();
debug_assert_eq!(self.0.k, rhs.0.k);
debug_assert_eq!(self.1.param(), rhs.1.param());
let k = self.0.k;
for i in 0..k {
self.0.r[i] = self.0.r[i].clone() + rhs.0.r[i].clone();
} }
self.1 = self.1.clone() + rhs.1.clone(); self.1 = self.1.clone() + rhs.1.clone();
} }
} }
impl<R: Ring, const K: usize> Sum<GLWE<R, K>> for GLWE<R, K> {
fn sum<I>(iter: I) -> Self
impl<R: Ring> Sum<GLWE<R>> for GLWE<R> {
fn sum<I>(mut iter: I) -> Self
where where
I: Iterator<Item = Self>, I: Iterator<Item = Self>,
{ {
let mut acc = GLWE::<R, K>::zero();
for e in iter {
acc += e;
}
acc
let first = iter.next().unwrap();
iter.fold(first, |acc, e| acc + e)
} }
} }
impl<R: Ring, const K: usize> Sub<GLWE<R, K>> for GLWE<R, K> {
impl<R: Ring> Sub<GLWE<R>> for GLWE<R> {
type Output = Self; type Output = Self;
fn sub(self, other: Self) -> Self { fn sub(self, other: Self) -> Self {
let a: TR<R, K> = self.0 - other.0;
debug_assert_eq!(self.0.k, other.0.k);
debug_assert_eq!(self.1.param(), other.1.param());
let a: TR<R> = self.0 - other.0;
let b: R = self.1 - other.1; let b: R = self.1 - other.1;
Self(a, b) Self(a, b)
} }
} }
impl<R: Ring, const K: usize> Mul<R> for GLWE<R, K> {
impl<R: Ring> Mul<R> for GLWE<R> {
type Output = Self; type Output = Self;
fn mul(self, plaintext: R) -> Self { fn mul(self, plaintext: R) -> Self {
let a: TR<R, K> = TR(self.0 .0.iter().map(|r_i| *r_i * plaintext).collect());
debug_assert_eq!(self.1.param(), plaintext.param());
let a: TR<R> = TR {
k: self.0.k,
r: self
.0
.r
.iter()
.map(|r_i| r_i.clone() * plaintext.clone())
.collect(),
};
let b: R = self.1 * plaintext; let b: R = self.1 * plaintext;
Self(a, b) Self(a, b)
} }
@ -255,77 +324,90 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_encrypt_decrypt() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 128;
const T: u64 = 32; // plaintext modulus
const K: usize = 16;
type S = GLWE<Rq<Q, N>, K>;
fn test_encrypt_decrypt_ring_nq() -> Result<()> {
let param = Param {
err_sigma: ERR_SIGMA,
ring: RingParam {
q: 2u64.pow(16) + 1,
n: 128,
},
k: 16,
t: 32, // plaintext modulus
};
type S = GLWE<Rq>;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = S::new_key(&mut rng, &param)?;
let m = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?; // msg
// let m: Rq<Q, N> = m.remodule::<Q>();
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?; // msg
let p = S::encode(&param, &m); // plaintext
let p = S::encode::<T>(&m); // plaintext
let c = S::encrypt(&mut rng, &pk, &p)?; // ciphertext
let c = S::encrypt(&mut rng, &param, &pk, &p)?; // ciphertext
let p_recovered = c.decrypt(&sk); let p_recovered = c.decrypt(&sk);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = S::decode(&param, &p_recovered);
assert_eq!(m.remodule::<T>(), m_recovered.remodule::<T>());
assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t));
// same but using encrypt_s (with sk instead of pk)) // same but using encrypt_s (with sk instead of pk))
let c = S::encrypt_s(&mut rng, &sk, &p)?;
let c = S::encrypt_s(&mut rng, &param, &sk, &p)?;
let p_recovered = c.decrypt(&sk); let p_recovered = c.decrypt(&sk);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = S::decode(&param, &p_recovered);
assert_eq!(m.remodule::<T>(), m_recovered.remodule::<T>());
assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t));
} }
Ok(()) Ok(())
} }
use arith::{Tn, T64}; use arith::{Tn, T64};
use std::array;
pub fn t_encode<const P: u64>(m: &Rq<P, 4>) -> Tn<4> {
let delta = u64::MAX / P; // floored
pub fn t_encode(param: &RingParam, m: &Rq) -> Tn {
let p = m.param.q; // plaintext space
let delta = u64::MAX / p; // floored
let coeffs = m.coeffs(); let coeffs = m.coeffs();
Tn(array::from_fn(|i| T64(coeffs[i].0 * delta)))
Tn {
param: *param,
coeffs: coeffs.iter().map(|c_i| T64(c_i.v * delta)).collect(),
}
} }
pub fn t_decode<const P: u64>(p: &Tn<4>) -> Rq<P, 4> {
let p = p.mul_div_round(P, u64::MAX);
Rq::<P, 4>::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect())
pub fn t_decode(param: &Param, pt: &Tn) -> Rq {
let pt = pt.mul_div_round(param.t, u64::MAX);
Rq::from_vec_u64(&param.pt(), pt.coeffs().iter().map(|c| c.0).collect())
} }
#[test] #[test]
fn test_encrypt_decrypt_torus() -> Result<()> { fn test_encrypt_decrypt_torus() -> Result<()> {
const N: usize = 128;
const T: u64 = 32; // plaintext modulus
const K: usize = 16;
type S = GLWE<Tn<4>, K>;
let param = Param {
err_sigma: ERR_SIGMA,
ring: RingParam {
q: u64::MAX,
n: 128,
},
k: 16,
t: 32, // plaintext modulus
};
type S = GLWE<Tn>;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_f64, T as f64);
let msg_dist = Uniform::new(0_f64, param.t as f64);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = S::new_key(&mut rng, &param)?;
let m = Rq::<T, 4>::rand(&mut rng, msg_dist); // msg
let m = Rq::rand(&mut rng, msg_dist, &param.pt()); // msg
let p = t_encode::<T>(&m); // plaintext
let c = S::encrypt(&mut rng, &pk, &p)?; // ciphertext
let p = t_encode(&param.ring, &m); // plaintext
let c = S::encrypt(&mut rng, &param, &pk, &p)?; // ciphertext
let p_recovered = c.decrypt(&sk); let p_recovered = c.decrypt(&sk);
let m_recovered = t_decode::<T>(&p_recovered);
let m_recovered = t_decode(&param, &p_recovered);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
// same but using encrypt_s (with sk instead of pk)) // same but using encrypt_s (with sk instead of pk))
let c = S::encrypt_s(&mut rng, &sk, &p)?;
let c = S::encrypt_s(&mut rng, &param, &sk, &p)?;
let p_recovered = c.decrypt(&sk); let p_recovered = c.decrypt(&sk);
let m_recovered = t_decode::<T>(&p_recovered);
let m_recovered = t_decode(&param, &p_recovered);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
} }
@ -335,32 +417,37 @@ mod tests {
#[test] #[test]
fn test_addition() -> Result<()> { fn test_addition() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 128;
const T: u64 = 20;
const K: usize = 16;
type S = GLWE<Rq<Q, N>, K>;
let param = Param {
err_sigma: ERR_SIGMA,
ring: RingParam {
q: 2u64.pow(16) + 1,
n: 128,
},
k: 16,
t: 20, // plaintext modulus
};
type S = GLWE<Rq>;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = S::new_key(&mut rng, &param)?;
let m1 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let p1: Rq<Q, N> = S::encode::<T>(&m1); // plaintext
let p2: Rq<Q, N> = S::encode::<T>(&m2); // plaintext
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: Rq = S::encode(&param, &m1); // plaintext
let p2: Rq = S::encode(&param, &m2); // plaintext
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c2 = S::encrypt(&mut rng, &pk, &p2)?;
let c1 = S::encrypt(&mut rng, &param, &pk, &p1)?;
let c2 = S::encrypt(&mut rng, &param, &pk, &p2)?;
let c3 = c1 + c2; let c3 = c1 + c2;
let p3_recovered = c3.decrypt(&sk); let p3_recovered = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
let m3_recovered = S::decode(&param, &p3_recovered);
assert_eq!((m1 + m2).remodule::<T>(), m3_recovered.remodule::<T>());
assert_eq!((m1 + m2).remodule(param.t), m3_recovered.remodule(param.t));
} }
Ok(()) Ok(())
@ -368,31 +455,36 @@ mod tests {
#[test] #[test]
fn test_add_plaintext() -> Result<()> { fn test_add_plaintext() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 128;
const T: u64 = 32;
const K: usize = 16;
type S = GLWE<Rq<Q, N>, K>;
let param = Param {
err_sigma: ERR_SIGMA,
ring: RingParam {
q: 2u64.pow(16) + 1,
n: 128,
},
k: 16,
t: 32, // plaintext modulus
};
type S = GLWE<Rq>;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = S::new_key(&mut rng, &param)?;
let m1 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let p1: Rq<Q, N> = S::encode::<T>(&m1); // plaintext
let p2: Rq<Q, N> = S::encode::<T>(&m2); // plaintext
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: Rq = S::encode(&param, &m1); // plaintext
let p2: Rq = S::encode(&param, &m2); // plaintext
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c1 = S::encrypt(&mut rng, &param, &pk, &p1)?;
let c3 = c1 + p2; let c3 = c1 + p2;
let p3_recovered = c3.decrypt(&sk); let p3_recovered = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
let m3_recovered = S::decode(&param, &p3_recovered);
assert_eq!((m1 + m2).remodule::<T>(), m3_recovered.remodule::<T>());
assert_eq!((m1 + m2).remodule(param.t), m3_recovered.remodule(param.t));
} }
Ok(()) Ok(())
@ -400,30 +492,35 @@ mod tests {
#[test] #[test]
fn test_mul_plaintext() -> Result<()> { fn test_mul_plaintext() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 16;
const T: u64 = 4;
const K: usize = 16;
type S = GLWE<Rq<Q, N>, K>;
let param = Param {
err_sigma: ERR_SIGMA,
ring: RingParam {
q: 2u64.pow(16) + 1,
n: 16,
},
k: 16,
t: 4, // plaintext modulus
};
type S = GLWE<Rq>;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = S::new_key(&mut rng, &param)?;
let m1 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let p1: Rq<Q, N> = S::encode::<T>(&m1); // plaintext
let p2 = m2.remodule::<Q>(); // notice we don't encode (scale by delta)
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: Rq = S::encode(&param, &m1); // plaintext
let p2 = m2.remodule(param.ring.q); // notice we don't encode (scale by delta)
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c1 = S::encrypt(&mut rng, &param, &pk, &p1)?;
let c3 = c1 * p2; let c3 = c1 * p2;
let p3_recovered: Rq<Q, N> = c3.decrypt(&sk);
let m3_recovered: Rq<T, N> = S::decode::<T>(&p3_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), m3_recovered);
let p3_recovered: Rq = c3.decrypt(&sk);
let m3_recovered: Rq = S::decode(&param, &p3_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), m3_recovered);
} }
Ok(()) Ok(())
@ -431,33 +528,50 @@ mod tests {
#[test] #[test]
fn test_mod_switch() -> Result<()> { fn test_mod_switch() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const P: u64 = 2u64.pow(8) + 1;
let param = Param {
err_sigma: ERR_SIGMA,
ring: RingParam {
q: 2u64.pow(16) + 1,
n: 8,
},
k: 16,
t: 4, // plaintext modulus, must be a prime or power of a prime
};
let new_q: u64 = 2u64.pow(8) + 1;
// note: wip, Q and P chosen so that P/Q is an integer // note: wip, Q and P chosen so that P/Q is an integer
const N: usize = 8;
const T: u64 = 4; // plaintext modulus, must be a prime or power of a prime
const K: usize = 16;
type S = GLWE<Rq<Q, N>, K>;
type S = GLWE<Rq>;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = S::new_key(&mut rng, &param)?;
let m = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p = S::encode::<T>(&m);
let c = S::encrypt(&mut rng, &pk, &p)?;
let p = S::encode(&param, &m);
let c = S::encrypt(&mut rng, &param, &pk, &p)?;
let c2: GLWE<Rq<P, N>, K> = c.mod_switch::<P>();
let sk2: SecretKey<Rq<P, N>, K> =
SecretKey(TR(sk.0 .0.iter().map(|s_i| s_i.remodule::<P>()).collect()));
let c2: GLWE<Rq> = c.mod_switch(new_q);
assert_eq!(c2.1.param.q, new_q);
let sk2: SecretKey<Rq> = SecretKey(TR {
k: param.k,
r: sk.0.r.iter().map(|s_i| s_i.remodule(new_q)).collect(),
});
let p_recovered = c2.decrypt(&sk2); let p_recovered = c2.decrypt(&sk2);
let m_recovered = GLWE::<Rq<P, N>, K>::decode::<T>(&p_recovered);
assert_eq!(m.remodule::<T>(), m_recovered.remodule::<T>());
let new_param = Param {
err_sigma: ERR_SIGMA,
ring: RingParam {
q: new_q,
n: param.ring.n,
},
k: param.k,
t: param.t,
};
let m_recovered = GLWE::<Rq>::decode(&new_param, &p_recovered);
assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t));
} }
Ok(()) Ok(())
@ -465,40 +579,45 @@ mod tests {
#[test] #[test]
fn test_key_switch() -> Result<()> { fn test_key_switch() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 128;
const T: u64 = 2; // plaintext modulus
const K: usize = 16;
type S = GLWE<Rq<Q, N>, K>;
let param = Param {
err_sigma: ERR_SIGMA,
ring: RingParam {
q: 2u64.pow(16) + 1,
n: 128,
},
k: 16,
t: 2,
};
type S = GLWE<Rq>;
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 16; let l: u32 = 16;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let (sk, pk) = S::new_key(&mut rng)?;
let (sk2, _) = S::new_key(&mut rng)?;
let (sk, pk) = S::new_key(&mut rng, &param)?;
let (sk2, _) = S::new_key(&mut rng, &param)?;
// ksk to switch from sk to sk2 // ksk to switch from sk to sk2
let ksk = S::new_ksk(&mut rng, beta, l, &sk, &sk2)?;
let ksk = S::new_ksk(&mut rng, &param, beta, l, &sk, &sk2)?;
let msg_dist = Uniform::new(0_u64, T);
let m = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let p = S::encode::<T>(&m); // plaintext
//
let c = S::encrypt_s(&mut rng, &sk, &p)?;
let msg_dist = Uniform::new(0_u64, param.t);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p = S::encode(&param, &m); // plaintext
//
let c = S::encrypt_s(&mut rng, &param, &sk, &p)?;
let c2 = c.key_switch(beta, l, &ksk);
let c2 = c.key_switch(&param, beta, l, &ksk);
// decrypt with the 2nd secret key // decrypt with the 2nd secret key
let p_recovered = c2.decrypt(&sk2); let p_recovered = c2.decrypt(&sk2);
let m_recovered = S::decode::<T>(&p_recovered);
assert_eq!(m.remodule::<T>(), m_recovered.remodule::<T>());
let m_recovered = S::decode(&param, &p_recovered);
assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t));
// do the same but now encrypting with pk // do the same but now encrypting with pk
let c = S::encrypt(&mut rng, &pk, &p)?;
let c2 = c.key_switch(beta, l, &ksk);
let c = S::encrypt(&mut rng, &param, &pk, &p)?;
let c2 = c.key_switch(&param, beta, l, &ksk);
let p_recovered = c2.decrypt(&sk2); let p_recovered = c2.decrypt(&sk2);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = S::decode(&param, &p_recovered);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
Ok(()) Ok(())

+ 2
- 0
tfhe/src/lib.rs

@ -10,3 +10,5 @@ pub mod tglwe;
pub mod tgsw; pub mod tgsw;
pub mod tlev; pub mod tlev;
pub mod tlwe; pub mod tlwe;
pub(crate) const ERR_SIGMA: f64 = 3.2;

+ 75
- 58
tfhe/src/tggsw.rs

@ -4,53 +4,57 @@ use rand::Rng;
use std::array; use std::array;
use std::ops::{Add, Mul}; use std::ops::{Add, Mul};
use arith::{Ring, Rq, Tn, T64, TR};
use arith::{Ring, RingParam, Rq, Tn, T64, TR};
use crate::tglwe::{PublicKey, SecretKey, TGLWE}; use crate::tglwe::{PublicKey, SecretKey, TGLWE};
use gfhe::glwe::GLWE;
use gfhe::glwe::{Param, GLWE};
/// vector of length K+1 = ([K * TGLev], [1 * TGLev]) /// vector of length K+1 = ([K * TGLev], [1 * TGLev])
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct TGGSW<const N: usize, const K: usize>(pub(crate) Vec<TGLev<N, K>>, TGLev<N, K>);
pub struct TGGSW(pub(crate) Vec<TGLev>, TGLev);
impl<const N: usize, const K: usize> TGGSW<N, K> {
impl TGGSW {
pub fn encrypt_s( pub fn encrypt_s(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
beta: u32, beta: u32,
l: u32, l: u32,
sk: &SecretKey<N, K>,
m: &Tn<N>,
sk: &SecretKey,
m: &Tn,
) -> Result<Self> { ) -> Result<Self> {
let a: Vec<TGLev<N, K>> = (0..K)
.map(|i| TGLev::encrypt_s(&mut rng, beta, l, sk, &(-sk.0 .0 .0[i] * *m)))
debug_assert_eq!(sk.0 .0.k, param.k);
let a: Vec<TGLev> = (0..param.k)
.map(|i| TGLev::encrypt_s(&mut rng, param, beta, l, sk, &(&-sk.0 .0.r[i].clone() * m)))
// TODO rm clone
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let b: TGLev<N, K> = TGLev::encrypt_s(&mut rng, beta, l, sk, m)?;
let b: TGLev = TGLev::encrypt_s(&mut rng, &param, beta, l, sk, m)?;
Ok(Self(a, b)) Ok(Self(a, b))
} }
pub fn decrypt(&self, sk: &SecretKey<N, K>, beta: u32) -> Tn<N> {
pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> Tn {
self.1.decrypt(sk, beta) self.1.decrypt(sk, beta)
} }
pub fn cmux(bit: Self, ct1: TGLWE<N, K>, ct2: TGLWE<N, K>) -> TGLWE<N, K> {
pub fn cmux(bit: Self, ct1: TGLWE, ct2: TGLWE) -> TGLWE {
ct1.clone() + (bit * (ct2 - ct1)) ct1.clone() + (bit * (ct2 - ct1))
} }
} }
/// External product TGGSW x TGLWE
impl<const N: usize, const K: usize> Mul<TGLWE<N, K>> for TGGSW<N, K> {
type Output = TGLWE<N, K>;
/// External product tggsw x tglwe
impl Mul<TGLWE> for TGGSW {
type Output = TGLWE;
fn mul(self, tglwe: TGLWE<N, K>) -> TGLWE<N, K> {
fn mul(self, tglwe: TGLWE) -> TGLWE {
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 64; // TODO wip let l: u32 = 64; // TODO wip
let tglwe_ab: Vec<Tn<N>> = [tglwe.0 .0 .0.clone(), vec![tglwe.0 .1]].concat();
let tglwe_ab: Vec<Tn> = [tglwe.0 .0.r.clone(), vec![tglwe.0 .1]].concat();
let tgsw_ab: Vec<TGLev<N, K>> = [self.0.clone(), vec![self.1]].concat();
let tgsw_ab: Vec<TGLev> = [self.0.clone(), vec![self.1]].concat();
assert_eq!(tgsw_ab.len(), tglwe_ab.len()); assert_eq!(tgsw_ab.len(), tglwe_ab.len());
let r: TGLWE<N, K> = zip_eq(tgsw_ab, tglwe_ab)
let r: TGLWE = zip_eq(tgsw_ab, tglwe_ab)
.map(|(tlev_i, tglwe_i)| tlev_i * tglwe_i.decompose(beta, l)) .map(|(tlev_i, tglwe_i)| tlev_i * tglwe_i.decompose(beta, l))
.sum(); .sum();
r r
@ -58,26 +62,36 @@ impl Mul> for TGGSW {
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct TGLev<const N: usize, const K: usize>(pub(crate) Vec<TGLWE<N, K>>);
pub struct TGLev(pub(crate) Vec<TGLWE>);
impl TGLev {
pub fn encode(param: &Param, m: &Rq) -> Tn {
debug_assert_eq!(param.t, m.param.q); // plaintext modulus
impl<const N: usize, const K: usize> TGLev<N, K> {
pub fn encode<const T: u64>(m: &Rq<T, N>) -> Tn<N> {
let coeffs = m.coeffs();
Tn(array::from_fn(|i| T64(coeffs[i].0)))
Tn {
param: param.ring,
coeffs: m.coeffs().iter().map(|c_i| T64(c_i.v)).collect(),
}
} }
pub fn decode<const T: u64>(p: &Tn<N>) -> Rq<T, N> {
Rq::<T, N>::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect())
pub fn decode(param: &Param, p: &Tn) -> Rq {
Rq::from_vec_u64(&param.pt(), p.coeffs().iter().map(|c| c.0).collect())
} }
pub fn encrypt( pub fn encrypt(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
beta: u32, beta: u32,
l: u32, l: u32,
pk: &PublicKey<N, K>,
m: &Tn<N>,
pk: &PublicKey,
m: &Tn,
) -> Result<Self> { ) -> Result<Self> {
let tlev: Vec<TGLWE<N, K>> = (1..l + 1)
let tlev: Vec<TGLWE> = (1..l + 1)
.map(|i| { .map(|i| {
TGLWE::<N, K>::encrypt(&mut rng, pk, &(*m * (u64::MAX / beta.pow(i as u32) as u64)))
TGLWE::encrypt(
&mut rng,
&param,
pk,
&(m * &(u64::MAX / beta.pow(i as u32) as u64)),
)
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
@ -85,35 +99,36 @@ impl TGLev {
} }
pub fn encrypt_s( pub fn encrypt_s(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
_beta: u32, // TODO rm, and make beta=2 always _beta: u32, // TODO rm, and make beta=2 always
l: u32, l: u32,
sk: &SecretKey<N, K>,
m: &Tn<N>,
sk: &SecretKey,
m: &Tn,
) -> Result<Self> { ) -> Result<Self> {
let tlev: Vec<TGLWE<N, K>> = (1..l as u64 + 1)
let tlev: Vec<TGLWE> = (1..l as u64 + 1)
.map(|i| { .map(|i| {
let aux = if i < 64 { let aux = if i < 64 {
*m * (u64::MAX / (1u64 << i))
m * &(u64::MAX / (1u64 << i))
} else { } else {
// 1<<64 would overflow, and anyways we're dividing u64::MAX // 1<<64 would overflow, and anyways we're dividing u64::MAX
// by it, which would be equal to 1 // by it, which would be equal to 1
*m
m.clone() // TODO rm clone
}; };
TGLWE::<N, K>::encrypt_s(&mut rng, sk, &aux)
TGLWE::encrypt_s(&mut rng, &param, sk, &aux)
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
Ok(Self(tlev)) Ok(Self(tlev))
} }
pub fn decrypt(&self, sk: &SecretKey<N, K>, beta: u32) -> Tn<N> {
pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> Tn {
let pt = self.0[0].decrypt(sk); let pt = self.0[0].decrypt(sk);
pt.mul_div_round(beta as u64, u64::MAX) pt.mul_div_round(beta as u64, u64::MAX)
} }
} }
impl<const N: usize, const K: usize> TGLev<N, K> {
pub fn iter(&self) -> std::slice::Iter<TGLWE<N, K>> {
impl TGLev {
pub fn iter(&self) -> std::slice::Iter<TGLWE> {
self.0.iter() self.0.iter()
} }
} }
@ -121,14 +136,14 @@ impl TGLev {
// dot product between a TGLev and Vec<Tn<N>>, usually Vec<Tn<N>> comes from a // dot product between a TGLev and Vec<Tn<N>>, usually Vec<Tn<N>> comes from a
// decomposition of Tn<N> // decomposition of Tn<N>
// TGLev * Vec<Tn<N>> --> TGLWE // TGLev * Vec<Tn<N>> --> TGLWE
impl<const N: usize, const K: usize> Mul<Vec<Tn<N>>> for TGLev<N, K> {
type Output = TGLWE<N, K>;
fn mul(self, v: Vec<Tn<N>>) -> Self::Output {
impl Mul<Vec<Tn>> for TGLev {
type Output = TGLWE;
fn mul(self, v: Vec<Tn>) -> Self::Output {
assert_eq!(self.0.len(), v.len()); assert_eq!(self.0.len(), v.len());
// l TGLWES // l TGLWES
let tlwes: Vec<TGLWE<N, K>> = self.0;
let r: TGLWE<N, K> = zip_eq(v, tlwes).map(|(a_d_i, glwe_i)| glwe_i * a_d_i).sum();
let tlwes: Vec<TGLWE> = self.0;
let r: TGLWE = zip_eq(v, tlwes).map(|(a_d_i, glwe_i)| glwe_i * a_d_i).sum();
r r
} }
} }
@ -141,38 +156,40 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_external_product() -> Result<()> { fn test_external_product() -> Result<()> {
const T: u64 = 16; // plaintext modulus
const K: usize = 4;
const N: usize = 64;
const KN: usize = K * N;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 64 },
k: 4,
t: 16, // plaintext modulus
};
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 64; let l: u32 = 64;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..50 { for _ in 0..50 {
let (sk, _) = TGLWE::<N, K>::new_key::<KN>(&mut rng)?;
let (sk, _) = TGLWE::new_key(&mut rng, &param)?;
let m1: Rq<T, N> = Rq::rand_u64(&mut rng, msg_dist)?;
let p1: Tn<N> = TGLev::<N, K>::encode::<T>(&m1);
let m1: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: Tn = TGLev::encode(&param, &m1);
let m2: Rq<T, N> = Rq::rand_u64(&mut rng, msg_dist)?;
let p2: Tn<N> = TGLWE::<N, K>::encode::<T>(&m2); // scaled by delta
let m2: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p2: Tn = TGLWE::encode(&param, &m2); // scaled by delta
let tgsw = TGGSW::<N, K>::encrypt_s(&mut rng, beta, l, &sk, &p1)?;
let tlwe = TGLWE::<N, K>::encrypt_s(&mut rng, &sk, &p2)?;
let tgsw = TGGSW::encrypt_s(&mut rng, &param, beta, l, &sk, &p1)?;
let tlwe = TGLWE::encrypt_s(&mut rng, &param, &sk, &p2)?;
let res: TGLWE<N, K> = tgsw * tlwe;
let res: TGLWE = tgsw * tlwe;
// let p_recovered = res.decrypt(&sk, beta); // let p_recovered = res.decrypt(&sk, beta);
let p_recovered = res.decrypt(&sk); let p_recovered = res.decrypt(&sk);
// downscaled by delta^-1 // downscaled by delta^-1
let res_recovered = TGLWE::<N, K>::decode::<T>(&p_recovered);
let res_recovered = TGLWE::decode(&param, &p_recovered);
// assert_eq!(m1 * m2, m_recovered); // assert_eq!(m1 * m2, m_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), res_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), res_recovered);
} }
Ok(()) Ok(())

+ 183
- 141
tfhe/src/tglwe.rs

@ -1,161 +1,194 @@
use anyhow::Result; use anyhow::Result;
use itertools::zip_eq;
use rand::distributions::Standard;
use rand::Rng; use rand::Rng;
use rand_distr::{Normal, Uniform};
use std::array;
use std::iter::Sum; use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Sub}; use std::ops::{Add, AddAssign, Mul, Sub};
use arith::{Ring, Rq, Tn, T64, TR};
use gfhe::{glwe, GLWE};
use arith::{Ring, RingParam, Rq, Tn, T64, TR};
use gfhe::{glwe, glwe::Param, GLWE};
use crate::tlev::TLev;
use crate::{tlwe, tlwe::TLWE}; use crate::{tlwe, tlwe::TLWE};
// pub type SecretKey<const N: usize, const K: usize> = glwe::SecretKey<Tn<N>, K>;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct SecretKey<const N: usize, const K: usize>(pub glwe::SecretKey<Tn<N>, K>);
// pub struct SecretKey<const K: usize>(pub tlwe::SecretKey<K>);
pub struct SecretKey(pub glwe::SecretKey<Tn>);
impl<const N: usize, const K: usize> SecretKey<N, K> {
pub fn to_tlwe<const KN: usize>(&self) -> tlwe::SecretKey<KN> {
let s: TR<Tn<N>, K> = self.0 .0.clone();
impl SecretKey {
pub fn to_tlwe(&self, param: &Param) -> tlwe::SecretKey {
let s: TR<Tn> = self.0 .0.clone();
debug_assert_eq!(s.r.len(), param.k); // sanity check
let r: Vec<Vec<T64>> = s.0.iter().map(|s_i| s_i.coeffs()).collect();
let kn = param.k * param.ring.n;
let r: Vec<Vec<T64>> = s.r.iter().map(|s_i| s_i.coeffs()).collect();
let r: Vec<T64> = r.into_iter().flatten().collect(); let r: Vec<T64> = r.into_iter().flatten().collect();
tlwe::SecretKey::<KN>(glwe::SecretKey::<T64, KN>(TR::<T64, KN>::new(r)))
debug_assert_eq!(r.len(), kn); // sanity check
tlwe::SecretKey(glwe::SecretKey::<T64>(TR::<T64>::new(kn, r)))
} }
} }
pub type PublicKey<const N: usize, const K: usize> = glwe::PublicKey<Tn<N>, K>;
pub type PublicKey = glwe::PublicKey<Tn>;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct TGLWE<const N: usize, const K: usize>(pub GLWE<Tn<N>, K>);
pub struct TGLWE(pub GLWE<Tn>);
impl<const N: usize, const K: usize> TGLWE<N, K> {
pub fn zero() -> Self {
Self(GLWE::<Tn<N>, K>::zero())
impl TGLWE {
pub fn zero(k: usize, param: &RingParam) -> Self {
Self(GLWE::<Tn>::zero(k, param))
} }
pub fn from_plaintext(p: Tn<N>) -> Self {
Self(GLWE::<Tn<N>, K>::from_plaintext(p))
pub fn from_plaintext(k: usize, param: &RingParam, p: Tn) -> Self {
Self(GLWE::<Tn>::from_plaintext(k, param, p))
} }
pub fn new_key<const KN: usize>(
mut rng: impl Rng,
) -> Result<(SecretKey<N, K>, PublicKey<N, K>)> {
// assert_eq!(KN, K * N); // this is wip, while not being able to compute K*N
let (sk_tlwe, _) = TLWE::<KN>::new_key(&mut rng)?;
// let sk = crate::tlwe::sk_to_tglwe::<N, K, KN>(sk_tlwe);
let sk = sk_tlwe.to_tglwe::<N, K>();
let pk: PublicKey<N, K> = GLWE::pk_from_sk(rng, sk.0.clone())?;
pub fn new_key(mut rng: impl Rng, param: &Param) -> Result<(SecretKey, PublicKey)> {
let (sk_tlwe, _) = TLWE::new_key(&mut rng, &param.lwe())?; //param.lwe() so that it uses K*N
debug_assert_eq!(sk_tlwe.0 .0.r.len(), param.lwe().k); // =KN (sanity check)
let sk = sk_tlwe.to_tglwe(param);
let pk: PublicKey = GLWE::pk_from_sk(rng, param, sk.0.clone())?;
Ok((sk, pk)) Ok((sk, pk))
} }
pub fn encode<const P: u64>(m: &Rq<P, N>) -> Tn<N> {
let delta = u64::MAX / P; // floored
pub fn encode(param: &Param, m: &Rq) -> Tn {
debug_assert_eq!(param.t, m.param.q); // plaintext modulus
let p = param.t;
let delta = u64::MAX / p; // floored
let coeffs = m.coeffs(); let coeffs = m.coeffs();
Tn(array::from_fn(|i| T64(coeffs[i].0 * delta)))
Tn {
param: param.ring,
coeffs: coeffs.iter().map(|c_i| T64(c_i.v * delta)).collect(),
}
}
pub fn decode(param: &Param, pt: &Tn) -> Rq {
let p = param.t;
let pt = pt.mul_div_round(p, u64::MAX);
Rq::from_vec_u64(&param.pt(), pt.coeffs().iter().map(|c| c.0).collect())
} }
pub fn decode<const P: u64>(p: &Tn<N>) -> Rq<P, N> {
let p = p.mul_div_round(P, u64::MAX);
Rq::<P, N>::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect())
/// encodes the given message as a TGLWE constant/public value, for using it
/// in ct-pt-multiplication.
pub fn new_const(param: &Param, m: &Rq) -> Tn {
debug_assert_eq!(param.t, m.param.q);
// don't scale up m, set the Tn element directly from m's coefficients
Tn {
param: param.ring,
coeffs: m.coeffs().iter().map(|c_i| T64(c_i.v)).collect(),
}
} }
// encrypts with the given SecretKey (instead of PublicKey)
pub fn encrypt_s(rng: impl Rng, sk: &SecretKey<N, K>, p: &Tn<N>) -> Result<Self> {
let glwe = GLWE::encrypt_s(rng, &sk.0, p)?;
/// encrypts with the given SecretKey (instead of PublicKey)
pub fn encrypt_s(rng: impl Rng, param: &Param, sk: &SecretKey, p: &Tn) -> Result<Self> {
let glwe = GLWE::encrypt_s(rng, param, &sk.0, p)?;
Ok(Self(glwe)) Ok(Self(glwe))
} }
pub fn encrypt(rng: impl Rng, pk: &PublicKey<N, K>, p: &Tn<N>) -> Result<Self> {
let glwe = GLWE::encrypt(rng, &pk, p)?;
pub fn encrypt(rng: impl Rng, param: &Param, pk: &PublicKey, p: &Tn) -> Result<Self> {
let glwe = GLWE::encrypt(rng, param, &pk, p)?;
Ok(Self(glwe)) Ok(Self(glwe))
} }
pub fn decrypt(&self, sk: &SecretKey<N, K>) -> Tn<N> {
pub fn decrypt(&self, sk: &SecretKey) -> Tn {
self.0.decrypt(&sk.0) self.0.decrypt(&sk.0)
} }
/// Sample extraction / Coefficient extraction /// Sample extraction / Coefficient extraction
pub fn sample_extraction<const KN: usize>(&self, h: usize) -> TLWE<KN> {
assert!(h < N);
pub fn sample_extraction(&self, param: &Param, h: usize) -> TLWE {
let n = param.ring.n;
assert!(h < n);
let a: TR<Tn<N>, K> = self.0 .0.clone();
let a: TR<Tn> = self.0 .0.clone();
// set a_{n*i+j} = a_{i, h-j} if j \in {0, h} // set a_{n*i+j} = a_{i, h-j} if j \in {0, h}
// -a_{i, n+h-j} if j \in {h+1, n-1} // -a_{i, n+h-j} if j \in {h+1, n-1}
let new_a: Vec<T64> = a let new_a: Vec<T64> = a
.iter() .iter()
.flat_map(|a_i| { .flat_map(|a_i| {
let a_i = a_i.coeffs(); let a_i = a_i.coeffs();
(0..N)
.map(|j| if j <= h { a_i[h - j] } else { -a_i[N + h - j] })
(0..n)
.map(|j| if j <= h { a_i[h - j] } else { -a_i[n + h - j] })
.collect::<Vec<T64>>() .collect::<Vec<T64>>()
}) })
.collect::<Vec<T64>>(); .collect::<Vec<T64>>();
TLWE(GLWE(TR(new_a), self.0 .1.coeffs()[h]))
debug_assert_eq!(new_a.len(), param.k * param.ring.n); // sanity check
TLWE(GLWE(
TR {
// TODO use constructor `new`, which will check len with k
k: param.k * param.ring.n,
r: new_a,
},
self.0 .1.coeffs()[h],
))
} }
pub fn left_rotate(&self, h: usize) -> Self { pub fn left_rotate(&self, h: usize) -> Self {
dbg!(&h);
let (a, b): (TR<Tn<N>, K>, Tn<N>) = (self.0 .0.clone(), self.0 .1);
let (a, b): (TR<Tn>, Tn) = (self.0 .0.clone(), self.0 .1.clone());
Self(GLWE(a.left_rotate(h), b.left_rotate(h))) Self(GLWE(a.left_rotate(h), b.left_rotate(h)))
} }
} }
impl<const N: usize, const K: usize> Add<TGLWE<N, K>> for TGLWE<N, K> {
impl Add<TGLWE> for TGLWE {
type Output = Self; type Output = Self;
fn add(self, other: Self) -> Self { fn add(self, other: Self) -> Self {
debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
Self(self.0 + other.0) Self(self.0 + other.0)
} }
} }
impl<const N: usize, const K: usize> AddAssign for TGLWE<N, K> {
fn add_assign(&mut self, rhs: Self) {
self.0 += rhs.0
impl AddAssign for TGLWE {
fn add_assign(&mut self, other: Self) {
debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
self.0 += other.0
} }
} }
impl<const N: usize, const K: usize> Sum<TGLWE<N, K>> for TGLWE<N, K> {
fn sum<I>(iter: I) -> Self
impl Sum<TGLWE> for TGLWE {
fn sum<I>(mut iter: I) -> Self
where where
I: Iterator<Item = Self>, I: Iterator<Item = Self>,
{ {
let mut acc = TGLWE::<N, K>::zero();
for e in iter {
acc += e;
}
acc
let first = iter.next().unwrap();
iter.fold(first, |acc, e| acc + e)
} }
} }
impl<const N: usize, const K: usize> Sub<TGLWE<N, K>> for TGLWE<N, K> {
impl Sub<TGLWE> for TGLWE {
type Output = Self; type Output = Self;
fn sub(self, other: Self) -> Self { fn sub(self, other: Self) -> Self {
debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
Self(self.0 - other.0) Self(self.0 - other.0)
} }
} }
// plaintext addition // plaintext addition
impl<const N: usize, const K: usize> Add<Tn<N>> for TGLWE<N, K> {
impl Add<Tn> for TGLWE {
type Output = Self; type Output = Self;
fn add(self, plaintext: Tn<N>) -> Self {
let a: TR<Tn<N>, K> = self.0 .0;
let b: Tn<N> = self.0 .1 + plaintext;
fn add(self, plaintext: Tn) -> Self {
debug_assert_eq!(self.0 .1.param(), plaintext.param());
let a: TR<Tn> = self.0 .0;
let b: Tn = self.0 .1 + plaintext;
Self(GLWE(a, b)) Self(GLWE(a, b))
} }
} }
// plaintext substraction // plaintext substraction
impl<const N: usize, const K: usize> Sub<Tn<N>> for TGLWE<N, K> {
impl Sub<Tn> for TGLWE {
type Output = Self; type Output = Self;
fn sub(self, plaintext: Tn<N>) -> Self {
let a: TR<Tn<N>, K> = self.0 .0;
let b: Tn<N> = self.0 .1 - plaintext;
fn sub(self, plaintext: Tn) -> Self {
debug_assert_eq!(self.0 .1.param(), plaintext.param());
let a: TR<Tn> = self.0 .0;
let b: Tn = self.0 .1 - plaintext;
Self(GLWE(a, b)) Self(GLWE(a, b))
} }
} }
// plaintext multiplication // plaintext multiplication
impl<const N: usize, const K: usize> Mul<Tn<N>> for TGLWE<N, K> {
impl Mul<Tn> for TGLWE {
type Output = Self; type Output = Self;
fn mul(self, plaintext: Tn<N>) -> Self {
let a: TR<Tn<N>, K> = TR(self.0 .0 .0.iter().map(|r_i| *r_i * plaintext).collect());
let b: Tn<N> = self.0 .1 * plaintext;
fn mul(self, plaintext: Tn) -> Self {
debug_assert_eq!(self.0 .1.param(), plaintext.param());
let a: TR<Tn> = TR {
k: self.0 .0.k,
r: self.0 .0.r.iter().map(|r_i| r_i * &plaintext).collect(),
};
let b: Tn = self.0 .1 * plaintext;
Self(GLWE(a, b)) Self(GLWE(a, b))
} }
} }
@ -169,30 +202,32 @@ mod tests {
#[test] #[test]
fn test_encrypt_decrypt() -> Result<()> { fn test_encrypt_decrypt() -> Result<()> {
const T: u64 = 128; // msg space (msg modulus)
const N: usize = 64;
const K: usize = 16;
type S = TGLWE<N, K>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 64 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = TGLWE::<N, K>::new_key::<{ K * N }>(&mut rng)?;
let (sk, pk) = TGLWE::new_key(&mut rng, &param)?;
let m = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let p: Tn<N> = S::encode::<T>(&m);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p: Tn = TGLWE::encode(&param, &m);
let c = S::encrypt(&mut rng, &pk, &p)?;
let c = TGLWE::encrypt(&mut rng, &param, &pk, &p)?;
let p_recovered = c.decrypt(&sk); let p_recovered = c.decrypt(&sk);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = TGLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
// same but using encrypt_s (with sk instead of pk)) // same but using encrypt_s (with sk instead of pk))
let c = S::encrypt_s(&mut rng, &sk, &p)?;
let c = TGLWE::encrypt_s(&mut rng, &param, &sk, &p)?;
let p_recovered = c.decrypt(&sk); let p_recovered = c.decrypt(&sk);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = TGLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
} }
@ -202,31 +237,33 @@ mod tests {
#[test] #[test]
fn test_addition() -> Result<()> { fn test_addition() -> Result<()> {
const T: u64 = 128;
const N: usize = 64;
const K: usize = 16;
type S = TGLWE<N, K>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 64 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key::<{ K * N }>(&mut rng)?;
let (sk, pk) = TGLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let p1: Tn<N> = S::encode::<T>(&m1); // plaintext
let p2: Tn<N> = S::encode::<T>(&m2); // plaintext
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: Tn = TGLWE::encode(&param, &m1); // plaintext
let p2: Tn = TGLWE::encode(&param, &m2); // plaintext
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c2 = S::encrypt(&mut rng, &pk, &p2)?;
let c1 = TGLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c2 = TGLWE::encrypt(&mut rng, &param, &pk, &p2)?;
let c3 = c1 + c2; let c3 = c1 + c2;
let p3_recovered = c3.decrypt(&sk); let p3_recovered = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
let m3_recovered = TGLWE::decode(&param, &p3_recovered);
assert_eq!((m1 + m2).remodule::<T>(), m3_recovered.remodule::<T>());
assert_eq!((m1 + m2).remodule(param.t), m3_recovered.remodule(param.t));
} }
Ok(()) Ok(())
@ -234,28 +271,30 @@ mod tests {
#[test] #[test]
fn test_add_plaintext() -> Result<()> { fn test_add_plaintext() -> Result<()> {
const T: u64 = 128;
const N: usize = 64;
const K: usize = 16;
type S = TGLWE<N, K>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 64 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key::<{ K * N }>(&mut rng)?;
let (sk, pk) = TGLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let p1: Tn<N> = S::encode::<T>(&m1); // plaintext
let p2: Tn<N> = S::encode::<T>(&m2); // plaintext
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: Tn = TGLWE::encode(&param, &m1); // plaintext
let p2: Tn = TGLWE::encode(&param, &m2); // plaintext
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c1 = TGLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c3 = c1 + p2; let c3 = c1 + p2;
let p3_recovered = c3.decrypt(&sk); let p3_recovered = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
let m3_recovered = TGLWE::decode(&param, &p3_recovered);
assert_eq!(m1 + m2, m3_recovered); assert_eq!(m1 + m2, m3_recovered);
} }
@ -265,30 +304,31 @@ mod tests {
#[test] #[test]
fn test_mul_plaintext() -> Result<()> { fn test_mul_plaintext() -> Result<()> {
const T: u64 = 128;
const N: usize = 64;
const K: usize = 16;
type S = TGLWE<N, K>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 64 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key::<{ K * N }>(&mut rng)?;
let (sk, pk) = TGLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let p1: Tn<N> = S::encode::<T>(&m1);
// don't scale up p2, set it directly from m2
let p2: Tn<N> = Tn(array::from_fn(|i| T64(m2.coeffs()[i].0)));
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: Tn = TGLWE::encode(&param, &m1);
let p2: Tn = TGLWE::new_const(&param, &m2); // as constant/public value
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c1 = TGLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c3 = c1 * p2; let c3 = c1 * p2;
let p3_recovered: Tn<N> = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), m3_recovered);
let p3_recovered: Tn = c3.decrypt(&sk);
let m3_recovered = TGLWE::decode(&param, &p3_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), m3_recovered);
} }
Ok(()) Ok(())
@ -296,28 +336,30 @@ mod tests {
#[test] #[test]
fn test_sample_extraction() -> Result<()> { fn test_sample_extraction() -> Result<()> {
const T: u64 = 128; // msg space (msg modulus)
const N: usize = 64;
const K: usize = 16;
const KN: usize = K * N;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 64 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..20 { for _ in 0..20 {
let (sk, pk) = TGLWE::<N, K>::new_key::<KN>(&mut rng)?;
let sk_tlwe = sk.to_tlwe::<KN>();
let (sk, pk) = TGLWE::new_key(&mut rng, &param)?;
let sk_tlwe = sk.to_tlwe(&param);
let m = Rq::<T, N>::rand_u64(&mut rng, msg_dist)?;
let p: Tn<N> = TGLWE::<N, K>::encode::<T>(&m);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p: Tn = TGLWE::encode(&param, &m);
let c = TGLWE::<N, K>::encrypt(&mut rng, &pk, &p)?;
let c = TGLWE::encrypt(&mut rng, &param, &pk, &p)?;
for h in 0..N {
let c_h: TLWE<KN> = c.sample_extraction(h);
for h in 0..param.ring.n {
let c_h: TLWE = c.sample_extraction(&param, h);
let p_recovered = c_h.decrypt(&sk_tlwe); let p_recovered = c_h.decrypt(&sk_tlwe);
let m_recovered = TLWE::<KN>::decode::<T>(&p_recovered);
let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m.coeffs()[h], m_recovered.coeffs()[0]); assert_eq!(m.coeffs()[h], m_recovered.coeffs()[0]);
} }
} }

+ 67
- 61
tfhe/src/tgsw.rs

@ -1,65 +1,62 @@
use anyhow::Result; use anyhow::Result;
use itertools::zip_eq; use itertools::zip_eq;
use rand::Rng; use rand::Rng;
use std::array;
use std::ops::{Add, Mul};
use std::ops::Mul;
use arith::{Ring, Rq, Tn, T64, TR};
use arith::{Ring, T64};
use crate::tlev::TLev; use crate::tlev::TLev;
use crate::{
tglwe::TGLWE,
tlwe::{PublicKey, SecretKey, TLWE},
};
use gfhe::glwe::GLWE;
use crate::tlwe::{SecretKey, TLWE};
use gfhe::glwe::Param;
/// vector of length K+1 = [K], [1] /// vector of length K+1 = [K], [1]
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct TGSW<const K: usize>(pub(crate) Vec<TLev<K>>, TLev<K>);
pub struct TGSW(pub(crate) Vec<TLev>, TLev);
impl<const K: usize> TGSW<K> {
impl TGSW {
pub fn encrypt_s( pub fn encrypt_s(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
beta: u32, beta: u32,
l: u32, l: u32,
sk: &SecretKey<K>,
sk: &SecretKey,
m: &T64, m: &T64,
) -> Result<Self> { ) -> Result<Self> {
let a: Vec<TLev<K>> = (0..K)
.map(|i| TLev::encrypt_s(&mut rng, beta, l, sk, &(-sk.0 .0 .0[i] * *m)))
let a: Vec<TLev> = (0..param.k)
.map(|i| TLev::encrypt_s(&mut rng, &param, beta, l, sk, &(-sk.0 .0.r[i] * *m)))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let b: TLev<K> = TLev::encrypt_s(&mut rng, beta, l, sk, m)?;
let b: TLev = TLev::encrypt_s(&mut rng, &param, beta, l, sk, m)?;
Ok(Self(a, b)) Ok(Self(a, b))
} }
pub fn decrypt(&self, sk: &SecretKey<K>, beta: u32) -> T64 {
pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> T64 {
self.1.decrypt(sk, beta) self.1.decrypt(sk, beta)
} }
pub fn from_tlwe(_tlwe: TLWE<K>) -> Self {
pub fn from_tlwe(_tlwe: TLWE) -> Self {
todo!() todo!()
} }
pub fn cmux(bit: Self, ct1: TLWE<K>, ct2: TLWE<K>) -> TLWE<K> {
pub fn cmux(bit: Self, ct1: TLWE, ct2: TLWE) -> TLWE {
ct1.clone() + (bit * (ct2 - ct1)) ct1.clone() + (bit * (ct2 - ct1))
} }
} }
/// External product TGSW x TLWE /// External product TGSW x TLWE
impl<const K: usize> Mul<TLWE<K>> for TGSW<K> {
type Output = TLWE<K>;
impl Mul<TLWE> for TGSW {
type Output = TLWE;
fn mul(self, tlwe: TLWE<K>) -> TLWE<K> {
fn mul(self, tlwe: TLWE) -> TLWE {
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 64; // TODO wip let l: u32 = 64; // TODO wip
// since N=1, each tlwe element is a vector of length=1, decomposed into // since N=1, each tlwe element is a vector of length=1, decomposed into
// l elements, and we have K of them // l elements, and we have K of them
let tlwe_ab: Vec<T64> = [tlwe.0 .0 .0.clone(), vec![tlwe.0 .1]].concat();
let tlwe_ab: Vec<T64> = [tlwe.0 .0.r.clone(), vec![tlwe.0 .1]].concat();
let tgsw_ab: Vec<TLev<K>> = [self.0.clone(), vec![self.1]].concat();
let tgsw_ab: Vec<TLev> = [self.0.clone(), vec![self.1]].concat();
assert_eq!(tgsw_ab.len(), tlwe_ab.len()); assert_eq!(tgsw_ab.len(), tlwe_ab.len());
let r: TLWE<K> = zip_eq(tgsw_ab, tlwe_ab)
let r: TLWE = zip_eq(tgsw_ab, tlwe_ab)
.map(|(tlev_i, tlwe_i)| tlev_i * tlwe_i.decompose(beta, l)) .map(|(tlev_i, tlwe_i)| tlev_i * tlwe_i.decompose(beta, l))
.sum(); .sum();
r r
@ -72,28 +69,31 @@ mod tests {
use rand::distributions::Uniform; use rand::distributions::Uniform;
use super::*; use super::*;
use arith::{RingParam, Rq};
#[test] #[test]
fn test_encrypt_decrypt() -> Result<()> { fn test_encrypt_decrypt() -> Result<()> {
const T: u64 = 2; // plaintext modulus
const K: usize = 16;
type S = TGSW<K>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 2, // plaintext modulus
};
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 16; let l: u32 = 16;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..50 { for _ in 0..50 {
let (sk, _) = TLWE::<K>::new_key(&mut rng)?;
let (sk, _) = TLWE::new_key(&mut rng, &param)?;
let m: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?;
let p: T64 = TLev::<K>::encode::<T>(&m); // plaintext
let m: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p: T64 = TLev::encode(&param, &m); // plaintext
let c = S::encrypt_s(&mut rng, beta, l, &sk, &p)?;
let c = TGSW::encrypt_s(&mut rng, &param, beta, l, &sk, &p)?;
let p_recovered = c.decrypt(&sk, beta); let p_recovered = c.decrypt(&sk, beta);
let m_recovered = TLev::<K>::decode::<T>(&p_recovered);
let m_recovered = TLev::decode(&param, &p_recovered);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
} }
@ -103,36 +103,38 @@ mod tests {
#[test] #[test]
fn test_external_product() -> Result<()> { fn test_external_product() -> Result<()> {
const T: u64 = 2; // plaintext modulus
const K: usize = 32;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 1 },
k: 32,
t: 2, // plaintext modulus
};
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 64; let l: u32 = 64;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..50 { for _ in 0..50 {
let (sk, _) = TLWE::<K>::new_key(&mut rng)?;
let (sk, _) = TLWE::new_key(&mut rng, &param)?;
let m1: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?;
let p1: T64 = TLev::<K>::encode::<T>(&m1);
let m1: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLev::encode(&param, &m1);
let m2: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?;
let p2: T64 = TLWE::<K>::encode::<T>(&m2); // scaled by delta
let m2: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p2: T64 = TLWE::encode(&param, &m2); // scaled by delta
let tgsw = TGSW::<K>::encrypt_s(&mut rng, beta, l, &sk, &p1)?;
let tlwe = TLWE::<K>::encrypt_s(&mut rng, &sk, &p2)?;
let tgsw = TGSW::encrypt_s(&mut rng, &param, beta, l, &sk, &p1)?;
let tlwe = TLWE::encrypt_s(&mut rng, &param, &sk, &p2)?;
let res: TLWE<K> = tgsw * tlwe;
let res: TLWE = tgsw * tlwe;
// let p_recovered = res.decrypt(&sk, beta);
let p_recovered = res.decrypt(&sk); let p_recovered = res.decrypt(&sk);
// downscaled by delta^-1 // downscaled by delta^-1
let res_recovered = TLWE::<K>::decode::<T>(&p_recovered);
let res_recovered = TLWE::decode(&param, &p_recovered);
// assert_eq!(m1 * m2, m_recovered); // assert_eq!(m1 * m2, m_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), res_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), res_recovered);
} }
Ok(()) Ok(())
@ -140,35 +142,39 @@ mod tests {
#[test] #[test]
fn test_cmux() -> Result<()> { fn test_cmux() -> Result<()> {
const T: u64 = 2; // plaintext modulus
const K: usize = 32;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 1 },
k: 32,
t: 2, // plaintext modulus
};
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 64; let l: u32 = 64;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..50 { for _ in 0..50 {
let (sk, _) = TLWE::<K>::new_key(&mut rng)?;
let (sk, _) = TLWE::new_key(&mut rng, &param)?;
let m1: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?;
let p1: T64 = TLWE::<K>::encode::<T>(&m1); // scaled by delta
let m1: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLWE::encode(&param, &m1); // scaled by delta
let m2: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?;
let p2: T64 = TLWE::<K>::encode::<T>(&m2); // scaled by delta
let m2: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p2: T64 = TLWE::encode(&param, &m2); // scaled by delta
for bit_raw in 0..2 { for bit_raw in 0..2 {
let bit = TGSW::<K>::encrypt_s(&mut rng, beta, l, &sk, &T64(bit_raw))?;
let bit = TGSW::encrypt_s(&mut rng, &param, beta, l, &sk, &T64(bit_raw))?;
let c1 = TLWE::<K>::encrypt_s(&mut rng, &sk, &p1)?;
let c2 = TLWE::<K>::encrypt_s(&mut rng, &sk, &p2)?;
let c1 = TLWE::encrypt_s(&mut rng, &param, &sk, &p1)?;
let c2 = TLWE::encrypt_s(&mut rng, &param, &sk, &p2)?;
let res: TLWE<K> = TGSW::cmux(bit, c1, c2);
let res: TLWE = TGSW::cmux(bit, c1, c2);
let p_recovered = res.decrypt(&sk); let p_recovered = res.decrypt(&sk);
// downscaled by delta^-1 // downscaled by delta^-1
let res_recovered = TLWE::<K>::decode::<T>(&p_recovered);
let res_recovered = TLWE::decode(&param, &p_recovered);
if bit_raw == 0 { if bit_raw == 0 {
assert_eq!(m1, res_recovered); assert_eq!(m1, res_recovered);

+ 71
- 45
tfhe/src/tlev.rs

@ -1,35 +1,50 @@
use anyhow::Result; use anyhow::Result;
use itertools::zip_eq; use itertools::zip_eq;
use rand::Rng; use rand::Rng;
use std::array;
use std::ops::{Add, Mul};
use std::ops::Mul;
use arith::{Ring, Rq, Tn, T64, TR};
use arith::{Ring, RingParam, Rq, T64};
use crate::tglwe::TGLWE;
use crate::tlwe::{PublicKey, SecretKey, TLWE}; use crate::tlwe::{PublicKey, SecretKey, TLWE};
use gfhe::glwe::Param;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct TLev<const K: usize>(pub(crate) Vec<TLWE<K>>);
pub struct TLev(pub(crate) Vec<TLWE>);
impl TLev {
pub fn encode(param: &Param, m: &Rq) -> T64 {
assert_eq!(m.param.n, 1);
assert_eq!(param.t, m.param.q);
impl<const K: usize> TLev<K> {
pub fn encode<const T: u64>(m: &Rq<T, 1>) -> T64 {
let coeffs = m.coeffs(); let coeffs = m.coeffs();
T64(coeffs[0].0) // N=1, so take the only coeff
T64(coeffs[0].v) // N=1, so take the only coeff
} }
pub fn decode<const T: u64>(p: &T64) -> Rq<T, 1> {
Rq::<T, 1>::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect())
pub fn decode(param: &Param, p: &T64) -> Rq {
Rq::from_vec_u64(
&RingParam { q: param.t, n: 1 },
p.coeffs().iter().map(|c| c.0).collect(),
)
} }
pub fn encrypt( pub fn encrypt(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
beta: u32, beta: u32,
l: u32, l: u32,
pk: &PublicKey<K>,
pk: &PublicKey,
m: &T64, m: &T64,
) -> Result<Self> { ) -> Result<Self> {
let tlev: Vec<TLWE<K>> = (1..l + 1)
debug_assert_eq!(pk.1.k, param.k);
let tlev: Vec<TLWE> = (1..l as u64 + 1)
.map(|i| { .map(|i| {
TLWE::<K>::encrypt(&mut rng, pk, &(*m * (u64::MAX / beta.pow(i as u32) as u64)))
let aux = if i < 64 {
*m * (u64::MAX / (1u64 << i))
} else {
// 1<<64 would overflow, and anyways we're dividing u64::MAX
// by it, which would be equal to 1
*m
};
TLWE::encrypt(&mut rng, param, pk, &aux)
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
@ -37,12 +52,15 @@ impl TLev {
} }
pub fn encrypt_s( pub fn encrypt_s(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
_beta: u32, // TODO rm, and make beta=2 always _beta: u32, // TODO rm, and make beta=2 always
l: u32, l: u32,
sk: &SecretKey<K>,
sk: &SecretKey,
m: &T64, m: &T64,
) -> Result<Self> { ) -> Result<Self> {
let tlev: Vec<TLWE<K>> = (1..l as u64 + 1)
debug_assert_eq!(sk.0 .0.k, param.k);
let tlev: Vec<TLWE> = (1..l as u64 + 1)
.map(|i| { .map(|i| {
let aux = if i < 64 { let aux = if i < 64 {
*m * (u64::MAX / (1u64 << i)) *m * (u64::MAX / (1u64 << i))
@ -51,22 +69,22 @@ impl TLev {
// by it, which would be equal to 1 // by it, which would be equal to 1
*m *m
}; };
TLWE::<K>::encrypt_s(&mut rng, sk, &aux)
TLWE::encrypt_s(&mut rng, &param, sk, &aux)
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
Ok(Self(tlev)) Ok(Self(tlev))
} }
pub fn decrypt(&self, sk: &SecretKey<K>, beta: u32) -> T64 {
pub fn decrypt(&self, sk: &SecretKey, beta: u32) -> T64 {
let pt = self.0[0].decrypt(sk); let pt = self.0[0].decrypt(sk);
pt.mul_div_round(beta as u64, u64::MAX) pt.mul_div_round(beta as u64, u64::MAX)
} }
} }
// TODO review u64::MAX, since is -1 of the value we actually want // TODO review u64::MAX, since is -1 of the value we actually want
impl<const K: usize> TLev<K> {
pub fn iter(&self) -> std::slice::Iter<TLWE<K>> {
impl TLev {
pub fn iter(&self) -> std::slice::Iter<TLWE> {
self.0.iter() self.0.iter()
} }
} }
@ -74,14 +92,14 @@ impl TLev {
// dot product between a TLev and Vec<T64>, usually Vec<T64> comes from a // dot product between a TLev and Vec<T64>, usually Vec<T64> comes from a
// decomposition of T64 // decomposition of T64
// TLev * Vec<T64> --> TLWE // TLev * Vec<T64> --> TLWE
impl<const K: usize> Mul<Vec<T64>> for TLev<K> {
type Output = TLWE<K>;
impl Mul<Vec<T64>> for TLev {
type Output = TLWE;
fn mul(self, v: Vec<T64>) -> Self::Output { fn mul(self, v: Vec<T64>) -> Self::Output {
assert_eq!(self.0.len(), v.len()); assert_eq!(self.0.len(), v.len());
// l TLWES // l TLWES
let tlwes: Vec<TLWE<K>> = self.0;
let r: TLWE<K> = zip_eq(v, tlwes).map(|(a_d_i, glwe_i)| glwe_i * a_d_i).sum();
let tlwes: Vec<TLWE> = self.0;
let r: TLWE = zip_eq(v, tlwes).map(|(a_d_i, glwe_i)| glwe_i * a_d_i).sum();
r r
} }
} }
@ -95,27 +113,30 @@ mod tests {
#[test] #[test]
fn test_encrypt_decrypt() -> Result<()> { fn test_encrypt_decrypt() -> Result<()> {
const T: u64 = 2; // plaintext modulus
const K: usize = 16;
type S = TLev<K>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 2, // plaintext modulus
};
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 16; let l: u32 = 16;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = TLWE::<K>::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?;
let p: T64 = S::encode::<T>(&m); // plaintext
let m: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p: T64 = TLev::encode(&param, &m); // plaintext
let c = S::encrypt(&mut rng, beta, l, &pk, &p)?;
let c = TLev::encrypt(&mut rng, &param, beta, l, &pk, &p)?;
let p_recovered = c.decrypt(&sk, beta); let p_recovered = c.decrypt(&sk, beta);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = TLev::decode(&param, &p_recovered);
assert_eq!(m.remodule::<T>(), m_recovered.remodule::<T>());
assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t));
} }
Ok(()) Ok(())
@ -123,32 +144,37 @@ mod tests {
#[test] #[test]
fn test_tlev_vect64_product() -> Result<()> { fn test_tlev_vect64_product() -> Result<()> {
const T: u64 = 2; // plaintext modulus
const K: usize = 16;
let param = Param {
err_sigma: 0.1, // WIP
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 2, // plaintext modulus
};
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 16;
// let l: u32 = 16;
let l: u32 = 64;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = TLWE::<K>::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m1: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?;
let m2: Rq<T, 1> = Rq::rand_u64(&mut rng, msg_dist)?;
let p1: T64 = TLev::<K>::encode::<T>(&m1);
let p2: T64 = TLev::<K>::encode::<T>(&m2);
let m1: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2: Rq = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLev::encode(&param, &m1);
let p2: T64 = TLev::encode(&param, &m2);
let c1 = TLev::<K>::encrypt(&mut rng, beta, l, &pk, &p1)?;
let c1 = TLev::encrypt(&mut rng, &param, beta, l, &pk, &p1)?;
let c2 = p2.decompose(beta, l); let c2 = p2.decompose(beta, l);
let c3 = c1 * c2; let c3 = c1 * c2;
let p_recovered = c3.decrypt(&sk); let p_recovered = c3.decrypt(&sk);
let m_recovered = TLev::<K>::decode::<T>(&p_recovered);
let m_recovered = TLev::decode(&param, &p_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), m_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), m_recovered);
} }
Ok(()) Ok(())

+ 228
- 181
tfhe/src/tlwe.rs

@ -4,242 +4,275 @@ use rand::Rng;
use std::iter::Sum; use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Sub}; use std::ops::{Add, AddAssign, Mul, Sub};
use arith::{Ring, Rq, Tn, Zq, T64, TR};
use gfhe::{glwe, GLWE};
use arith::{Ring, RingParam, Rq, Tn, Zq, T64, TR};
use gfhe::{glwe, glwe::Param, GLWE};
use crate::tggsw::TGGSW; use crate::tggsw::TGGSW;
use crate::tlev::TLev; use crate::tlev::TLev;
use crate::{tglwe, tglwe::TGLWE}; use crate::{tglwe, tglwe::TGLWE};
pub struct SecretKey<const K: usize>(pub glwe::SecretKey<T64, K>);
pub struct SecretKey(pub glwe::SecretKey<T64>);
impl<const KN: usize> SecretKey<KN> {
impl SecretKey {
/// from TFHE [2018-421] paper: A TLWE key k \in B^n, can be interpreted as a /// from TFHE [2018-421] paper: A TLWE key k \in B^n, can be interpreted as a
/// TRLWE key K \in B_N[X]^k having the same sequence of coefficients and /// TRLWE key K \in B_N[X]^k having the same sequence of coefficients and
/// vice-versa. /// vice-versa.
pub fn to_tglwe<const N: usize, const K: usize>(self) -> crate::tglwe::SecretKey<N, K> {
let s: TR<T64, KN> = self.0 .0;
pub fn to_tglwe(self, param: &Param) -> crate::tglwe::SecretKey {
let s: TR<T64> = self.0 .0; // of length K*N
assert_eq!(s.r.len(), param.k * param.ring.n); // sanity check
// split into K vectors, and interpret each of them as a T_N[X]/(X^N+1) // split into K vectors, and interpret each of them as a T_N[X]/(X^N+1)
// polynomial // polynomial
let r: Vec<Tn<N>> =
s.0.chunks(N)
.map(|v| Tn::<N>::from_vec(v.to_vec()))
let r: Vec<Tn> =
s.r.chunks(param.ring.n)
.map(|v| Tn::from_vec(&param.ring, v.to_vec()))
.collect(); .collect();
crate::tglwe::SecretKey(glwe::SecretKey::<Tn<N>, K>(TR(r)))
crate::tglwe::SecretKey(glwe::SecretKey::<Tn>(TR { k: param.k, r }))
} }
} }
pub type PublicKey<const K: usize> = glwe::PublicKey<T64, K>;
pub type PublicKey = glwe::PublicKey<T64>;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct KSK<const K: usize>(Vec<TLev<K>>);
pub struct KSK(Vec<TLev>);
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct TLWE<const K: usize>(pub GLWE<T64, K>);
pub struct TLWE(pub GLWE<T64>);
impl<const K: usize> TLWE<K> {
pub fn zero() -> Self {
Self(GLWE::<T64, K>::zero())
impl TLWE {
pub fn zero(k: usize, ring_param: &RingParam) -> Self {
Self(GLWE::<T64>::zero(k, ring_param))
} }
pub fn new_key(rng: impl Rng) -> Result<(SecretKey<K>, PublicKey<K>)> {
let (sk, pk): (glwe::SecretKey<T64, K>, glwe::PublicKey<T64, K>) = GLWE::new_key(rng)?;
pub fn new_key(rng: impl Rng, param: &Param) -> Result<(SecretKey, PublicKey)> {
let (sk, pk): (glwe::SecretKey<T64>, glwe::PublicKey<T64>) = GLWE::new_key(rng, param)?;
Ok((SecretKey(sk), pk)) Ok((SecretKey(sk), pk))
} }
pub fn encode<const P: u64>(m: &Rq<P, 1>) -> T64 {
let delta = u64::MAX / P; // floored
pub fn encode(param: &Param, m: &Rq) -> T64 {
assert_eq!(param.ring.n, 1);
debug_assert_eq!(param.t, m.param.q); // plaintext modulus
let delta = u64::MAX / param.t; // floored
let coeffs = m.coeffs(); let coeffs = m.coeffs();
T64(coeffs[0].0 * delta)
T64(coeffs[0].v * delta)
} }
pub fn decode<const P: u64>(p: &T64) -> Rq<P, 1> {
let p = p.mul_div_round(P, u64::MAX);
Rq::<P, 1>::from_vec_u64(p.coeffs().iter().map(|c| c.0).collect())
pub fn decode(param: &Param, p: &T64) -> Rq {
let p = p.mul_div_round(param.t, u64::MAX);
Rq::from_vec_u64(&param.pt(), p.coeffs().iter().map(|c| c.0).collect())
}
/// encodes the given message as a TLWE constant/public value, for using it
/// in ct-pt-multiplication.
pub fn new_const(param: &Param, m: &Rq) -> T64 {
debug_assert_eq!(param.t, m.param.q);
T64(m.coeffs()[0].v)
} }
// encrypts with the given SecretKey (instead of PublicKey) // encrypts with the given SecretKey (instead of PublicKey)
pub fn encrypt_s(rng: impl Rng, sk: &SecretKey<K>, p: &T64) -> Result<Self> {
let glwe = GLWE::encrypt_s(rng, &sk.0, p)?;
pub fn encrypt_s(rng: impl Rng, param: &Param, sk: &SecretKey, p: &T64) -> Result<Self> {
let glwe = GLWE::encrypt_s(rng, param, &sk.0, p)?;
Ok(Self(glwe)) Ok(Self(glwe))
} }
pub fn encrypt(rng: impl Rng, pk: &PublicKey<K>, p: &T64) -> Result<Self> {
let glwe = GLWE::encrypt(rng, &pk, p)?;
pub fn encrypt(rng: impl Rng, param: &Param, pk: &PublicKey, p: &T64) -> Result<Self> {
let glwe = GLWE::encrypt(rng, param, pk, p)?;
Ok(Self(glwe)) Ok(Self(glwe))
} }
pub fn decrypt(&self, sk: &SecretKey<K>) -> T64 {
pub fn decrypt(&self, sk: &SecretKey) -> T64 {
self.0.decrypt(&sk.0) self.0.decrypt(&sk.0)
} }
pub fn new_ksk( pub fn new_ksk(
mut rng: impl Rng, mut rng: impl Rng,
param: &Param,
beta: u32, beta: u32,
l: u32, l: u32,
sk: &SecretKey<K>,
new_sk: &SecretKey<K>,
) -> Result<KSK<K>> {
let r: Vec<TLev<K>> = (0..K)
sk: &SecretKey,
new_sk: &SecretKey,
) -> Result<KSK> {
let r: Vec<TLev> = (0..param.k)
.into_iter() .into_iter()
.map(|i| .map(|i|
// treat sk_i as the msg being encrypted // treat sk_i as the msg being encrypted
TLev::<K>::encrypt_s(&mut rng, beta, l, &new_sk, &sk.0.0 .0[i]))
TLev::encrypt_s(&mut rng, param, beta, l, &new_sk, &sk.0.0 .r[i]))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
Ok(KSK(r)) Ok(KSK(r))
} }
pub fn key_switch(&self, beta: u32, l: u32, ksk: &KSK<K>) -> Self {
let (a, b): (TR<T64, K>, T64) = (self.0 .0.clone(), self.0 .1);
pub fn key_switch(&self, param: &Param, beta: u32, l: u32, ksk: &KSK) -> Self {
let (a, b): (TR<T64>, T64) = (self.0 .0.clone(), self.0 .1);
let lhs: TLWE<K> = TLWE(GLWE(TR::zero(), b));
let lhs: TLWE = TLWE(GLWE(TR::zero(param.k * param.ring.n, &param.ring), b));
// K iterations, ksk.0 contains K times GLev // K iterations, ksk.0 contains K times GLev
let rhs: TLWE<K> = zip_eq(a.0, ksk.0.clone())
let rhs: TLWE = zip_eq(a.r, ksk.0.clone())
.map(|(a_i, ksk_i)| ksk_i * a_i.decompose(beta, l)) // dot_product .map(|(a_i, ksk_i)| ksk_i * a_i.decompose(beta, l)) // dot_product
.sum(); .sum();
lhs - rhs lhs - rhs
} }
// modulus switch from Q (2^64) to Q2 (in blind_rotation Q2=K*N) // modulus switch from Q (2^64) to Q2 (in blind_rotation Q2=K*N)
pub fn mod_switch<const Q2: u64>(&self) -> Self {
let a: TR<T64, K> = self.0 .0.mod_switch::<Q2>();
let b: T64 = self.0 .1.mod_switch::<Q2>();
pub fn mod_switch(&self, q2: u64) -> Self {
let a: TR<T64> = self.0 .0.mod_switch(q2);
let b: T64 = self.0 .1.mod_switch(q2);
Self(GLWE(a, b)) Self(GLWE(a, b))
} }
} }
// NOTE: the ugly const generics are temporary
pub fn blind_rotation<const N: usize, const K: usize, const KN: usize, const KN2: u64>(
c: TLWE<KN>,
btk: BootstrappingKey<N, K, KN>,
table: TGLWE<N, K>,
) -> TGLWE<N, K> {
let c_kn: TLWE<KN> = c.mod_switch::<KN2>();
let (a, b): (TR<T64, KN>, T64) = (c_kn.0 .0, c_kn.0 .1);
pub fn blind_rotation(
param: &Param,
c: TLWE, // kn
btk: BootstrappingKey,
table: TGLWE, // n,k
) -> TGLWE {
debug_assert_eq!(c.0 .0.k, param.k);
// TODO replace `param.k*param.ring.n` by `param.kn()`
let c_kn: TLWE = c.mod_switch((param.k * param.ring.n) as u64);
let (a, b): (TR<T64>, T64) = (c_kn.0 .0, c_kn.0 .1);
// two main parts: rotate by a known power of X, rotate by a secret // two main parts: rotate by a known power of X, rotate by a secret
// power of X (using the C gate) // power of X (using the C gate)
// table * X^-b, ie. left rotate // table * X^-b, ie. left rotate
let v_xb: TGLWE<N, K> = table.left_rotate(b.0 as usize);
let v_xb: TGLWE = table.left_rotate(b.0 as usize);
// rotate by a secret power of X using the cmux gate // rotate by a secret power of X using the cmux gate
let mut c_j: TGLWE<N, K> = v_xb.clone();
let _ = (1..K).map(|j| {
c_j = TGGSW::<N, K>::cmux(
let mut c_j: TGLWE = v_xb.clone();
let _ = (1..param.k).map(|j| {
c_j = TGGSW::cmux(
btk.0[j].clone(), btk.0[j].clone(),
c_j.clone(), c_j.clone(),
c_j.clone().left_rotate(a.0[j].0 as usize),
c_j.clone().left_rotate(a.r[j].0 as usize),
); );
dbg!(&c_j);
}); });
c_j c_j
} }
pub fn bootstrapping<const N: usize, const K: usize, const KN: usize, const KN2: u64>(
btk: BootstrappingKey<N, K, KN>,
table: TGLWE<N, K>,
c: TLWE<KN>,
) -> TLWE<KN> {
let rotated: TGLWE<N, K> = blind_rotation::<N, K, KN, KN2>(c, btk.clone(), table);
let c_h: TLWE<KN> = rotated.sample_extraction(0);
let r = c_h.key_switch(2, 64, &btk.1);
pub fn bootstrapping(
param: &Param,
btk: BootstrappingKey,
table: TGLWE,
c: TLWE, // kn
) -> TLWE {
// kn
let rotated: TGLWE = blind_rotation(param, c, btk.clone(), table);
let c_h: TLWE = rotated.sample_extraction(&param, 0);
let r = c_h.key_switch(param, 2, 64, &btk.1);
r r
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct BootstrappingKey<const N: usize, const K: usize, const KN: usize>(
pub Vec<TGGSW<N, K>>,
pub KSK<KN>,
pub struct BootstrappingKey(
pub Vec<TGGSW>,
pub KSK, // kn
); );
impl<const N: usize, const K: usize, const KN: usize> BootstrappingKey<N, K, KN> {
pub fn from_sk(mut rng: impl Rng, sk: &tglwe::SecretKey<N, K>) -> Result<Self> {
impl BootstrappingKey {
pub fn from_sk(mut rng: impl Rng, param: &Param, sk: &tglwe::SecretKey) -> Result<Self> {
let (beta, l) = (2u32, 64u32); // TMP let (beta, l) = (2u32, 64u32); // TMP
//
let s: TR<Tn<N>, K> = sk.0 .0.clone();
let (sk2, _) = TLWE::<KN>::new_key(&mut rng)?; // TLWE<KN> compatible with TGLWE<N,K>
let s: TR<Tn> = sk.0 .0.clone();
let (sk2, _) = TLWE::new_key(&mut rng, &param.lwe())?; // TLWE<KN> compatible with TGLWE<N,K>
// each btk_j = TGGSW_sk(s_i) // each btk_j = TGGSW_sk(s_i)
let btk: Vec<TGGSW<N, K>> = s
let btk: Vec<TGGSW> = s
.iter() .iter()
.map(|s_i| TGGSW::<N, K>::encrypt_s(&mut rng, beta, l, sk, s_i))
.map(|s_i| TGGSW::encrypt_s(&mut rng, param, beta, l, sk, s_i))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let ksk = TLWE::<KN>::new_ksk(&mut rng, beta, l, &sk.to_tlwe(), &sk2)?;
let ksk = TLWE::new_ksk(
&mut rng,
&param.lwe(),
beta,
l,
&sk.to_tlwe(&param.lwe()), // converted to length k*n
&sk2, // created with length k*n
)?;
debug_assert_eq!(ksk.0.len(), param.lwe().k);
debug_assert_eq!(ksk.0.len(), param.k * param.ring.n);
Ok(Self(btk, ksk)) Ok(Self(btk, ksk))
} }
} }
pub fn compute_lookup_table<const T: u64, const K: usize, const N: usize>() -> TGLWE<N, K> {
pub fn compute_lookup_table(param: &Param) -> TGLWE {
// from 2021-1402: // from 2021-1402:
// v(x) = \sum_j^{N-1} [(p_j / 2N mod p)/p] X^j // v(x) = \sum_j^{N-1} [(p_j / 2N mod p)/p] X^j
// matrix of coefficients with size K*N = delta x T // matrix of coefficients with size K*N = delta x T
let delta: usize = N / T as usize;
let values: Vec<Zq<T>> = (0..T).map(|v| Zq::<T>::from_u64(v)).collect();
let coeffs: Vec<Zq<T>> = (0..T as usize)
let delta: usize = param.ring.n / param.t as usize;
let values: Vec<Zq> = (0..param.t).map(|v| Zq::from_u64(param.t, v)).collect();
let coeffs: Vec<Zq> = (0..param.t as usize)
.flat_map(|i| vec![values[i]; delta]) .flat_map(|i| vec![values[i]; delta])
.collect(); .collect();
let table = Rq::<T, N>::from_vec(coeffs);
let table = Rq::from_vec(&param.pt(), coeffs);
// encode the table as plaintext // encode the table as plaintext
let v: Tn<N> = TGLWE::<N, K>::encode::<T>(&table);
let v: Tn = TGLWE::encode(param, &table);
// encode the table as TGLWE ciphertext // encode the table as TGLWE ciphertext
let v: TGLWE<N, K> = TGLWE::<N, K>::from_plaintext(v);
let v: TGLWE = TGLWE::from_plaintext(param.k, &param.ring, v);
v v
} }
impl<const K: usize> Add<TLWE<K>> for TLWE<K> {
impl Add<TLWE> for TLWE {
type Output = Self; type Output = Self;
fn add(self, other: Self) -> Self { fn add(self, other: Self) -> Self {
debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
Self(self.0 + other.0) Self(self.0 + other.0)
} }
} }
impl<const K: usize> AddAssign for TLWE<K> {
impl AddAssign for TLWE {
fn add_assign(&mut self, rhs: Self) { fn add_assign(&mut self, rhs: Self) {
debug_assert_eq!(self.0 .0.k, rhs.0 .0.k);
debug_assert_eq!(self.0 .1.param(), rhs.0 .1.param());
self.0 += rhs.0 self.0 += rhs.0
} }
} }
impl<const K: usize> Sum<TLWE<K>> for TLWE<K> {
fn sum<I>(iter: I) -> Self
impl Sum<TLWE> for TLWE {
fn sum<I>(mut iter: I) -> Self
where where
I: Iterator<Item = Self>, I: Iterator<Item = Self>,
{ {
let mut acc = TLWE::<K>::zero();
for e in iter {
acc += e;
}
acc
let first = iter.next().unwrap();
iter.fold(first, |acc, e| acc + e)
} }
} }
impl<const K: usize> Sub<TLWE<K>> for TLWE<K> {
impl Sub<TLWE> for TLWE {
type Output = Self; type Output = Self;
fn sub(self, other: Self) -> Self { fn sub(self, other: Self) -> Self {
debug_assert_eq!(self.0 .0.k, other.0 .0.k);
debug_assert_eq!(self.0 .1.param(), other.0 .1.param());
Self(self.0 - other.0) Self(self.0 - other.0)
} }
} }
// plaintext addition // plaintext addition
impl<const K: usize> Add<T64> for TLWE<K> {
impl Add<T64> for TLWE {
type Output = Self; type Output = Self;
fn add(self, plaintext: T64) -> Self { fn add(self, plaintext: T64) -> Self {
let a: TR<T64, K> = self.0 .0;
let a: TR<T64> = self.0 .0;
let b: T64 = self.0 .1 + plaintext; let b: T64 = self.0 .1 + plaintext;
Self(GLWE(a, b)) Self(GLWE(a, b))
} }
} }
// plaintext substraction // plaintext substraction
impl<const K: usize> Sub<T64> for TLWE<K> {
impl Sub<T64> for TLWE {
type Output = Self; type Output = Self;
fn sub(self, plaintext: T64) -> Self { fn sub(self, plaintext: T64) -> Self {
let a: TR<T64, K> = self.0 .0;
let a: TR<T64> = self.0 .0;
let b: T64 = self.0 .1 - plaintext; let b: T64 = self.0 .1 - plaintext;
Self(GLWE(a, b)) Self(GLWE(a, b))
} }
} }
// plaintext multiplication // plaintext multiplication
impl<const K: usize> Mul<T64> for TLWE<K> {
impl Mul<T64> for TLWE {
type Output = Self; type Output = Self;
fn mul(self, plaintext: T64) -> Self { fn mul(self, plaintext: T64) -> Self {
let a: TR<T64, K> = TR(self.0 .0 .0.iter().map(|r_i| *r_i * plaintext).collect());
let a: TR<T64> = TR {
k: self.0 .0.k,
r: self.0 .0.r.iter().map(|r_i| *r_i * plaintext).collect(),
};
let b: T64 = self.0 .1 * plaintext; let b: T64 = self.0 .1 * plaintext;
Self(GLWE(a, b)) Self(GLWE(a, b))
} }
@ -255,29 +288,32 @@ mod tests {
#[test] #[test]
fn test_encrypt_decrypt() -> Result<()> { fn test_encrypt_decrypt() -> Result<()> {
const T: u64 = 128; // msg space (msg modulus)
const K: usize = 16;
type S = TLWE<K>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let p: T64 = S::encode::<T>(&m);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p: T64 = TLWE::encode(&param, &m);
let c = S::encrypt(&mut rng, &pk, &p)?;
let c = TLWE::encrypt(&mut rng, &param, &pk, &p)?;
let p_recovered = c.decrypt(&sk); let p_recovered = c.decrypt(&sk);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
// same but using encrypt_s (with sk instead of pk)) // same but using encrypt_s (with sk instead of pk))
let c = S::encrypt_s(&mut rng, &sk, &p)?;
let c = TLWE::encrypt_s(&mut rng, &param, &sk, &p)?;
let p_recovered = c.decrypt(&sk); let p_recovered = c.decrypt(&sk);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
} }
@ -287,30 +323,33 @@ mod tests {
#[test] #[test]
fn test_addition() -> Result<()> { fn test_addition() -> Result<()> {
const T: u64 = 128;
const K: usize = 16;
type S = TLWE<K>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let p1: T64 = S::encode::<T>(&m1); // plaintext
let p2: T64 = S::encode::<T>(&m2); // plaintext
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLWE::encode(&param, &m1); // plaintext
let p2: T64 = TLWE::encode(&param, &m2); // plaintext
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c2 = S::encrypt(&mut rng, &pk, &p2)?;
let c1 = TLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c2 = TLWE::encrypt(&mut rng, &param, &pk, &p2)?;
let c3 = c1 + c2; let c3 = c1 + c2;
let p3_recovered = c3.decrypt(&sk); let p3_recovered = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
let m3_recovered = TLWE::decode(&param, &p3_recovered);
assert_eq!((m1 + m2).remodule::<T>(), m3_recovered.remodule::<T>());
assert_eq!((m1 + m2).remodule(param.t), m3_recovered.remodule(param.t));
} }
Ok(()) Ok(())
@ -318,27 +357,30 @@ mod tests {
#[test] #[test]
fn test_add_plaintext() -> Result<()> { fn test_add_plaintext() -> Result<()> {
const T: u64 = 128;
const K: usize = 16;
type S = TLWE<K>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let p1: T64 = S::encode::<T>(&m1); // plaintext
let p2: T64 = S::encode::<T>(&m2); // plaintext
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLWE::encode(&param, &m1); // plaintext
let p2: T64 = TLWE::encode(&param, &m2); // plaintext
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c1 = TLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c3 = c1 + p2; let c3 = c1 + p2;
let p3_recovered = c3.decrypt(&sk); let p3_recovered = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
let m3_recovered = TLWE::decode(&param, &p3_recovered);
assert_eq!(m1 + m2, m3_recovered); assert_eq!(m1 + m2, m3_recovered);
} }
@ -348,30 +390,31 @@ mod tests {
#[test] #[test]
fn test_mul_plaintext() -> Result<()> { fn test_mul_plaintext() -> Result<()> {
const T: u64 = 128;
const K: usize = 16;
type S = TLWE<K>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let msg_dist = Uniform::new(0_u64, T);
let msg_dist = Uniform::new(0_u64, param.t);
for _ in 0..200 { for _ in 0..200 {
let (sk, pk) = S::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let m1 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let m2 = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let p1: T64 = S::encode::<T>(&m1);
// don't scale up p2, set it directly from m2
// let p2: T64 = Tn(array::from_fn(|i| T64(m2.coeffs()[i].0)));
let p2: T64 = T64(m2.coeffs()[0].0);
let m1 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let m2 = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p1: T64 = TLWE::encode(&param, &m1);
let p2: T64 = TLWE::new_const(&param, &m2); // as constant/public value
let c1 = S::encrypt(&mut rng, &pk, &p1)?;
let c1 = TLWE::encrypt(&mut rng, &param, &pk, &p1)?;
let c3 = c1 * p2; let c3 = c1 * p2;
let p3_recovered: T64 = c3.decrypt(&sk); let p3_recovered: T64 = c3.decrypt(&sk);
let m3_recovered = S::decode::<T>(&p3_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq::<T>(), m3_recovered);
let m3_recovered = TLWE::decode(&param, &p3_recovered);
assert_eq!((m1.to_r() * m2.to_r()).to_rq(param.t), m3_recovered);
} }
Ok(()) Ok(())
@ -379,38 +422,41 @@ mod tests {
#[test] #[test]
fn test_key_switch() -> Result<()> { fn test_key_switch() -> Result<()> {
const T: u64 = 128; // plaintext modulus
const K: usize = 16;
type S = TLWE<K>;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam { q: u64::MAX, n: 1 },
k: 16,
t: 128, // plaintext modulus
};
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 64; let l: u32 = 64;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let (sk, pk) = S::new_key(&mut rng)?;
let (sk2, _) = S::new_key(&mut rng)?;
let (sk, pk) = TLWE::new_key(&mut rng, &param)?;
let (sk2, _) = TLWE::new_key(&mut rng, &param)?;
// ksk to switch from sk to sk2 // ksk to switch from sk to sk2
let ksk = S::new_ksk(&mut rng, beta, l, &sk, &sk2)?;
let ksk = TLWE::new_ksk(&mut rng, &param, beta, l, &sk, &sk2)?;
let msg_dist = Uniform::new(0_u64, param.t);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.pt())?;
let p = TLWE::encode(&param, &m); // plaintext
let msg_dist = Uniform::new(0_u64, T);
let m = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
let p = S::encode::<T>(&m); // plaintext
//
let c = S::encrypt_s(&mut rng, &sk, &p)?;
let c = TLWE::encrypt_s(&mut rng, &param, &sk, &p)?;
let c2 = c.key_switch(beta, l, &ksk);
let c2 = c.key_switch(&param, beta, l, &ksk);
// decrypt with the 2nd secret key // decrypt with the 2nd secret key
let p_recovered = c2.decrypt(&sk2); let p_recovered = c2.decrypt(&sk2);
let m_recovered = S::decode::<T>(&p_recovered);
assert_eq!(m.remodule::<T>(), m_recovered.remodule::<T>());
let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m.remodule(param.t), m_recovered.remodule(param.t));
// do the same but now encrypting with pk // do the same but now encrypting with pk
let c = S::encrypt(&mut rng, &pk, &p)?;
let c2 = c.key_switch(beta, l, &ksk);
let c = TLWE::encrypt(&mut rng, &param, &pk, &p)?;
let c2 = c.key_switch(&param, beta, l, &ksk);
let p_recovered = c2.decrypt(&sk2); let p_recovered = c2.decrypt(&sk2);
let m_recovered = S::decode::<T>(&p_recovered);
let m_recovered = TLWE::decode(&param, &p_recovered);
assert_eq!(m, m_recovered); assert_eq!(m, m_recovered);
Ok(()) Ok(())
@ -418,39 +464,40 @@ mod tests {
#[test] #[test]
fn test_bootstrapping() -> Result<()> { fn test_bootstrapping() -> Result<()> {
const T: u64 = 128; // plaintext modulus
const K: usize = 1;
const N: usize = 1024;
const KN: usize = K * N;
let param = Param {
err_sigma: crate::ERR_SIGMA,
ring: RingParam {
q: u64::MAX,
n: 1024,
},
k: 1,
t: 128, // plaintext modulus
};
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let start = Instant::now(); let start = Instant::now();
let table: TGLWE<N, K> = compute_lookup_table::<T, K, N>();
let table: TGLWE = compute_lookup_table(&param);
println!("table took: {:?}", start.elapsed()); println!("table took: {:?}", start.elapsed());
let (sk, _) = TGLWE::<N, K>::new_key::<KN>(&mut rng)?;
let sk_tlwe: SecretKey<KN> = sk.to_tlwe::<KN>();
let (sk, _) = TGLWE::new_key(&mut rng, &param)?;
let sk_tlwe: SecretKey = sk.to_tlwe(&param);
let start = Instant::now(); let start = Instant::now();
let btk = BootstrappingKey::<N, K, KN>::from_sk(&mut rng, &sk)?;
let btk = BootstrappingKey::from_sk(&mut rng, &param, &sk)?;
println!("btk took: {:?}", start.elapsed()); println!("btk took: {:?}", start.elapsed());
let msg_dist = Uniform::new(0_u64, T);
let m = Rq::<T, 1>::rand_u64(&mut rng, msg_dist)?;
dbg!(&m);
let p = TLWE::<K>::encode::<T>(&m); // plaintext
let msg_dist = Uniform::new(0_u64, param.t);
let m = Rq::rand_u64(&mut rng, msg_dist, &param.lwe().pt())?; // q=t, n=1
let p = TLWE::encode(&param.lwe(), &m); // plaintext
let c = TLWE::<KN>::encrypt_s(&mut rng, &sk_tlwe, &p)?;
let c = TLWE::encrypt_s(&mut rng, &param.lwe(), &sk_tlwe, &p)?;
let start = Instant::now(); let start = Instant::now();
// the ugly const generics are temporary
let bootstrapped: TLWE<KN> =
bootstrapping::<N, K, KN, { K as u64 * N as u64 }>(btk, table, c);
let bootstrapped: TLWE = bootstrapping(&param, btk, table, c);
println!("bootstrapping took: {:?}", start.elapsed()); println!("bootstrapping took: {:?}", start.elapsed());
let p_recovered: T64 = bootstrapped.decrypt(&sk_tlwe); let p_recovered: T64 = bootstrapped.decrypt(&sk_tlwe);
let m_recovered = TLWE::<KN>::decode::<T>(&p_recovered);
dbg!(&m_recovered);
let m_recovered = TLWE::decode(&param.lwe(), &p_recovered);
assert_eq!(m_recovered, m); assert_eq!(m_recovered, m);
Ok(()) Ok(())

Loading…
Cancel
Save