Browse Source

arith: get rid of constant generics. Reason:

using constant generics was great for allocating the arrays in the
stack, which is faster, but when started to use bigger parameter values,
in some cases it was overflowing the stack. This commit removes all the
constant generics in all of the `arith` crate, which in some cases slows
a bit the performance, but allows for bigger parameter values (on the
ones that affect lengths, like N and K).
pull/2/head
arnaucube 3 months ago
parent
commit
d484f29b17
12 changed files with 1312 additions and 686 deletions
  1. +1
    -1
      README.md
  2. +1
    -0
      arith/Cargo.toml
  3. +7
    -7
      arith/src/lib.rs
  4. +119
    -66
      arith/src/ntt.rs
  5. +187
    -0
      arith/src/ntt_fixedsize.rs
  6. +9
    -8
      arith/src/ring.rs
  7. +197
    -146
      arith/src/ring_n.rs
  8. +353
    -199
      arith/src/ring_nq.rs
  9. +179
    -95
      arith/src/ring_torus.rs
  10. +12
    -10
      arith/src/torus.rs
  11. +79
    -40
      arith/src/tuple_ring.rs
  12. +168
    -114
      arith/src/zq.rs

+ 1
- 1
README.md

@ -62,7 +62,7 @@ let m4 = S::decode::(&p4_recovered);
- external products of ciphertexts - external products of ciphertexts
- TGSW x TLWE - TGSW x TLWE
- TGGSW x TGLWE - TGGSW x TGLWE
- TGSW & TGGSW CMux gate
- {TGSW, TGGSW} CMux gate
- blind rotation, key switching, mod switching - blind rotation, key switching, mod switching
- bootstrapping - bootstrapping
- CKKS - CKKS

+ 1
- 0
arith/Cargo.toml

@ -8,6 +8,7 @@ anyhow = { workspace = true }
rand = { workspace = true } rand = { workspace = true }
rand_distr = { workspace = true } rand_distr = { workspace = true }
itertools = { workspace = true } itertools = { workspace = true }
lazy_static = "1.5.0"
# TMP: the next 4 imports are TMP, to solve systems of linear equations. Used # TMP: the next 4 imports are TMP, to solve systems of linear equations. Used
# for the CKKS encoding step, probably remvoed once in ckks the encoding is done # for the CKKS encoding step, probably remvoed once in ckks the encoding is done

+ 7
- 7
arith/src/lib.rs

@ -6,29 +6,29 @@
pub mod complex; pub mod complex;
pub mod matrix; pub mod matrix;
pub mod torus;
// pub mod torus;
pub mod zq; pub mod zq;
pub mod ring; pub mod ring;
pub mod ring_n; pub mod ring_n;
pub mod ring_nq; pub mod ring_nq;
pub mod ring_torus;
pub mod tuple_ring;
// pub mod ring_torus;
// pub mod tuple_ring;
mod naive_ntt; // note: for dev only
// mod naive_ntt; // note: for dev only
pub mod ntt; pub mod ntt;
// expose objects // expose objects
pub use complex::C; pub use complex::C;
pub use matrix::Matrix; pub use matrix::Matrix;
pub use torus::T64;
// pub use torus::T64;
pub use zq::Zq; pub use zq::Zq;
pub use ring::Ring; pub use ring::Ring;
pub use ring_n::R; pub use ring_n::R;
pub use ring_nq::Rq; pub use ring_nq::Rq;
pub use ring_torus::Tn;
pub use tuple_ring::TR;
// pub use ring_torus::Tn;
// pub use tuple_ring::TR;
pub use ntt::NTT; pub use ntt::NTT;

+ 119
- 66
arith/src/ntt.rs

@ -1,34 +1,68 @@
//! Implementation of the NTT & iNTT, following the CT & GS algorighms, more details in //! Implementation of the NTT & iNTT, following the CT & GS algorighms, more details in
//! https://eprint.iacr.org/2017/727.pdf, some notes at //! https://eprint.iacr.org/2017/727.pdf, some notes at
//! https://github.com/arnaucube/math/blob/master/notes_ntt.pdf . //! https://github.com/arnaucube/math/blob/master/notes_ntt.pdf .
use crate::zq::Zq;
//!
//! NOTE: initially I implemented it with fixed Q & N, given as constant
//! generics; but once using real-world parameters, the stack could not handle
//! it, so moved to use Vec instead of fixed-sized arrays, and adapted the NTT
//! implementation to that too.
use crate::{ring::Ring, ring_nq::Rq, zq::Zq};
use std::collections::HashMap;
#[derive(Debug)] #[derive(Debug)]
pub struct NTT<const Q: u64, const N: usize> {}
impl<const Q: u64, const N: usize> NTT<Q, N> {
const N_INV: Zq<Q> = Zq(const_inv_mod::<Q>(N as u64));
// since we work over Zq[X]/(X^N+1) (negacyclic), get the 2*N-th root of unity
pub(crate) const ROOT_OF_UNITY: u64 = primitive_root_of_unity::<Q>(2 * N);
pub(crate) const ROOTS_OF_UNITY: [Zq<Q>; N] = roots_of_unity(Self::ROOT_OF_UNITY);
const ROOTS_OF_UNITY_INV: [Zq<Q>; N] = roots_of_unity_inv(Self::ROOTS_OF_UNITY);
pub struct NTT {}
use std::sync::{Mutex, OnceLock};
static CACHE: OnceLock<Mutex<HashMap<(u64, usize), (Vec<Zq>, Vec<Zq>, Zq)>>> = OnceLock::new();
fn roots(q: u64, n: usize) -> (Vec<Zq>, Vec<Zq>, Zq) {
// Initialize CACHE with an empty HashMap on first use
let cache_lock = CACHE.get_or_init(|| Mutex::new(HashMap::new()));
// Lock the HashMap for this thread
let mut cache = cache_lock.lock().unwrap();
if let Some(value) = cache.get(&(q, n)) {
// Found an existing value — return a clone
return value.clone();
}
// Not found — compute the new triple
let n_inv: Zq = Zq {
q,
v: const_inv_mod(q, n as u64),
};
let root_of_unity: u64 = primitive_root_of_unity(q, 2 * n);
let roots_of_unity: Vec<Zq> = roots_of_unity(q, n, root_of_unity);
let roots_of_unity_inv: Vec<Zq> = roots_of_unity_inv(q, n, roots_of_unity.clone());
let value = (roots_of_unity, roots_of_unity_inv, n_inv);
// Store and return
cache.insert((q, n), value.clone());
value
} }
impl<const Q: u64, const N: usize> NTT<Q, N> {
impl NTT {
/// implements the Cooley-Tukey (CT) algorithm. Details at /// implements the Cooley-Tukey (CT) algorithm. Details at
/// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.1 of /// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.1 of
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf /// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
pub fn ntt(a: [Zq<Q>; N]) -> [Zq<Q>; N] {
let mut t = N / 2;
pub fn ntt(a: &Rq) -> Rq {
let (q, n) = (a.q, a.n);
let (roots_of_unity, _, _) = roots(q, n);
let mut t = n / 2;
let mut m = 1; let mut m = 1;
let mut r: [Zq<Q>; N] = a.clone();
while m < N {
let mut r: Vec<Zq> = a.coeffs.clone();
while m < n {
let mut k = 0; let mut k = 0;
for i in 0..m { for i in 0..m {
let S: Zq<Q> = Self::ROOTS_OF_UNITY[m + i];
let S: Zq = roots_of_unity[m + i];
for j in k..k + t { for j in k..k + t {
let U: Zq<Q> = r[j];
let V: Zq<Q> = r[j + t] * S;
let U: Zq = r[j];
let V: Zq = r[j + t] * S;
r[j] = U + V; r[j] = U + V;
r[j + t] = U - V; r[j + t] = U - V;
} }
@ -37,23 +71,32 @@ impl NTT {
t /= 2; t /= 2;
m *= 2; m *= 2;
} }
r
// Rq::from_vec((a.q, n), r)
Rq {
q,
n,
coeffs: r,
evals: None,
}
} }
/// implements the Cooley-Tukey (CT) algorithm. Details at /// implements the Cooley-Tukey (CT) algorithm. Details at
/// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.2 of /// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.2 of
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf /// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
pub fn intt(a: [Zq<Q>; N]) -> [Zq<Q>; N] {
pub fn intt(a: &Rq) -> Rq {
let (q, n) = (a.q, a.n);
let (_, roots_of_unity_inv, n_inv) = roots(q, n);
let mut t = 1; let mut t = 1;
let mut m = N / 2;
let mut r: [Zq<Q>; N] = a.clone();
let mut m = n / 2;
let mut r: Vec<Zq> = a.coeffs.clone();
while m > 0 { while m > 0 {
let mut k = 0; let mut k = 0;
for i in 0..m { for i in 0..m {
let S: Zq<Q> = Self::ROOTS_OF_UNITY_INV[m + i];
let S: Zq = roots_of_unity_inv[m + i];
for j in k..k + t { for j in k..k + t {
let U: Zq<Q> = r[j];
let V: Zq<Q> = r[j + t];
let U: Zq = r[j];
let V: Zq = r[j + t];
r[j] = U + V; r[j] = U + V;
r[j + t] = (U - V) * S; r[j + t] = (U - V) * S;
} }
@ -62,26 +105,32 @@ impl NTT {
t *= 2; t *= 2;
m /= 2; m /= 2;
} }
for i in 0..N {
r[i] = r[i] * Self::N_INV;
for i in 0..n {
r[i] = r[i] * n_inv;
}
// Rq::from_vec((a.q, n), r)
Rq {
q,
n,
coeffs: r,
evals: None,
} }
r
} }
} }
/// computes a primitive N-th root of unity using the method described by Thomas /// computes a primitive N-th root of unity using the method described by Thomas
/// Pornin in https://crypto.stackexchange.com/a/63616 /// Pornin in https://crypto.stackexchange.com/a/63616
const fn primitive_root_of_unity<const Q: u64>(N: usize) -> u64 {
assert!(N.is_power_of_two());
assert!((Q - 1) % N as u64 == 0);
const fn primitive_root_of_unity(q: u64, n: usize) -> u64 {
assert!(n.is_power_of_two());
assert!((q - 1) % n as u64 == 0);
let n_u64 = n as u64;
let n: u64 = N as u64;
let mut k = 1; let mut k = 1;
while k < Q {
while k < q {
// alternatively could get a random k at each iteration, if so, add the following if: // alternatively could get a random k at each iteration, if so, add the following if:
// `if k == 0 { continue; }` // `if k == 0 { continue; }`
let w = const_exp_mod::<Q>(k, (Q - 1) / n);
if const_exp_mod::<Q>(w, n / 2) != 1 {
let w = const_exp_mod(q, k, (q - 1) / n_u64);
if const_exp_mod(q, w, n_u64 / 2) != 1 {
return w; // w is a primitive N-th root of unity return w; // w is a primitive N-th root of unity
} }
k += 1; k += 1;
@ -89,52 +138,58 @@ const fn primitive_root_of_unity(N: usize) -> u64 {
panic!("No primitive root of unity"); panic!("No primitive root of unity");
} }
const fn roots_of_unity<const Q: u64, const N: usize>(w: u64) -> [Zq<Q>; N] {
let mut r: [Zq<Q>; N] = [Zq(0u64); N];
fn roots_of_unity(q: u64, n: usize, w: u64) -> Vec<Zq> {
let mut r: Vec<Zq> = vec![Zq { q, v: 0 }; n];
let mut i = 0; let mut i = 0;
let log_n = N.ilog2();
while i < N {
let log_n = n.ilog2();
while i < n {
// (return the roots in bit-reverset order) // (return the roots in bit-reverset order)
let j = ((i as u64).reverse_bits() >> (64 - log_n)) as usize; let j = ((i as u64).reverse_bits() >> (64 - log_n)) as usize;
r[i] = Zq(const_exp_mod::<Q>(w, j as u64));
r[i] = Zq {
q,
v: const_exp_mod(q, w, j as u64),
};
i += 1; i += 1;
} }
r r
} }
const fn roots_of_unity_inv<const Q: u64, const N: usize>(v: [Zq<Q>; N]) -> [Zq<Q>; N] {
fn roots_of_unity_inv(q: u64, n: usize, v: Vec<Zq>) -> Vec<Zq> {
// assumes that the inputted roots are already in bit-reverset order // assumes that the inputted roots are already in bit-reverset order
let mut r: [Zq<Q>; N] = [Zq(0u64); N];
let mut r: Vec<Zq> = vec![Zq { q, v: 0 }; n];
let mut i = 0; let mut i = 0;
while i < N {
r[i] = Zq(const_inv_mod::<Q>(v[i].0));
while i < n {
r[i] = Zq {
q,
v: const_inv_mod(q, v[i].v),
};
i += 1; i += 1;
} }
r r
} }
/// returns x^k mod Q /// returns x^k mod Q
const fn const_exp_mod<const Q: u64>(x: u64, k: u64) -> u64 {
const fn const_exp_mod(q: u64, x: u64, k: u64) -> u64 {
// work on u128 to avoid overflow // work on u128 to avoid overflow
let mut r = 1u128; let mut r = 1u128;
let mut x = x as u128; let mut x = x as u128;
let mut k = k as u128; let mut k = k as u128;
x = x % Q as u128;
x = x % q as u128;
// exponentiation by square strategy // exponentiation by square strategy
while k > 0 { while k > 0 {
if k % 2 == 1 { if k % 2 == 1 {
r = (r * x) % Q as u128;
r = (r * x) % q as u128;
} }
x = (x * x) % Q as u128;
x = (x * x) % q as u128;
k /= 2; k /= 2;
} }
r as u64 r as u64
} }
/// returns x^-1 mod Q /// returns x^-1 mod Q
const fn const_inv_mod<const Q: u64>(x: u64) -> u64 {
const fn const_inv_mod(q: u64, x: u64) -> u64 {
// by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q // by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q
const_exp_mod::<Q>(x, Q - 2)
const_exp_mod(q, x, q - 2)
} }
#[cfg(test)] #[cfg(test)]
@ -142,25 +197,24 @@ mod tests {
use super::*; use super::*;
use anyhow::Result; use anyhow::Result;
use std::array;
#[test] #[test]
fn test_ntt() -> Result<()> { fn test_ntt() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 4;
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 4;
let a: [u64; N] = [1u64, 2, 3, 4];
let a: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(a[i]));
let a: Vec<u64> = vec![1u64, 2, 3, 4];
let a: Rq = Rq::from_vec_u64(q, n, a);
let a_ntt = NTT::<Q, N>::ntt(a);
let a_ntt = NTT::ntt(&a);
let a_intt = NTT::<Q, N>::intt(a_ntt);
let a_intt = NTT::intt(&a_ntt);
dbg!(&a); dbg!(&a);
dbg!(&a_ntt); dbg!(&a_ntt);
dbg!(&a_intt); dbg!(&a_intt);
dbg!(NTT::<Q, N>::ROOT_OF_UNITY);
dbg!(NTT::<Q, N>::ROOTS_OF_UNITY);
// dbg!(NTT::ROOT_OF_UNITY);
// dbg!(NTT::ROOTS_OF_UNITY);
assert_eq!(a, a_intt); assert_eq!(a, a_intt);
Ok(()) Ok(())
@ -168,18 +222,17 @@ mod tests {
#[test] #[test]
fn test_ntt_loop() -> Result<()> { fn test_ntt_loop() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 512;
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 512;
use rand::distributions::Distribution;
use rand::distributions::Uniform; use rand::distributions::Uniform;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let dist = Uniform::new(0_f64, Q as f64);
let dist = Uniform::new(0_f64, q as f64);
for _ in 0..100 {
let a: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng)));
let a_ntt = NTT::<Q, N>::ntt(a);
let a_intt = NTT::<Q, N>::intt(a_ntt);
for _ in 0..10_000 {
let a: Rq = Rq::rand(&mut rng, dist, (q, n));
let a_ntt = NTT::ntt(&a);
let a_intt = NTT::intt(&a_ntt);
assert_eq!(a, a_intt); assert_eq!(a, a_intt);
} }
Ok(()) Ok(())

+ 187
- 0
arith/src/ntt_fixedsize.rs

@ -0,0 +1,187 @@
//! Implementation of the NTT & iNTT, following the CT & GS algorighms, more details in
//! https://eprint.iacr.org/2017/727.pdf, some notes at
//! https://github.com/arnaucube/math/blob/master/notes_ntt.pdf .
use crate::zq::Zq;
#[derive(Debug)]
pub struct NTT<const Q: u64, const N: usize> {}
impl<const Q: u64, const N: usize> NTT<Q, N> {
const N_INV: Zq<Q> = Zq(const_inv_mod::<Q>(N as u64));
// since we work over Zq[X]/(X^N+1) (negacyclic), get the 2*N-th root of unity
pub(crate) const ROOT_OF_UNITY: u64 = primitive_root_of_unity::<Q>(2 * N);
pub(crate) const ROOTS_OF_UNITY: [Zq<Q>; N] = roots_of_unity(Self::ROOT_OF_UNITY);
const ROOTS_OF_UNITY_INV: [Zq<Q>; N] = roots_of_unity_inv(Self::ROOTS_OF_UNITY);
}
impl<const Q: u64, const N: usize> NTT<Q, N> {
/// implements the Cooley-Tukey (CT) algorithm. Details at
/// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.1 of
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
pub fn ntt(a: [Zq<Q>; N]) -> [Zq<Q>; N] {
let mut t = N / 2;
let mut m = 1;
let mut r: [Zq<Q>; N] = a.clone();
while m < N {
let mut k = 0;
for i in 0..m {
let S: Zq<Q> = Self::ROOTS_OF_UNITY[m + i];
for j in k..k + t {
let U: Zq<Q> = r[j];
let V: Zq<Q> = r[j + t] * S;
r[j] = U + V;
r[j + t] = U - V;
}
k = k + 2 * t;
}
t /= 2;
m *= 2;
}
r
}
/// implements the Cooley-Tukey (CT) algorithm. Details at
/// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.2 of
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
pub fn intt(a: [Zq<Q>; N]) -> [Zq<Q>; N] {
let mut t = 1;
let mut m = N / 2;
let mut r: [Zq<Q>; N] = a.clone();
while m > 0 {
let mut k = 0;
for i in 0..m {
let S: Zq<Q> = Self::ROOTS_OF_UNITY_INV[m + i];
for j in k..k + t {
let U: Zq<Q> = r[j];
let V: Zq<Q> = r[j + t];
r[j] = U + V;
r[j + t] = (U - V) * S;
}
k += 2 * t;
}
t *= 2;
m /= 2;
}
for i in 0..N {
r[i] = r[i] * Self::N_INV;
}
r
}
}
/// computes a primitive N-th root of unity using the method described by Thomas
/// Pornin in https://crypto.stackexchange.com/a/63616
const fn primitive_root_of_unity<const Q: u64>(N: usize) -> u64 {
assert!(N.is_power_of_two());
assert!((Q - 1) % N as u64 == 0);
let n: u64 = N as u64;
let mut k = 1;
while k < Q {
// alternatively could get a random k at each iteration, if so, add the following if:
// `if k == 0 { continue; }`
let w = const_exp_mod::<Q>(k, (Q - 1) / n);
if const_exp_mod::<Q>(w, n / 2) != 1 {
return w; // w is a primitive N-th root of unity
}
k += 1;
}
panic!("No primitive root of unity");
}
const fn roots_of_unity<const Q: u64, const N: usize>(w: u64) -> [Zq<Q>; N] {
let mut r: [Zq<Q>; N] = [Zq(0u64); N];
let mut i = 0;
let log_n = N.ilog2();
while i < N {
// (return the roots in bit-reverset order)
let j = ((i as u64).reverse_bits() >> (64 - log_n)) as usize;
r[i] = Zq(const_exp_mod::<Q>(w, j as u64));
i += 1;
}
r
}
const fn roots_of_unity_inv<const Q: u64, const N: usize>(v: [Zq<Q>; N]) -> [Zq<Q>; N] {
// assumes that the inputted roots are already in bit-reverset order
let mut r: [Zq<Q>; N] = [Zq(0u64); N];
let mut i = 0;
while i < N {
r[i] = Zq(const_inv_mod::<Q>(v[i].0));
i += 1;
}
r
}
/// returns x^k mod Q
const fn const_exp_mod<const Q: u64>(x: u64, k: u64) -> u64 {
// work on u128 to avoid overflow
let mut r = 1u128;
let mut x = x as u128;
let mut k = k as u128;
x = x % Q as u128;
// exponentiation by square strategy
while k > 0 {
if k % 2 == 1 {
r = (r * x) % Q as u128;
}
x = (x * x) % Q as u128;
k /= 2;
}
r as u64
}
/// returns x^-1 mod Q
const fn const_inv_mod<const Q: u64>(x: u64) -> u64 {
// by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q
const_exp_mod::<Q>(x, Q - 2)
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
use std::array;
#[test]
fn test_ntt() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 4;
let a: [u64; N] = [1u64, 2, 3, 4];
let a: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(a[i]));
let a_ntt = NTT::<Q, N>::ntt(a);
let a_intt = NTT::<Q, N>::intt(a_ntt);
dbg!(&a);
dbg!(&a_ntt);
dbg!(&a_intt);
dbg!(NTT::<Q, N>::ROOT_OF_UNITY);
dbg!(NTT::<Q, N>::ROOTS_OF_UNITY);
assert_eq!(a, a_intt);
Ok(())
}
#[test]
fn test_ntt_loop() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 512;
use rand::distributions::Distribution;
use rand::distributions::Uniform;
let mut rng = rand::thread_rng();
let dist = Uniform::new(0_f64, Q as f64);
for _ in 0..100 {
let a: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng)));
let a_ntt = NTT::<Q, N>::ntt(a);
let a_intt = NTT::<Q, N>::intt(a_ntt);
assert_eq!(a, a_intt);
}
Ok(())
}
}

+ 9
- 8
arith/src/ring.rs

@ -21,27 +21,28 @@ pub trait Ring:
+ PartialEq + PartialEq
+ Debug + Debug
+ Clone + Clone
+ Copy
// + Copy
+ Sum<<Self as Add>::Output> + Sum<<Self as Add>::Output>
+ Sum<<Self as Mul>::Output> + Sum<<Self as Mul>::Output>
{ {
/// C defines the coefficient type /// C defines the coefficient type
type C: Debug + Clone; type C: Debug + Clone;
type Params: Debug+Clone+Copy;
const Q: u64;
const N: usize;
// const Q: u64;
// const N: usize;
fn coeffs(&self) -> Vec<Self::C>; fn coeffs(&self) -> Vec<Self::C>;
fn zero() -> Self;
fn zero(params: Self::Params) -> Self;
// note/wip/warning: dist (0,q) with f64, will output more '0=q' elements than other values // note/wip/warning: dist (0,q) with f64, will output more '0=q' elements than other values
fn rand(rng: impl Rng, dist: impl Distribution<f64>) -> Self;
fn rand(rng: impl Rng, dist: impl Distribution<f64>, params: Self::Params) -> Self;
fn from_vec(coeffs: Vec<Self::C>) -> Self;
fn from_vec(params: Self::Params, coeffs: Vec<Self::C>) -> Self;
fn decompose(&self, beta: u32, l: u32) -> Vec<Self>; fn decompose(&self, beta: u32, l: u32) -> Vec<Self>;
fn remodule<const P: u64>(&self) -> impl Ring;
fn mod_switch<const P: u64>(&self) -> impl Ring;
fn remodule(&self, p:u64) -> impl Ring;
fn mod_switch(&self, p:u64) -> impl Ring;
/// returns [ [(num/den) * self].round() ] mod q /// returns [ [(num/den) * self].round() ] mod q
/// ie. performs the multiplication and division over f64, and then it /// ie. performs the multiplication and division over f64, and then it

+ 197
- 146
arith/src/ring_n.rs

@ -2,8 +2,10 @@
//! //!
use anyhow::Result; use anyhow::Result;
use itertools::zip_eq;
use rand::{distributions::Distribution, Rng}; use rand::{distributions::Distribution, Rng};
use std::array; use std::array;
use std::borrow::Borrow;
use std::fmt; use std::fmt;
use std::iter::Sum; use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
@ -12,34 +14,46 @@ use crate::Ring;
// TODO rename to not have name conflicts with the Ring trait (R: Ring) // TODO rename to not have name conflicts with the Ring trait (R: Ring)
// PolynomialRing element, where the PolynomialRing is R = Z[X]/(X^n +1) // PolynomialRing element, where the PolynomialRing is R = Z[X]/(X^n +1)
#[derive(Clone, Copy)]
pub struct R<const N: usize>(pub [i64; N]);
#[derive(Clone)]
pub struct R {
pub n: usize,
pub coeffs: Vec<i64>,
}
// impl<const N: usize> Ring for R<N> { // impl<const N: usize> Ring for R<N> {
impl<const N: usize> R<N> {
impl R {
// type C = i64; // type C = i64;
// type Params = usize; // n
// const Q: u64 = i64::MAX as u64; // WIP // const Q: u64 = i64::MAX as u64; // WIP
// const N: usize = N; // const N: usize = N;
pub fn coeffs(&self) -> Vec<i64> { pub fn coeffs(&self) -> Vec<i64> {
self.0.to_vec()
self.coeffs.clone()
} }
fn zero() -> Self {
let coeffs: [i64; N] = array::from_fn(|_| 0i64);
Self(coeffs)
fn zero(n: usize) -> Self {
Self {
n,
coeffs: vec![0i64; n],
}
} }
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>) -> Self {
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>, n: usize) -> Self {
// let coeffs: [i64; N] = array::from_fn(|_| Self::C::rand(&mut rng, &dist)); // let coeffs: [i64; N] = array::from_fn(|_| Self::C::rand(&mut rng, &dist));
let coeffs: [i64; N] = array::from_fn(|_| dist.sample(&mut rng).round() as i64);
Self(coeffs)
// let coeffs: [i64; N] = array::from_fn(|_| dist.sample(&mut rng).round() as i64);
Self {
n,
coeffs: std::iter::repeat_with(|| dist.sample(&mut rng).round() as i64)
.take(n)
.collect(),
}
// let coeffs: [C; N] = array::from_fn(|_| Zq::from_u64(dist.sample(&mut rng))); // let coeffs: [C; N] = array::from_fn(|_| Zq::from_u64(dist.sample(&mut rng)));
// Self(coeffs) // Self(coeffs)
} }
pub fn from_vec(coeffs: Vec<i64>) -> Self {
pub fn from_vec(n: usize, coeffs: Vec<i64>) -> Self {
let mut p = coeffs; let mut p = coeffs;
modulus::<N>(&mut p);
Self(array::from_fn(|i| p[i]))
modulus(n, &mut p);
Self { n, coeffs: p }
} }
/* /*
@ -71,34 +85,38 @@ impl R {
*/ */
} }
impl<const Q: u64, const N: usize> From<crate::ring_nq::Rq<Q, N>> for R<N> {
fn from(rq: crate::ring_nq::Rq<Q, N>) -> Self {
Self::from_vec_u64(rq.coeffs().to_vec().iter().map(|e| e.0).collect())
impl From<crate::ring_nq::Rq> for R {
fn from(rq: crate::ring_nq::Rq) -> Self {
Self::from_vec_u64(rq.n, rq.coeffs().to_vec().iter().map(|e| e.v).collect())
} }
} }
impl<const N: usize> R<N> {
impl R {
// pub fn coeffs(&self) -> [i64; N] { // pub fn coeffs(&self) -> [i64; N] {
// self.0 // self.0
// } // }
pub fn to_rq<const Q: u64>(self) -> crate::Rq<Q, N> {
crate::Rq::<Q, N>::from(self)
pub fn to_rq(self, q: u64) -> crate::Rq {
crate::Rq::from((q, self))
} }
// this method is mostly for tests // this method is mostly for tests
pub fn from_vec_u64(coeffs: Vec<u64>) -> Self {
pub fn from_vec_u64(n: usize, coeffs: Vec<u64>) -> Self {
let coeffs_i64 = coeffs.iter().map(|c| *c as i64).collect(); let coeffs_i64 = coeffs.iter().map(|c| *c as i64).collect();
Self::from_vec(coeffs_i64)
Self::from_vec(n, coeffs_i64)
} }
pub fn from_vec_f64(coeffs: Vec<f64>) -> Self {
pub fn from_vec_f64(n: usize, coeffs: Vec<f64>) -> Self {
let coeffs_i64 = coeffs.iter().map(|c| c.round() as i64).collect(); let coeffs_i64 = coeffs.iter().map(|c| c.round() as i64).collect();
Self::from_vec(coeffs_i64)
Self::from_vec(n, coeffs_i64)
} }
pub fn new(coeffs: [i64; N]) -> Self {
Self(coeffs)
pub fn new(n: usize, coeffs: Vec<i64>) -> Self {
assert_eq!(n, coeffs.len());
Self { n, coeffs }
} }
pub fn mul_by_i64(&self, s: i64) -> Self { pub fn mul_by_i64(&self, s: i64) -> Self {
Self(array::from_fn(|i| self.0[i] * s))
Self {
n: self.n,
coeffs: self.coeffs.iter().map(|c_i| c_i * s).collect(),
}
} }
pub fn infinity_norm(&self) -> u64 { pub fn infinity_norm(&self) -> u64 {
@ -108,10 +126,10 @@ impl R {
.map(|x| x.abs() as u64) .map(|x| x.abs() as u64)
.fold(0, |a, b| a.max(b)) .fold(0, |a, b| a.max(b))
} }
pub fn mod_centered_q<const Q: u64>(&self) -> R<N> {
let q = Q as i64;
pub fn mod_centered_q(&self, q: u64) -> R {
let q = q as i64;
let r = self let r = self
.0
.coeffs
.iter() .iter()
.map(|v| { .map(|v| {
let mut res = v % q; let mut res = v % q;
@ -121,190 +139,224 @@ impl R {
res res
}) })
.collect::<Vec<i64>>(); .collect::<Vec<i64>>();
R::<N>::from_vec(r)
R::from_vec(self.n, r)
} }
} }
pub fn mul_div_round<const Q: u64, const N: usize>(
v: Vec<i64>,
num: u64,
den: u64,
) -> crate::Rq<Q, N> {
pub fn mul_div_round(q: u64, n: usize, v: Vec<i64>, num: u64, den: u64) -> crate::Rq {
// dbg!(&v); // dbg!(&v);
let r: Vec<f64> = v let r: Vec<f64> = v
.iter() .iter()
.map(|e| ((num as f64 * *e as f64) / den as f64).round()) .map(|e| ((num as f64 * *e as f64) / den as f64).round())
.collect(); .collect();
// dbg!(&r); // dbg!(&r);
crate::Rq::<Q, N>::from_vec_f64(r)
crate::Rq::from_vec_f64(q, n, r)
} }
// TODO rename to make it clear that is not mod q, but mod X^N+1 // TODO rename to make it clear that is not mod q, but mod X^N+1
// apply mod (X^N+1) // apply mod (X^N+1)
pub fn modulus<const N: usize>(p: &mut Vec<i64>) {
if p.len() < N {
pub fn modulus(n: usize, p: &mut Vec<i64>) {
if p.len() < n {
return; return;
} }
for i in N..p.len() {
p[i - N] = p[i - N].clone() - p[i].clone();
for i in n..p.len() {
p[i - n] = p[i - n].clone() - p[i].clone();
p[i] = 0; p[i] = 0;
} }
p.truncate(N);
p.truncate(n);
} }
pub fn modulus_i128<const N: usize>(p: &mut Vec<i128>) {
if p.len() < N {
pub fn modulus_i128(n: usize, p: &mut Vec<i128>) {
if p.len() < n {
return; return;
} }
for i in N..p.len() {
p[i - N] = p[i - N].clone() - p[i].clone();
for i in n..p.len() {
p[i - n] = p[i - n].clone() - p[i].clone();
p[i] = 0; p[i] = 0;
} }
p.truncate(N);
p.truncate(n);
} }
impl<const N: usize> PartialEq for R<N> {
impl PartialEq for R {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.0 == other.0
self.coeffs == other.coeffs && self.n == other.n
} }
} }
impl<const N: usize> Add<R<N>> for R<N> {
impl Add<R> for R {
type Output = Self; type Output = Self;
fn add(self, rhs: Self) -> Self { fn add(self, rhs: Self) -> Self {
Self(array::from_fn(|i| self.0[i] + rhs.0[i]))
assert_eq!(self.n, rhs.n);
Self {
n: self.n,
coeffs: zip_eq(self.coeffs, rhs.coeffs)
.map(|(l, r)| l + r)
.collect(),
}
} }
} }
impl<const N: usize> Add<&R<N>> for &R<N> {
type Output = R<N>;
fn add(self, rhs: &R<N>) -> Self::Output {
R(array::from_fn(|i| self.0[i] + rhs.0[i]))
impl Add<&R> for &R {
type Output = R;
fn add(self, rhs: &R) -> Self::Output {
// R(array::from_fn(|i| self.0[i] + rhs.0[i]))
assert_eq!(self.n, rhs.n);
R {
n: self.n,
coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone())
.map(|(l, r)| l + r)
.collect(),
}
} }
} }
impl<const N: usize> AddAssign for R<N> {
impl AddAssign for R {
fn add_assign(&mut self, rhs: Self) { fn add_assign(&mut self, rhs: Self) {
for i in 0..N {
self.0[i] += rhs.0[i];
assert_eq!(self.n, rhs.n);
for i in 0..self.n {
self.coeffs[i] += rhs.coeffs[i];
} }
} }
} }
impl<const N: usize> Sum<R<N>> for R<N> {
fn sum<I>(iter: I) -> Self
impl Sum<R> for R {
fn sum<I>(mut iter: I) -> Self
where where
I: Iterator<Item = Self>, I: Iterator<Item = Self>,
{ {
let mut acc = R::<N>::zero();
for e in iter {
acc += e;
}
acc
// let mut acc = R::zero();
// for e in iter {
// acc += e;
// }
// acc
let first = iter.next().unwrap();
iter.fold(first, |acc, x| acc + x)
} }
} }
impl<const N: usize> Sub<R<N>> for R<N> {
impl Sub<R> for R {
type Output = Self; type Output = Self;
fn sub(self, rhs: Self) -> Self { fn sub(self, rhs: Self) -> Self {
Self(array::from_fn(|i| self.0[i] - rhs.0[i]))
// Self(array::from_fn(|i| self.0[i] - rhs.0[i]))
assert_eq!(self.n, rhs.n);
Self {
n: self.n,
coeffs: zip_eq(self.coeffs, rhs.coeffs)
.map(|(l, r)| l - r)
.collect(),
}
} }
} }
impl<const N: usize> Sub<&R<N>> for &R<N> {
type Output = R<N>;
fn sub(self, rhs: &R<N>) -> Self::Output {
R(array::from_fn(|i| self.0[i] - rhs.0[i]))
impl Sub<&R> for &R {
type Output = R;
fn sub(self, rhs: &R) -> Self::Output {
// R(array::from_fn(|i| self.0[i] - rhs.0[i]))
assert_eq!(self.n, rhs.n);
R {
n: self.n,
coeffs: zip_eq(&self.coeffs, &rhs.coeffs)
.map(|(l, r)| l - r)
.collect(),
}
} }
} }
impl<const N: usize> SubAssign for R<N> {
impl SubAssign for R {
fn sub_assign(&mut self, rhs: Self) { fn sub_assign(&mut self, rhs: Self) {
for i in 0..N {
self.0[i] -= rhs.0[i];
assert_eq!(self.n, rhs.n);
for i in 0..self.n {
self.coeffs[i] -= rhs.coeffs[i];
} }
} }
} }
impl<const N: usize> Mul<R<N>> for R<N> {
impl Mul<R> for R {
type Output = Self; type Output = Self;
fn mul(self, rhs: Self) -> Self { fn mul(self, rhs: Self) -> Self {
naive_poly_mul(&self, &rhs) naive_poly_mul(&self, &rhs)
} }
} }
impl<const N: usize> Mul<&R<N>> for &R<N> {
type Output = R<N>;
impl Mul<&R> for &R {
type Output = R;
fn mul(self, rhs: &R<N>) -> Self::Output {
fn mul(self, rhs: &R) -> Self::Output {
naive_poly_mul(self, rhs) naive_poly_mul(self, rhs)
} }
} }
// TODO WIP // TODO WIP
pub fn naive_poly_mul<const N: usize>(poly1: &R<N>, poly2: &R<N>) -> R<N> {
let poly1: Vec<i128> = poly1.0.iter().map(|c| *c as i128).collect();
let poly2: Vec<i128> = poly2.0.iter().map(|c| *c as i128).collect();
let mut result: Vec<i128> = vec![0; (N * 2) - 1];
for i in 0..N {
for j in 0..N {
pub fn naive_poly_mul(poly1: &R, poly2: &R) -> R {
assert_eq!(poly1.n, poly2.n);
let n = poly1.n;
let poly1: Vec<i128> = poly1.coeffs.iter().map(|c| *c as i128).collect();
let poly2: Vec<i128> = poly2.coeffs.iter().map(|c| *c as i128).collect();
let mut result: Vec<i128> = vec![0; (n * 2) - 1];
for i in 0..n {
for j in 0..n {
result[i + j] = result[i + j] + poly1[i] * poly2[j]; result[i + j] = result[i + j] + poly1[i] * poly2[j];
} }
} }
// apply mod (X^N + 1)) // apply mod (X^N + 1))
// R::<N>::from_vec(result.iter().map(|c| *c as i64).collect()) // R::<N>::from_vec(result.iter().map(|c| *c as i64).collect())
modulus_i128::<N>(&mut result);
modulus_i128(n, &mut result);
// dbg!(&result); // dbg!(&result);
// dbg!(R::<N>(array::from_fn(|i| result[i] as i64)).coeffs()); // dbg!(R::<N>(array::from_fn(|i| result[i] as i64)).coeffs());
let result_i64: Vec<i64> = result.iter().map(|c_i| *c_i as i64).collect();
let r = R::from_vec(n, result_i64);
// sanity check: check that there are no coeffs > i64_max // sanity check: check that there are no coeffs > i64_max
assert_eq!( assert_eq!(
result, result,
R::<N>(array::from_fn(|i| result[i] as i64))
.coeffs()
.iter()
.map(|c| *c as i128)
.collect::<Vec<_>>()
r.coeffs.iter().map(|c| *c as i128).collect::<Vec<_>>()
); );
R(array::from_fn(|i| result[i] as i64))
r
} }
pub fn naive_mul_2<const N: usize>(poly1: &Vec<i128>, poly2: &Vec<i128>) -> Vec<i128> {
let mut result: Vec<i128> = vec![0; (N * 2) - 1];
for i in 0..N {
for j in 0..N {
pub fn naive_mul_2(n: usize, poly1: &Vec<i128>, poly2: &Vec<i128>) -> Vec<i128> {
let mut result: Vec<i128> = vec![0; (n * 2) - 1];
for i in 0..n {
for j in 0..n {
result[i + j] = result[i + j] + poly1[i] * poly2[j]; result[i + j] = result[i + j] + poly1[i] * poly2[j];
} }
} }
// apply mod (X^N + 1)) // apply mod (X^N + 1))
// R::<N>::from_vec(result.iter().map(|c| *c as i64).collect()) // R::<N>::from_vec(result.iter().map(|c| *c as i64).collect())
modulus_i128::<N>(&mut result);
modulus_i128(n, &mut result);
result result
} }
pub fn naive_mul<const N: usize>(poly1: &R<N>, poly2: &R<N>) -> Vec<i64> {
let poly1: Vec<i128> = poly1.0.iter().map(|c| *c as i128).collect();
let poly2: Vec<i128> = poly2.0.iter().map(|c| *c as i128).collect();
let mut result = vec![0; (N * 2) - 1];
for i in 0..N {
for j in 0..N {
pub fn naive_mul(poly1: &R, poly2: &R) -> Vec<i64> {
assert_eq!(poly1.n, poly2.n);
let n = poly1.n;
let poly1: Vec<i128> = poly1.coeffs.iter().map(|c| *c as i128).collect();
let poly2: Vec<i128> = poly2.coeffs.iter().map(|c| *c as i128).collect();
let mut result = vec![0; (n * 2) - 1];
for i in 0..n {
for j in 0..n {
result[i + j] = result[i + j] + poly1[i] * poly2[j]; result[i + j] = result[i + j] + poly1[i] * poly2[j];
} }
} }
result.iter().map(|c| *c as i64).collect() result.iter().map(|c| *c as i64).collect()
} }
pub fn naive_mul_TMP<const N: usize>(poly1: &R<N>, poly2: &R<N>) -> Vec<i64> {
let poly1: Vec<i128> = poly1.0.iter().map(|c| *c as i128).collect();
let poly2: Vec<i128> = poly2.0.iter().map(|c| *c as i128).collect();
let mut result: Vec<i128> = vec![0; (N * 2) - 1];
for i in 0..N {
for j in 0..N {
pub fn naive_mul_TMP(poly1: &R, poly2: &R) -> Vec<i64> {
assert_eq!(poly1.n, poly2.n);
let n = poly1.n;
let poly1: Vec<i128> = poly1.coeffs.iter().map(|c| *c as i128).collect();
let poly2: Vec<i128> = poly2.coeffs.iter().map(|c| *c as i128).collect();
let mut result: Vec<i128> = vec![0; (n * 2) - 1];
for i in 0..n {
for j in 0..n {
result[i + j] = result[i + j] + poly1[i] * poly2[j]; result[i + j] = result[i + j] + poly1[i] * poly2[j];
} }
} }
// dbg!(&result); // dbg!(&result);
modulus_i128::<N>(&mut result);
modulus_i128(n, &mut result);
// for c_i in result.iter() { // for c_i in result.iter() {
// println!("---"); // println!("---");
// println!("{:?}", &c_i); // println!("{:?}", &c_i);
@ -316,8 +368,8 @@ pub fn naive_mul_TMP(poly1: &R, poly2: &R) -> Vec {
} }
// wip // wip
pub fn mod_centered_q<const Q: u64, const N: usize>(p: Vec<i128>) -> R<N> {
let q: i128 = Q as i128;
pub fn mod_centered_q(q: u64, n: usize, p: Vec<i128>) -> R {
let q: i128 = q as i128;
let r = p let r = p
.iter() .iter()
.map(|v| { .map(|v| {
@ -328,10 +380,10 @@ pub fn mod_centered_q(p: Vec) -> R {
res res
}) })
.collect::<Vec<i128>>(); .collect::<Vec<i128>>();
R::<N>::from_vec(r.iter().map(|v| *v as i64).collect::<Vec<i64>>())
R::from_vec(n, r.iter().map(|v| *v as i64).collect::<Vec<i64>>())
} }
impl<const N: usize> Mul<i64> for R<N> {
impl Mul<i64> for R {
type Output = Self; type Output = Self;
fn mul(self, s: i64) -> Self { fn mul(self, s: i64) -> Self {
@ -339,34 +391,38 @@ impl Mul for R {
} }
} }
// mul by u64 // mul by u64
impl<const N: usize> Mul<u64> for R<N> {
impl Mul<u64> for R {
type Output = Self; type Output = Self;
fn mul(self, s: u64) -> Self { fn mul(self, s: u64) -> Self {
self.mul_by_i64(s as i64) self.mul_by_i64(s as i64)
} }
} }
impl<const N: usize> Mul<&u64> for &R<N> {
type Output = R<N>;
impl Mul<&u64> for &R {
type Output = R;
fn mul(self, s: &u64) -> Self::Output { fn mul(self, s: &u64) -> Self::Output {
self.mul_by_i64(*s as i64) self.mul_by_i64(*s as i64)
} }
} }
impl<const N: usize> Neg for R<N> {
impl Neg for R {
type Output = Self; type Output = Self;
fn neg(self) -> Self::Output { fn neg(self) -> Self::Output {
Self(array::from_fn(|i| -self.0[i]))
// Self(array::from_fn(|i| -self.0[i]))
Self {
n: self.n,
coeffs: self.coeffs.iter().map(|c_i| -c_i).collect(),
}
} }
} }
impl<const N: usize> R<N> {
impl R {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut str = ""; let mut str = "";
let mut zero = true; let mut zero = true;
for (i, coeff) in self.0.iter().enumerate().rev() {
for (i, coeff) in self.coeffs.iter().enumerate().rev() {
if *coeff == 0 { if *coeff == 0 {
continue; continue;
} }
@ -395,18 +451,18 @@ impl R {
f.write_str(" mod Z")?; f.write_str(" mod Z")?;
f.write_str("/(X^")?; f.write_str("/(X^")?;
f.write_str(N.to_string().as_str())?;
f.write_str(self.n.to_string().as_str())?;
f.write_str("+1)")?; f.write_str("+1)")?;
Ok(()) Ok(())
} }
} }
impl<const N: usize> fmt::Display for R<N> {
impl fmt::Display for R {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt(f)?; self.fmt(f)?;
Ok(()) Ok(())
} }
} }
impl<const N: usize> fmt::Debug for R<N> {
impl fmt::Debug for R {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt(f)?; self.fmt(f)?;
Ok(()) Ok(())
@ -420,38 +476,33 @@ mod tests {
#[test] #[test]
fn test_mul() -> Result<()> { fn test_mul() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 2;
let q: i64 = Q as i64;
let n: usize = 2;
let q: i64 = (2u64.pow(16) + 1) as i64;
// test vectors generated with SageMath // test vectors generated with SageMath
let a: [i64; N] = [q - 1, q - 1];
let b: [i64; N] = [q - 1, q - 1];
let c: [i64; N] = [0, 8589934592];
test_mul_opt::<Q, N>(a, b, c)?;
let a: Vec<i64> = vec![q - 1, q - 1];
let b: Vec<i64> = vec![q - 1, q - 1];
let c: Vec<i64> = vec![0, 8589934592];
test_mul_opt(n, a, b, c)?;
let a: [i64; N] = [1, q - 1];
let b: [i64; N] = [1, q - 1];
let c: [i64; N] = [-4294967295, 131072];
test_mul_opt::<Q, N>(a, b, c)?;
let a: Vec<i64> = vec![1, q - 1];
let b: Vec<i64> = vec![1, q - 1];
let c: Vec<i64> = vec![-4294967295, 131072];
test_mul_opt(n, a, b, c)?;
Ok(()) Ok(())
} }
fn test_mul_opt<const Q: u64, const N: usize>(
a: [i64; N],
b: [i64; N],
expected_c: [i64; N],
) -> Result<()> {
let mut a = R::new(a);
let mut b = R::new(b);
fn test_mul_opt(n: usize, a: Vec<i64>, b: Vec<i64>, expected_c: Vec<i64>) -> Result<()> {
let mut a = R::new(n, a);
let mut b = R::new(n, b);
dbg!(&a); dbg!(&a);
dbg!(&b); dbg!(&b);
let expected_c = R::new(expected_c);
let expected_c = R::new(n, expected_c);
let mut c = naive_mul(&mut a, &mut b); let mut c = naive_mul(&mut a, &mut b);
modulus::<N>(&mut c);
dbg!(R::<N>::from_vec(c.clone()));
assert_eq!(c, expected_c.0.to_vec());
modulus(n, &mut c);
dbg!(R::from_vec(n, c.clone()));
assert_eq!(c, expected_c.coeffs);
Ok(()) Ok(())
} }
} }

+ 353
- 199
arith/src/ring_nq.rs

@ -2,8 +2,10 @@
//! //!
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use itertools::zip_eq;
use rand::{distributions::Distribution, Rng}; use rand::{distributions::Distribution, Rng};
use std::array; use std::array;
use std::borrow::Borrow;
use std::fmt; use std::fmt;
use std::iter::Sum; use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}; use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign};
@ -18,75 +20,91 @@ use crate::Ring;
// use Vec. // use Vec.
/// PolynomialRing element, where the PolynomialRing is R = Z_q[X]/(X^n +1) /// PolynomialRing element, where the PolynomialRing is R = Z_q[X]/(X^n +1)
/// The implementation assumes that q is prime. /// The implementation assumes that q is prime.
#[derive(Clone, Copy)]
pub struct Rq<const Q: u64, const N: usize> {
pub(crate) coeffs: [Zq<Q>; N],
#[derive(Clone)]
pub struct Rq {
pub(crate) q: u64, // TODO think if really needed or it's fine with coeffs[0].q
pub(crate) n: usize,
pub(crate) coeffs: Vec<Zq>,
// evals are set when doig a PRxPR multiplication, so it can be reused in future // evals are set when doig a PRxPR multiplication, so it can be reused in future
// multiplications avoiding recomputing it // multiplications avoiding recomputing it
pub(crate) evals: Option<[Zq<Q>; N]>,
pub(crate) evals: Option<Vec<Zq>>,
} }
impl<const Q: u64, const N: usize> Ring for Rq<Q, N> {
type C = Zq<Q>;
const Q: u64 = Q;
const N: usize = N;
impl Ring for Rq {
type C = Zq;
type Params = (u64, usize);
fn coeffs(&self) -> Vec<Self::C> { fn coeffs(&self) -> Vec<Self::C> {
self.coeffs.to_vec() self.coeffs.to_vec()
} }
fn zero() -> Self {
let coeffs = array::from_fn(|_| Zq::zero());
// fn zero(q: u64, n: usize) -> Self {
fn zero(param: (u64, usize)) -> Self {
let (q, n) = param;
Self { Self {
coeffs,
q,
n,
coeffs: vec![Zq::zero(q); n],
evals: None, evals: None,
} }
} }
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>) -> Self {
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>, params: Self::Params) -> Self {
// let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_u64(dist.sample(&mut rng))); // let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_u64(dist.sample(&mut rng)));
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Self::C::rand(&mut rng, &dist));
// let coeffs: [Zq<Q>; N] = array::from_fn(|_| Self::C::rand(&mut rng, &dist));
let (q, n) = params;
Self { Self {
coeffs,
q,
n,
coeffs: std::iter::repeat_with(|| Self::C::rand(&mut rng, &dist, q))
.take(n)
.collect(),
evals: None, evals: None,
} }
} }
fn from_vec(coeffs: Vec<Zq<Q>>) -> Self {
fn from_vec(params: Self::Params, coeffs: Vec<Zq>) -> Self {
let (q, n) = params;
let mut p = coeffs; let mut p = coeffs;
modulus::<Q, N>(&mut p);
let coeffs = array::from_fn(|i| p[i]);
modulus(q, n, &mut p);
Self { Self {
coeffs,
q,
n,
coeffs: p,
evals: None, evals: None,
} }
} }
// returns the decomposition of each polynomial coefficient, such // returns the decomposition of each polynomial coefficient, such
// decomposition will be a vecotor of length N, containint N vectors of Zq
// decomposition will be a vector of length N, containing N vectors of Zq
fn decompose(&self, beta: u32, l: u32) -> Vec<Self> { fn decompose(&self, beta: u32, l: u32) -> Vec<Self> {
let elems: Vec<Vec<Zq<Q>>> = self.coeffs.iter().map(|r| r.decompose(beta, l)).collect();
let elems: Vec<Vec<Zq>> = self.coeffs.iter().map(|r| r.decompose(beta, l)).collect();
// transpose it // transpose it
let r: Vec<Vec<Zq<Q>>> = (0..elems[0].len())
let r: Vec<Vec<Zq>> = (0..elems[0].len())
.map(|i| (0..elems.len()).map(|j| elems[j][i]).collect()) .map(|i| (0..elems.len()).map(|j| elems[j][i]).collect())
.collect(); .collect();
// convert it to Rq<Q,N> // convert it to Rq<Q,N>
r.iter().map(|a_i| Self::from_vec(a_i.clone())).collect()
r.iter()
.map(|a_i| Self::from_vec((self.q, self.n), a_i.clone()))
.collect()
} }
// Warning: this method will behave differently depending on the values P and Q: // Warning: this method will behave differently depending on the values P and Q:
// if Q<P, it just 'renames' the modulus parameter to P // if Q<P, it just 'renames' the modulus parameter to P
// if Q>=P, it crops to mod P // if Q>=P, it crops to mod P
fn remodule<const P: u64>(&self) -> Rq<P, N> {
Rq::<P, N>::from_vec_u64(self.coeffs().iter().map(|m_i| m_i.0).collect())
fn remodule(&self, p: u64) -> Rq {
Rq::from_vec_u64(p, self.n, self.coeffs().iter().map(|m_i| m_i.v).collect())
} }
/// perform the mod switch operation from Q to Q', where Q2=Q' /// perform the mod switch operation from Q to Q', where Q2=Q'
// fn mod_switch<const P: u64, const M: usize>(&self) -> impl Ring { // fn mod_switch<const P: u64, const M: usize>(&self) -> impl Ring {
fn mod_switch<const P: u64>(&self) -> Rq<P, N> {
fn mod_switch(&self, p: u64) -> Rq {
// assert_eq!(N, M); // sanity check // assert_eq!(N, M); // sanity check
Rq::<P, N> {
coeffs: array::from_fn(|i| self.coeffs[i].mod_switch::<P>()),
Rq {
q: p,
n: self.n,
// coeffs: array::from_fn(|i| self.coeffs[i].mod_switch::<P>()),
coeffs: self.coeffs.iter().map(|c_i| c_i.mod_switch(p)).collect(),
evals: None, evals: None,
} }
} }
@ -98,45 +116,49 @@ impl Ring for Rq {
let r: Vec<f64> = self let r: Vec<f64> = self
.coeffs() .coeffs()
.iter() .iter()
.map(|e| ((num as f64 * e.0 as f64) / den as f64).round())
.map(|e| ((num as f64 * e.v as f64) / den as f64).round())
.collect(); .collect();
Rq::<Q, N>::from_vec_f64(r)
Rq::from_vec_f64(self.q, self.n, r)
} }
} }
impl<const Q: u64, const N: usize> From<crate::ring_n::R<N>> for Rq<Q, N> {
fn from(r: crate::ring_n::R<N>) -> Self {
impl From<(u64, crate::ring_n::R)> for Rq {
fn from(qr: (u64, crate::ring_n::R)) -> Self {
let (q, r) = qr;
assert_eq!(r.n, r.coeffs.len());
Self::from_vec( Self::from_vec(
(q, r.n),
r.coeffs() r.coeffs()
.iter() .iter()
.map(|e| Zq::<Q>::from_f64(*e as f64))
.map(|e| Zq::from_f64(q, *e as f64))
.collect(), .collect(),
) )
} }
} }
// apply mod (X^N+1) // apply mod (X^N+1)
pub fn modulus<const Q: u64, const N: usize>(p: &mut Vec<Zq<Q>>) {
if p.len() < N {
pub fn modulus(q: u64, n: usize, p: &mut Vec<Zq>) {
if p.len() < n {
return; return;
} }
for i in N..p.len() {
p[i - N] = p[i - N].clone() - p[i].clone();
p[i] = Zq(0);
for i in n..p.len() {
p[i - n] = p[i - n].clone() - p[i].clone();
p[i] = Zq::zero(q);
} }
p.truncate(N);
p.truncate(n);
} }
// PR stands for PolynomialRing
impl<const Q: u64, const N: usize> Rq<Q, N> {
pub fn coeffs(&self) -> [Zq<Q>; N] {
self.coeffs
impl Rq {
pub fn coeffs(&self) -> Vec<Zq> {
self.coeffs.clone()
} }
pub fn compute_evals(&mut self) { pub fn compute_evals(&mut self) {
self.evals = Some(NTT::<Q, N>::ntt(self.coeffs));
self.evals = Some(NTT::ntt(self).coeffs); // TODO improve, ntt returns Rq but here
// just needs Vec<Zq>
} }
pub fn to_r(self) -> crate::R<N> {
crate::R::<N>::from(self)
pub fn to_r(self) -> crate::R {
crate::R::from(self)
} }
// TODO rm since it is implemented in Ring trait impl // TODO rm since it is implemented in Ring trait impl
@ -148,55 +170,105 @@ impl Rq {
// } // }
// } // }
// this method is mostly for tests // this method is mostly for tests
pub fn from_vec_u64(coeffs: Vec<u64>) -> Self {
let coeffs_mod_q = coeffs.iter().map(|c| Zq::from_u64(*c)).collect();
Self::from_vec(coeffs_mod_q)
pub fn from_vec_u64(q: u64, n: usize, coeffs: Vec<u64>) -> Self {
let coeffs_mod_q: Vec<Zq> = coeffs.iter().map(|c| Zq::from_u64(q, *c)).collect();
Self::from_vec((q, n), coeffs_mod_q)
} }
pub fn from_vec_f64(coeffs: Vec<f64>) -> Self {
let coeffs_mod_q = coeffs.iter().map(|c| Zq::from_f64(*c)).collect();
Self::from_vec(coeffs_mod_q)
pub fn from_vec_f64(q: u64, n: usize, coeffs: Vec<f64>) -> Self {
let coeffs_mod_q: Vec<Zq> = coeffs.iter().map(|c| Zq::from_f64(q, *c)).collect();
Self::from_vec((q, n), coeffs_mod_q)
} }
pub fn from_vec_i64(coeffs: Vec<i64>) -> Self {
let coeffs_mod_q = coeffs.iter().map(|c| Zq::from_f64(*c as f64)).collect();
Self::from_vec(coeffs_mod_q)
pub fn from_vec_i64(q: u64, n: usize, coeffs: Vec<i64>) -> Self {
let coeffs_mod_q: Vec<Zq> = coeffs.iter().map(|c| Zq::from_f64(q, *c as f64)).collect();
Self::from_vec((q, n), coeffs_mod_q)
} }
pub fn new(coeffs: [Zq<Q>; N], evals: Option<[Zq<Q>; N]>) -> Self {
Self { coeffs, evals }
pub fn new(q: u64, n: usize, coeffs: Vec<Zq>, evals: Option<Vec<Zq>>) -> Self {
Self {
q,
n,
coeffs,
evals,
}
} }
pub fn rand_abs(mut rng: impl Rng, dist: impl Distribution<f64>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng).abs()));
pub fn rand_abs(
mut rng: impl Rng,
dist: impl Distribution<f64>,
q: u64,
n: usize,
) -> Result<Self> {
// let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng).abs()));
Ok(Self { Ok(Self {
coeffs,
q,
n,
coeffs: std::iter::repeat_with(|| Zq::from_f64(q, dist.sample(&mut rng).abs()))
.take(n)
.collect(),
evals: None, evals: None,
}) })
} }
pub fn rand_f64_abs(mut rng: impl Rng, dist: impl Distribution<f64>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng).abs()));
pub fn rand_f64_abs(
mut rng: impl Rng,
dist: impl Distribution<f64>,
q: u64,
n: usize,
) -> Result<Self> {
// let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng).abs()));
Ok(Self { Ok(Self {
coeffs,
q,
n,
coeffs: std::iter::repeat_with(|| Zq::from_f64(q, dist.sample(&mut rng).abs()))
.take(n)
.collect(),
evals: None, evals: None,
}) })
} }
pub fn rand_f64(mut rng: impl Rng, dist: impl Distribution<f64>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng)));
pub fn rand_f64(
mut rng: impl Rng,
dist: impl Distribution<f64>,
q: u64,
n: usize,
) -> Result<Self> {
// let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng)));
Ok(Self { Ok(Self {
coeffs,
q,
n,
coeffs: std::iter::repeat_with(|| Zq::from_f64(q, dist.sample(&mut rng)))
.take(n)
.collect(),
evals: None, evals: None,
}) })
} }
pub fn rand_u64(mut rng: impl Rng, dist: impl Distribution<u64>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_u64(dist.sample(&mut rng)));
pub fn rand_u64(
mut rng: impl Rng,
dist: impl Distribution<u64>,
q: u64,
n: usize,
) -> Result<Self> {
// let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_u64(dist.sample(&mut rng)));
Ok(Self { Ok(Self {
coeffs,
q,
n,
coeffs: std::iter::repeat_with(|| Zq::from_u64(q, dist.sample(&mut rng)))
.take(n)
.collect(),
evals: None, evals: None,
}) })
} }
// WIP. returns random v \in {0,1}. // TODO {-1, 0, 1} // WIP. returns random v \in {0,1}. // TODO {-1, 0, 1}
pub fn rand_bin(mut rng: impl Rng, dist: impl Distribution<bool>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_bool(dist.sample(&mut rng)));
pub fn rand_bin(
mut rng: impl Rng,
dist: impl Distribution<bool>,
q: u64,
n: usize,
) -> Result<Self> {
// let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_bool(dist.sample(&mut rng)));
Ok(Rq { Ok(Rq {
coeffs,
q,
n,
coeffs: std::iter::repeat_with(|| Zq::from_bool(q, dist.sample(&mut rng)))
.take(n)
.collect(),
evals: None, evals: None,
}) })
} }
@ -208,36 +280,50 @@ impl Rq {
// } // }
// applies mod(T) to all coefficients of self // applies mod(T) to all coefficients of self
pub fn coeffs_mod<const T: u64>(&self) -> Self {
Rq::<Q, N>::from_vec_u64(
pub fn coeffs_mod<const T: u64>(&self, q: u64, n: usize, t: u64) -> Self {
Rq::from_vec_u64(
q,
n,
self.coeffs() self.coeffs()
.iter() .iter()
.map(|m_i| modulus_u64::<T>(m_i.0))
.map(|m_i| modulus_u64(t, m_i.v))
.collect(), .collect(),
) )
} }
// TODO review if needed, or if with this interface // TODO review if needed, or if with this interface
pub fn mul_by_matrix(&self, m: &Vec<Vec<Zq<Q>>>) -> Result<Vec<Zq<Q>>> {
pub fn mul_by_matrix(&self, m: &Vec<Vec<Zq>>) -> Result<Vec<Zq>> {
matrix_vec_product(m, &self.coeffs.to_vec()) matrix_vec_product(m, &self.coeffs.to_vec())
} }
pub fn mul_by_zq(&self, s: &Zq<Q>) -> Self {
pub fn mul_by_zq(&self, s: &Zq) -> Self {
Self { Self {
coeffs: array::from_fn(|i| self.coeffs[i] * *s),
q: self.q,
n: self.n,
// coeffs: array::from_fn(|i| self.coeffs[i] * *s),
coeffs: self.coeffs.iter().map(|c_i| *c_i * *s).collect(),
evals: None, evals: None,
} }
} }
pub fn mul_by_u64(&self, s: u64) -> Self { pub fn mul_by_u64(&self, s: u64) -> Self {
let s = Zq::from_u64(s);
let s = Zq::from_u64(self.q, s);
Self { Self {
coeffs: array::from_fn(|i| self.coeffs[i] * s),
// coeffs: self.coeffs.iter().map(|&e| e * s).collect(),
q: self.q,
n: self.n,
// coeffs: array::from_fn(|i| self.coeffs[i] * s),
coeffs: self.coeffs.iter().map(|&e| e * s).collect(),
evals: None, evals: None,
} }
} }
pub fn mul_by_f64(&self, s: f64) -> Self { pub fn mul_by_f64(&self, s: f64) -> Self {
Self { Self {
coeffs: array::from_fn(|i| Zq::from_f64(self.coeffs[i].0 as f64 * s)),
q: self.q,
n: self.n,
// coeffs: array::from_fn(|i| Zq::from_f64(self.coeffs[i].0 as f64 * s)),
coeffs: self
.coeffs
.iter()
.map(|c_i| Zq::from_f64(self.q, c_i.v as f64 * s))
.collect(),
evals: None, evals: None,
} }
} }
@ -251,9 +337,9 @@ impl Rq {
let r: Vec<f64> = self let r: Vec<f64> = self
.coeffs() .coeffs()
.iter() .iter()
.map(|e| (e.0 as f64 / s as f64).round())
.map(|e| (e.v as f64 / s as f64).round())
.collect(); .collect();
Rq::<Q, N>::from_vec_f64(r)
Rq::from_vec_f64(self.q, self.n, r)
} }
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
@ -261,19 +347,19 @@ impl Rq {
let mut str = ""; let mut str = "";
let mut zero = true; let mut zero = true;
for (i, coeff) in self.coeffs.iter().enumerate().rev() { for (i, coeff) in self.coeffs.iter().enumerate().rev() {
if coeff.0 == 0 {
if coeff.v == 0 {
continue; continue;
} }
zero = false; zero = false;
f.write_str(str)?; f.write_str(str)?;
if coeff.0 != 1 {
f.write_str(coeff.0.to_string().as_str())?;
if coeff.v != 1 {
f.write_str(coeff.v.to_string().as_str())?;
if i > 0 { if i > 0 {
f.write_str("*")?; f.write_str("*")?;
} }
} }
if coeff.0 == 1 && i == 0 {
f.write_str(coeff.0.to_string().as_str())?;
if coeff.v == 1 && i == 0 {
f.write_str(coeff.v.to_string().as_str())?;
} }
if i == 1 { if i == 1 {
f.write_str("x")?; f.write_str("x")?;
@ -288,9 +374,9 @@ impl Rq {
} }
f.write_str(" mod Z_")?; f.write_str(" mod Z_")?;
f.write_str(Q.to_string().as_str())?;
f.write_str(self.q.to_string().as_str())?;
f.write_str("/(X^")?; f.write_str("/(X^")?;
f.write_str(N.to_string().as_str())?;
f.write_str(self.n.to_string().as_str())?;
f.write_str("+1)")?; f.write_str("+1)")?;
Ok(()) Ok(())
} }
@ -298,16 +384,20 @@ impl Rq {
pub fn infinity_norm(&self) -> u64 { pub fn infinity_norm(&self) -> u64 {
self.coeffs() self.coeffs()
.iter() .iter()
.map(|x| if x.0 > (Q / 2) { Q - x.0 } else { x.0 })
.map(|x| {
if x.v > (self.q / 2) {
self.q - x.v
} else {
x.v
}
})
.fold(0, |a, b| a.max(b)) .fold(0, |a, b| a.max(b))
} }
pub fn mod_centered_q(&self) -> crate::ring_n::R<N> {
self.to_r().mod_centered_q::<Q>()
pub fn mod_centered_q(&self) -> crate::ring_n::R {
self.clone().to_r().mod_centered_q(self.q)
} }
} }
pub fn matrix_vec_product<const Q: u64>(m: &Vec<Vec<Zq<Q>>>, v: &Vec<Zq<Q>>) -> Result<Vec<Zq<Q>>> {
// assert_eq!(m.len(), m[0].len()); // TODO change to returning err
// assert_eq!(m.len(), v.len());
pub fn matrix_vec_product(m: &Vec<Vec<Zq>>, v: &Vec<Zq>) -> Result<Vec<Zq>> {
if m.len() != m[0].len() { if m.len() != m[0].len() {
return Err(anyhow!("expected 'm' to be a square matrix")); return Err(anyhow!("expected 'm' to be a square matrix"));
} }
@ -319,6 +409,8 @@ pub fn matrix_vec_product(m: &Vec>>, v: &Vec>) ->
)); ));
} }
assert_eq!(m[0][0].q, v[0].q); // TODO change to returning err
Ok(m.iter() Ok(m.iter()
.map(|row| { .map(|row| {
row.iter() row.iter()
@ -326,12 +418,15 @@ pub fn matrix_vec_product(m: &Vec>>, v: &Vec>) ->
.map(|(&row_i, &v_i)| row_i * v_i) .map(|(&row_i, &v_i)| row_i * v_i)
.sum() .sum()
}) })
.collect::<Vec<Zq<Q>>>())
.collect::<Vec<Zq>>())
} }
pub fn transpose<const Q: u64>(m: &[Vec<Zq<Q>>]) -> Vec<Vec<Zq<Q>>> {
pub fn transpose(m: &[Vec<Zq>]) -> Vec<Vec<Zq>> {
assert!(m.len() > 0);
assert!(m[0].len() > 0);
let q = m[0][0].q;
// TODO case when m[0].len()=0 // TODO case when m[0].len()=0
// TODO non square matrix // TODO non square matrix
let mut r: Vec<Vec<Zq<Q>>> = vec![vec![Zq(0); m[0].len()]; m.len()];
let mut r: Vec<Vec<Zq>> = vec![vec![Zq::zero(q); m[0].len()]; m.len()];
for (i, m_row) in m.iter().enumerate() { for (i, m_row) in m.iter().enumerate() {
for (j, m_ij) in m_row.iter().enumerate() { for (j, m_ij) in m_row.iter().enumerate() {
r[j][i] = *m_ij; r[j][i] = *m_ij;
@ -340,17 +435,24 @@ pub fn transpose(m: &[Vec>]) -> Vec>> {
r r
} }
impl<const Q: u64, const N: usize> PartialEq for Rq<Q, N> {
impl PartialEq for Rq {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.coeffs == other.coeffs
self.coeffs == other.coeffs && self.q == other.q && self.n == other.n
} }
} }
impl<const Q: u64, const N: usize> Add<Rq<Q, N>> for Rq<Q, N> {
impl Add<Rq> for Rq {
type Output = Self; type Output = Self;
fn add(self, rhs: Self) -> Self { fn add(self, rhs: Self) -> Self {
assert_eq!(self.q, rhs.q);
assert_eq!(self.n, rhs.n);
Self { Self {
coeffs: array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]),
q: self.q,
n: self.n,
// coeffs: array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]),
coeffs: zip_eq(self.coeffs, rhs.coeffs)
.map(|(l, r)| l + r)
.collect(),
evals: None, evals: None,
} }
// Self { // Self {
@ -365,180 +467,227 @@ impl Add> for Rq {
// Self(r.iter_mut().map(|e| e.r#mod()).collect()) // TODO mod should happen auto in + // Self(r.iter_mut().map(|e| e.r#mod()).collect()) // TODO mod should happen auto in +
} }
} }
impl<const Q: u64, const N: usize> Add<&Rq<Q, N>> for &Rq<Q, N> {
type Output = Rq<Q, N>;
impl Add<&Rq> for &Rq {
type Output = Rq;
fn add(self, rhs: &Rq<Q, N>) -> Self::Output {
fn add(self, rhs: &Rq) -> Self::Output {
assert_eq!(self.q, rhs.q);
assert_eq!(self.n, rhs.n);
Rq { Rq {
coeffs: array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]),
q: self.q,
n: self.n,
// coeffs: array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]),
coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone())
.map(|(l, r)| l + r)
.collect(),
evals: None, evals: None,
} }
} }
} }
impl<const Q: u64, const N: usize> AddAssign for Rq<Q, N> {
impl AddAssign for Rq {
fn add_assign(&mut self, rhs: Self) { fn add_assign(&mut self, rhs: Self) {
for i in 0..N {
debug_assert_eq!(self.q, rhs.q);
debug_assert_eq!(self.n, rhs.n);
for i in 0..self.n {
self.coeffs[i] += rhs.coeffs[i]; self.coeffs[i] += rhs.coeffs[i];
} }
} }
} }
impl<const Q: u64, const N: usize> Sum<Rq<Q, N>> for Rq<Q, N> {
fn sum<I>(iter: I) -> Self
impl Sum<Rq> for Rq {
fn sum<I>(mut iter: I) -> Self
where where
I: Iterator<Item = Self>, I: Iterator<Item = Self>,
{ {
let mut acc = Rq::<Q, N>::zero();
for e in iter {
acc += e;
}
acc
// let mut acc = Rq::zero();
// for e in iter {
// acc += e;
// }
// acc
let first = iter.next().unwrap();
iter.fold(first, |acc, x| acc + x)
} }
} }
impl<const Q: u64, const N: usize> Sub<Rq<Q, N>> for Rq<Q, N> {
impl Sub<Rq> for Rq {
type Output = Self; type Output = Self;
fn sub(self, rhs: Self) -> Self { fn sub(self, rhs: Self) -> Self {
assert_eq!(self.q, rhs.q);
assert_eq!(self.n, rhs.n);
Self { Self {
coeffs: array::from_fn(|i| self.coeffs[i] - rhs.coeffs[i]),
q: self.q,
n: self.n,
// coeffs: array::from_fn(|i| self.coeffs[i] - rhs.coeffs[i]),
coeffs: zip_eq(self.coeffs, rhs.coeffs)
.map(|(l, r)| l - r)
.collect(),
evals: None, evals: None,
} }
} }
} }
impl<const Q: u64, const N: usize> Sub<&Rq<Q, N>> for &Rq<Q, N> {
type Output = Rq<Q, N>;
impl Sub<&Rq> for &Rq {
type Output = Rq;
fn sub(self, rhs: &Rq<Q, N>) -> Self::Output {
fn sub(self, rhs: &Rq) -> Self::Output {
assert_eq!(self.q, rhs.q); // TODO replace all those with debug_assert_eq
debug_assert_eq!(self.n, rhs.n);
Rq { Rq {
coeffs: array::from_fn(|i| self.coeffs[i] - rhs.coeffs[i]),
q: self.q,
n: self.n,
// coeffs: array::from_fn(|i| self.coeffs[i] - rhs.coeffs[i]),
coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone())
.map(|(l, r)| l - r)
.collect(),
evals: None, evals: None,
} }
} }
} }
impl<const Q: u64, const N: usize> SubAssign for Rq<Q, N> {
impl SubAssign for Rq {
fn sub_assign(&mut self, rhs: Self) { fn sub_assign(&mut self, rhs: Self) {
for i in 0..N {
debug_assert_eq!(self.q, rhs.q);
debug_assert_eq!(self.n, rhs.n);
for i in 0..self.n {
self.coeffs[i] -= rhs.coeffs[i]; self.coeffs[i] -= rhs.coeffs[i];
} }
} }
} }
impl<const Q: u64, const N: usize> Mul<Rq<Q, N>> for Rq<Q, N> {
impl Mul<Rq> for Rq {
type Output = Self; type Output = Self;
fn mul(self, rhs: Self) -> Self { fn mul(self, rhs: Self) -> Self {
mul(&self, &rhs) mul(&self, &rhs)
} }
} }
impl<const Q: u64, const N: usize> Mul<&Rq<Q, N>> for &Rq<Q, N> {
type Output = Rq<Q, N>;
impl Mul<&Rq> for &Rq {
type Output = Rq;
fn mul(self, rhs: &Rq<Q, N>) -> Self::Output {
fn mul(self, rhs: &Rq) -> Self::Output {
mul(self, rhs) mul(self, rhs)
} }
} }
// mul by Zq element // mul by Zq element
impl<const Q: u64, const N: usize> Mul<Zq<Q>> for Rq<Q, N> {
impl Mul<Zq> for Rq {
type Output = Self; type Output = Self;
fn mul(self, s: Zq<Q>) -> Self {
fn mul(self, s: Zq) -> Self {
self.mul_by_zq(&s) self.mul_by_zq(&s)
} }
} }
impl<const Q: u64, const N: usize> Mul<&Zq<Q>> for &Rq<Q, N> {
type Output = Rq<Q, N>;
impl Mul<&Zq> for &Rq {
type Output = Rq;
fn mul(self, s: &Zq<Q>) -> Self::Output {
fn mul(self, s: &Zq) -> Self::Output {
self.mul_by_zq(s) self.mul_by_zq(s)
} }
} }
// mul by u64 // mul by u64
impl<const Q: u64, const N: usize> Mul<u64> for Rq<Q, N> {
impl Mul<u64> for Rq {
type Output = Self; type Output = Self;
fn mul(self, s: u64) -> Self { fn mul(self, s: u64) -> Self {
self.mul_by_u64(s) self.mul_by_u64(s)
} }
} }
impl<const Q: u64, const N: usize> Mul<&u64> for &Rq<Q, N> {
type Output = Rq<Q, N>;
impl Mul<&u64> for &Rq {
type Output = Rq;
fn mul(self, s: &u64) -> Self::Output { fn mul(self, s: &u64) -> Self::Output {
self.mul_by_u64(*s) self.mul_by_u64(*s)
} }
} }
// mul by f64 // mul by f64
impl<const Q: u64, const N: usize> Mul<f64> for Rq<Q, N> {
impl Mul<f64> for Rq {
type Output = Self; type Output = Self;
fn mul(self, s: f64) -> Self { fn mul(self, s: f64) -> Self {
self.mul_by_f64(s) self.mul_by_f64(s)
} }
} }
impl<const Q: u64, const N: usize> Mul<&f64> for &Rq<Q, N> {
type Output = Rq<Q, N>;
impl Mul<&f64> for &Rq {
type Output = Rq;
fn mul(self, s: &f64) -> Self::Output { fn mul(self, s: &f64) -> Self::Output {
self.mul_by_f64(*s) self.mul_by_f64(*s)
} }
} }
impl<const Q: u64, const N: usize> Neg for Rq<Q, N> {
impl Neg for Rq {
type Output = Self; type Output = Self;
fn neg(self) -> Self::Output { fn neg(self) -> Self::Output {
Self { Self {
coeffs: array::from_fn(|i| -self.coeffs[i]),
q: self.q,
n: self.n,
// coeffs: array::from_fn(|i| -self.coeffs[i]),
// coeffs: self.coeffs.iter().map(|c_i| -c_i).collect(),
coeffs: self.coeffs.iter().map(|c_i| -*c_i).collect(),
evals: None, evals: None,
} }
} }
} }
// note: this assumes that Q is prime // note: this assumes that Q is prime
fn mul_mut<const Q: u64, const N: usize>(lhs: &mut Rq<Q, N>, rhs: &mut Rq<Q, N>) -> Rq<Q, N> {
fn mul_mut(lhs: &mut Rq, rhs: &mut Rq) -> Rq {
assert_eq!(lhs.q, rhs.q);
assert_eq!(lhs.n, rhs.n);
let (q, n) = (lhs.q, lhs.n);
// reuse evaluations if already computed // reuse evaluations if already computed
if !lhs.evals.is_some() { if !lhs.evals.is_some() {
lhs.evals = Some(NTT::<Q, N>::ntt(lhs.coeffs));
lhs.evals = Some(NTT::ntt(lhs).coeffs);
}; };
if !rhs.evals.is_some() { if !rhs.evals.is_some() {
rhs.evals = Some(NTT::<Q, N>::ntt(rhs.coeffs));
rhs.evals = Some(NTT::ntt(rhs).coeffs);
}; };
let lhs_evals = lhs.evals.unwrap();
let rhs_evals = rhs.evals.unwrap();
let c_ntt: [Zq<Q>; N] = array::from_fn(|i| lhs_evals[i] * rhs_evals[i]);
let c = NTT::<Q, { N }>::intt(c_ntt);
Rq::new(c, Some(c_ntt))
let lhs_evals = lhs.evals.clone().unwrap();
let rhs_evals = rhs.evals.clone().unwrap();
// let c_ntt: [Zq<Q>; N] = array::from_fn(|i| lhs_evals[i] * rhs_evals[i]);
let c_ntt: Rq = Rq::from_vec(
(q, n),
zip_eq(lhs_evals, rhs_evals).map(|(l, r)| l * r).collect(),
);
let c = NTT::intt(&c_ntt);
Rq::new(q, n, c.coeffs, Some(c_ntt.coeffs))
} }
// note: this assumes that Q is prime // note: this assumes that Q is prime
// TODO impl karatsuba for non-prime Q // TODO impl karatsuba for non-prime Q
fn mul<const Q: u64, const N: usize>(lhs: &Rq<Q, N>, rhs: &Rq<Q, N>) -> Rq<Q, N> {
fn mul(lhs: &Rq, rhs: &Rq) -> Rq {
assert_eq!(lhs.q, rhs.q);
assert_eq!(lhs.n, rhs.n);
let (q, n) = (lhs.q, lhs.n);
// reuse evaluations if already computed // reuse evaluations if already computed
let lhs_evals = if lhs.evals.is_some() {
lhs.evals.unwrap()
let lhs_evals: Vec<Zq> = if lhs.evals.is_some() {
lhs.evals.clone().unwrap()
} else { } else {
NTT::<Q, N>::ntt(lhs.coeffs)
NTT::ntt(lhs).coeffs
}; };
let rhs_evals = if rhs.evals.is_some() {
rhs.evals.unwrap()
let rhs_evals: Vec<Zq> = if rhs.evals.is_some() {
rhs.evals.clone().unwrap()
} else { } else {
NTT::<Q, N>::ntt(rhs.coeffs)
NTT::ntt(rhs).coeffs
}; };
let c_ntt: [Zq<Q>; N] = array::from_fn(|i| lhs_evals[i] * rhs_evals[i]);
let c = NTT::<Q, { N }>::intt(c_ntt);
Rq::new(c, Some(c_ntt))
// let c_ntt: [Zq<Q>; N] = array::from_fn(|i| lhs_evals[i] * rhs_evals[i]);
let c_ntt: Rq = Rq::from_vec(
(q, n),
zip_eq(lhs_evals, rhs_evals).map(|(l, r)| l * r).collect(),
);
let c = NTT::intt(&c_ntt);
Rq::new(q, n, c.coeffs, Some(c_ntt.coeffs))
} }
impl<const Q: u64, const N: usize> fmt::Display for Rq<Q, N> {
impl fmt::Display for Rq {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt(f)?; self.fmt(f)?;
Ok(()) Ok(())
} }
} }
impl<const Q: u64, const N: usize> fmt::Debug for Rq<Q, N> {
impl fmt::Debug for Rq {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt(f)?; self.fmt(f)?;
Ok(()) Ok(())
@ -552,31 +701,31 @@ mod tests {
#[test] #[test]
fn test_polynomial_ring() { fn test_polynomial_ring() {
// the test values used are generated with SageMath // the test values used are generated with SageMath
const Q: u64 = 7;
const N: usize = 3;
let q: u64 = 7;
let n: usize = 3;
// p = 1x + 2x^2 + 3x^3 + 4 x^4 + 5 x^5 in R=Z_q[X]/(X^n +1) // p = 1x + 2x^2 + 3x^3 + 4 x^4 + 5 x^5 in R=Z_q[X]/(X^n +1)
let p = Rq::<Q, N>::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]);
let p = Rq::from_vec_u64(q, n, vec![0u64, 1, 2, 3, 4, 5]);
assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)"); assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
// try with coefficients bigger than Q // try with coefficients bigger than Q
let p = Rq::<Q, N>::from_vec_u64(vec![0u64, 1, Q + 2, 3, 4, 5]);
let p = Rq::from_vec_u64(q, n, vec![0u64, 1, q + 2, 3, 4, 5]);
assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)"); assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
// try with other ring // try with other ring
let p = Rq::<7, 4>::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]);
let p = Rq::from_vec_u64(7, 4, vec![0u64, 1, 2, 3, 4, 5]);
assert_eq!(p.to_string(), "3*x^3 + 2*x^2 + 3*x + 3 mod Z_7/(X^4+1)"); assert_eq!(p.to_string(), "3*x^3 + 2*x^2 + 3*x + 3 mod Z_7/(X^4+1)");
let p = Rq::<Q, N>::from_vec_u64(vec![0u64, 0, 0, 0, 4, 5]);
let p = Rq::from_vec_u64(q, n, vec![0u64, 0, 0, 0, 4, 5]);
assert_eq!(p.to_string(), "2*x^2 + 3*x mod Z_7/(X^3+1)"); assert_eq!(p.to_string(), "2*x^2 + 3*x mod Z_7/(X^3+1)");
let p = Rq::<Q, N>::from_vec_u64(vec![5u64, 4, 5, 2, 1, 0]);
let p = Rq::from_vec_u64(q, n, vec![5u64, 4, 5, 2, 1, 0]);
assert_eq!(p.to_string(), "5*x^2 + 3*x + 3 mod Z_7/(X^3+1)"); assert_eq!(p.to_string(), "5*x^2 + 3*x + 3 mod Z_7/(X^3+1)");
let a = Rq::<Q, N>::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]);
let a = Rq::from_vec_u64(q, n, vec![0u64, 1, 2, 3, 4, 5]);
assert_eq!(a.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)"); assert_eq!(a.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
let b = Rq::<Q, N>::from_vec_u64(vec![5u64, 4, 3, 2, 1, 0]);
let b = Rq::from_vec_u64(q, n, vec![5u64, 4, 3, 2, 1, 0]);
assert_eq!(b.to_string(), "3*x^2 + 3*x + 3 mod Z_7/(X^3+1)"); assert_eq!(b.to_string(), "3*x^2 + 3*x + 3 mod Z_7/(X^3+1)");
// add // add
@ -593,34 +742,39 @@ mod tests {
#[test] #[test]
fn test_mul() -> Result<()> { fn test_mul() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 4;
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 4;
let a: [u64; N] = [1u64, 2, 3, 4];
let b: [u64; N] = [1u64, 2, 3, 4];
let c: [u64; N] = [65513, 65517, 65531, 20];
test_mul_opt::<Q, N>(a, b, c)?;
let a: Vec<u64> = vec![1u64, 2, 3, 4];
let b: Vec<u64> = vec![1u64, 2, 3, 4];
let c: Vec<u64> = vec![65513, 65517, 65531, 20];
test_mul_opt(q, n, a, b, c)?;
let a: [u64; N] = [0u64, 0, 0, 2];
let b: [u64; N] = [0u64, 0, 0, 2];
let c: [u64; N] = [0u64, 0, 65533, 0];
test_mul_opt::<Q, N>(a, b, c)?;
let a: Vec<u64> = vec![0u64, 0, 0, 2];
let b: Vec<u64> = vec![0u64, 0, 0, 2];
let c: Vec<u64> = vec![0u64, 0, 65533, 0];
test_mul_opt(q, n, a, b, c)?;
// TODO more testvectors // TODO more testvectors
Ok(()) Ok(())
} }
fn test_mul_opt<const Q: u64, const N: usize>(
a: [u64; N],
b: [u64; N],
expected_c: [u64; N],
fn test_mul_opt(
q: u64,
n: usize,
a: Vec<u64>,
b: Vec<u64>,
expected_c: Vec<u64>,
) -> Result<()> { ) -> Result<()> {
let a: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(a[i]));
let mut a = Rq::new(a, None);
let b: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(b[i]));
let mut b = Rq::new(b, None);
let expected_c: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(expected_c[i]));
let expected_c = Rq::new(expected_c, None);
assert_eq!(a.len(), n);
assert_eq!(b.len(), n);
// let a: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(a[i]));
let mut a = Rq::from_vec_u64(q, n, a);
// let b: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(b[i]));
let mut b = Rq::from_vec_u64(q, n, b);
// let expected_c: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(expected_c[i]));
let expected_c = Rq::from_vec_u64(q, n, expected_c);
let c = mul_mut(&mut a, &mut b); let c = mul_mut(&mut a, &mut b);
assert_eq!(c, expected_c); assert_eq!(c, expected_c);
@ -629,26 +783,26 @@ mod tests {
#[test] #[test]
fn test_rq_decompose() -> Result<()> { fn test_rq_decompose() -> Result<()> {
const Q: u64 = 16;
const N: usize = 4;
let q: u64 = 16;
let n: usize = 4;
let beta = 4; let beta = 4;
let l = 2; let l = 2;
let a = Rq::<Q, N>::from_vec_u64(vec![7u64, 14, 3, 6]);
let a = Rq::from_vec_u64(q, n, vec![7u64, 14, 3, 6]);
let d = a.decompose(beta, l); let d = a.decompose(beta, l);
assert_eq!( assert_eq!(
d[0].coeffs().to_vec(),
d[0].coeffs(),
vec![1u64, 3, 0, 1] vec![1u64, 3, 0, 1]
.iter() .iter()
.map(|e| Zq::<Q>::from_u64(*e))
.map(|e| Zq::from_u64(q, *e))
.collect::<Vec<_>>() .collect::<Vec<_>>()
); );
assert_eq!( assert_eq!(
d[1].coeffs().to_vec(),
d[1].coeffs(),
vec![3u64, 2, 3, 2] vec![3u64, 2, 3, 2]
.iter() .iter()
.map(|e| Zq::<Q>::from_u64(*e))
.map(|e| Zq::from_u64(q, *e))
.collect::<Vec<_>>() .collect::<Vec<_>>()
); );
Ok(()) Ok(())

+ 179
- 95
arith/src/ring_torus.rs

@ -7,6 +7,7 @@
//! u64, we fit it into the `Ring` trait (from ring.rs) so that we can compose //! u64, we fit it into the `Ring` trait (from ring.rs) so that we can compose
//! the 𝕋_<N,q> implementation with the other objects from the code. //! the 𝕋_<N,q> implementation with the other objects from the code.
use itertools::zip_eq;
use rand::{distributions::Distribution, Rng}; use rand::{distributions::Distribution, Rng};
use std::array; use std::array;
use std::iter::Sum; use std::iter::Sum;
@ -16,54 +17,75 @@ use crate::{ring::Ring, torus::T64, Rq, Zq};
/// 𝕋_<N,Q>[X] = 𝕋<Q>[X]/(X^N +1), polynomials modulo X^N+1 with coefficients in /// 𝕋_<N,Q>[X] = 𝕋<Q>[X]/(X^N +1), polynomials modulo X^N+1 with coefficients in
/// 𝕋, where Q=2^64. /// 𝕋, where Q=2^64.
#[derive(Clone, Copy, Debug)]
pub struct Tn<const N: usize>(pub [T64; N]);
#[derive(Clone, Debug)]
pub struct Tn {
pub n: usize,
pub coeffs: Vec<T64>,
}
impl<const N: usize> Ring for Tn<N> {
impl Ring for Tn {
type C = T64; type C = T64;
type Params = usize; // n
const Q: u64 = u64::MAX; // WIP
const N: usize = N;
// const Q: u64 = u64::MAX; // WIP
// const N: usize = N;
fn coeffs(&self) -> Vec<T64> { fn coeffs(&self) -> Vec<T64> {
self.0.to_vec()
self.coeffs.to_vec()
} }
fn zero() -> Self {
Self(array::from_fn(|_| T64::zero()))
fn zero(n: usize) -> Self {
Self {
n,
coeffs: vec![T64::zero(()); n],
}
} }
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>) -> Self {
Self(array::from_fn(|_| T64::rand(&mut rng, &dist)))
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>, n: usize) -> Self {
Self {
n,
coeffs: std::iter::repeat_with(|| T64::rand(&mut rng, &dist, ()))
.take(n)
.collect(),
}
// Self(array::from_fn(|_| T64::rand(&mut rng, &dist)))
} }
fn from_vec(coeffs: Vec<Self::C>) -> Self {
fn from_vec(n: usize, coeffs: Vec<Self::C>) -> Self {
let mut p = coeffs; let mut p = coeffs;
modulus::<N>(&mut p);
Self(array::from_fn(|i| p[i]))
modulus(n, &mut p);
Self { n, coeffs: p }
} }
fn decompose(&self, beta: u32, l: u32) -> Vec<Self> { fn decompose(&self, beta: u32, l: u32) -> Vec<Self> {
let elems: Vec<Vec<T64>> = self.0.iter().map(|r| r.decompose(beta, l)).collect();
let elems: Vec<Vec<T64>> = self.coeffs.iter().map(|r| r.decompose(beta, l)).collect();
// transpose it // transpose it
let r: Vec<Vec<T64>> = (0..elems[0].len()) let r: Vec<Vec<T64>> = (0..elems[0].len())
.map(|i| (0..elems.len()).map(|j| elems[j][i]).collect()) .map(|i| (0..elems.len()).map(|j| elems[j][i]).collect())
.collect(); .collect();
// convert it to Tn<N>
r.iter().map(|a_i| Self::from_vec(a_i.clone())).collect()
// convert it to Tn
r.iter()
.map(|a_i| Self::from_vec(self.n, a_i.clone()))
.collect()
} }
fn remodule<const P: u64>(&self) -> Tn<N> {
fn remodule(&self, p: u64) -> Tn {
todo!() todo!()
// Rq::<P, N>::from_vec_u64(self.coeffs().iter().map(|m_i| m_i.0).collect()) // Rq::<P, N>::from_vec_u64(self.coeffs().iter().map(|m_i| m_i.0).collect())
} }
// fn mod_switch<const P: u64>(&self) -> impl Ring { // fn mod_switch<const P: u64>(&self) -> impl Ring {
fn mod_switch<const P: u64>(&self) -> Rq<P, N> {
fn mod_switch(&self, p: u64) -> Rq {
// unimplemented!() // unimplemented!()
// TODO WIP // TODO WIP
let coeffs = array::from_fn(|i| Zq::<P>::from_u64(self.0[i].mod_switch::<P>().0));
Rq::<P, N> {
let coeffs = self
.coeffs
.iter()
.map(|c_i| Zq::from_u64(p, c_i.mod_switch(p).0))
.collect();
Rq {
q: p,
n: self.n,
coeffs, coeffs,
evals: None, evals: None,
} }
@ -78,175 +100,234 @@ impl Ring for Tn {
.iter() .iter()
.map(|e| T64(((num as f64 * e.0 as f64) / den as f64).round() as u64)) .map(|e| T64(((num as f64 * e.0 as f64) / den as f64).round() as u64))
.collect(); .collect();
Self::from_vec(r)
Self::from_vec(self.n, r)
} }
} }
impl<const N: usize> Tn<N> {
impl Tn {
// multiply self by X^-h // multiply self by X^-h
pub fn left_rotate(&self, h: usize) -> Self { pub fn left_rotate(&self, h: usize) -> Self {
let h = h % N;
assert!(h < N);
let c = self.0;
let n = self.n;
let h = h % n;
assert!(h < n);
let c = &self.coeffs;
// c[h], c[h+1], c[h+2], ..., c[n-1], -c[0], -c[1], ..., -c[h-1] // c[h], c[h+1], c[h+2], ..., c[n-1], -c[0], -c[1], ..., -c[h-1]
// let r: Vec<T64> = vec![c[h..N], c[0..h].iter().map(|&c_i| -c_i).collect()].concat(); // let r: Vec<T64> = vec![c[h..N], c[0..h].iter().map(|&c_i| -c_i).collect()].concat();
let r: Vec<T64> = c[h..N]
let r: Vec<T64> = c[h..n]
.iter() .iter()
.copied() .copied()
.chain(c[0..h].iter().map(|&x| -x)) .chain(c[0..h].iter().map(|&x| -x))
.collect(); .collect();
Self::from_vec(r)
Self::from_vec(self.n, r)
} }
pub fn from_vec_u64(v: Vec<u64>) -> Self {
pub fn from_vec_u64(n: usize, v: Vec<u64>) -> Self {
let coeffs = v.iter().map(|c| T64(*c)).collect(); let coeffs = v.iter().map(|c| T64(*c)).collect();
Self::from_vec(coeffs)
Self::from_vec(n, coeffs)
} }
} }
// apply mod (X^N+1) // apply mod (X^N+1)
pub fn modulus<const N: usize>(p: &mut Vec<T64>) {
if p.len() < N {
pub fn modulus(n: usize, p: &mut Vec<T64>) {
if p.len() < n {
return; return;
} }
for i in N..p.len() {
p[i - N] = p[i - N].clone() - p[i].clone();
p[i] = T64::zero();
for i in n..p.len() {
p[i - n] = p[i - n].clone() - p[i].clone();
p[i] = T64::zero(());
} }
p.truncate(N);
p.truncate(n);
} }
impl<const N: usize> Add<Tn<N>> for Tn<N> {
impl Add<Tn> for Tn {
type Output = Self; type Output = Self;
fn add(self, rhs: Self) -> Self { fn add(self, rhs: Self) -> Self {
Self(array::from_fn(|i| self.0[i] + rhs.0[i]))
// Self(array::from_fn(|i| self.0[i] + rhs.0[i]))
assert_eq!(self.n, rhs.n);
Self {
n: self.n,
coeffs: zip_eq(self.coeffs, rhs.coeffs)
.map(|(l, r)| l + r)
.collect(),
}
} }
} }
impl<const N: usize> Add<&Tn<N>> for &Tn<N> {
type Output = Tn<N>;
fn add(self, rhs: &Tn<N>) -> Self::Output {
Tn(array::from_fn(|i| self.0[i] + rhs.0[i]))
impl Add<&Tn> for &Tn {
type Output = Tn;
fn add(self, rhs: &Tn) -> Self::Output {
// Tn(array::from_fn(|i| self.0[i] + rhs.0[i]))
assert_eq!(self.n, rhs.n);
Tn {
n: self.n,
coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone())
.map(|(l, r)| l + r)
.collect(),
}
} }
} }
impl<const N: usize> AddAssign for Tn<N> {
impl AddAssign for Tn {
fn add_assign(&mut self, rhs: Self) { fn add_assign(&mut self, rhs: Self) {
for i in 0..N {
self.0[i] += rhs.0[i];
assert_eq!(self.n, rhs.n);
for i in 0..self.n {
self.coeffs[i] += rhs.coeffs[i];
} }
} }
} }
impl<const N: usize> Sum<Tn<N>> for Tn<N> {
impl Sum<Tn> for Tn {
fn sum<I>(iter: I) -> Self fn sum<I>(iter: I) -> Self
where where
I: Iterator<Item = Self>, I: Iterator<Item = Self>,
{ {
let mut acc = Tn::<N>::zero();
for e in iter {
acc += e;
}
acc
// let mut acc = Tn::<N>::zero();
// for e in iter {
// acc += e;
// }
// acc
let first = *iter.next().unwrap().borrow();
iter.fold(first, |acc, x| acc + x)
} }
} }
impl<const N: usize> Sub<Tn<N>> for Tn<N> {
impl Sub<Tn> for Tn {
type Output = Self; type Output = Self;
fn sub(self, rhs: Self) -> Self { fn sub(self, rhs: Self) -> Self {
Self(array::from_fn(|i| self.0[i] - rhs.0[i]))
assert_eq!(self.n, rhs.n);
Self {
n: self.n,
coeffs: zip_eq(self.coeffs, rhs.coeffs)
.map(|(l, r)| l - r)
.collect(),
}
} }
} }
impl<const N: usize> Sub<&Tn<N>> for &Tn<N> {
type Output = Tn<N>;
fn sub(self, rhs: &Tn<N>) -> Self::Output {
Tn(array::from_fn(|i| self.0[i] - rhs.0[i]))
impl Sub<&Tn> for &Tn {
type Output = Tn;
fn sub(self, rhs: &Tn) -> Self::Output {
// Tn(array::from_fn(|i| self.0[i] - rhs.0[i]))
assert_eq!(self.n, rhs.n);
Tn {
n: self.n,
coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone())
.map(|(l, r)| l - r)
.collect(),
}
} }
} }
impl<const N: usize> SubAssign for Tn<N> {
impl SubAssign for Tn {
fn sub_assign(&mut self, rhs: Self) { fn sub_assign(&mut self, rhs: Self) {
for i in 0..N {
self.0[i] -= rhs.0[i];
// for i in 0..N {
// self.0[i] -= rhs.0[i];
// }
assert_eq!(self.n, rhs.n);
for i in 0..self.n {
self.coeffs[i] -= rhs.coeffs[i];
} }
} }
} }
impl<const N: usize> Neg for Tn<N> {
impl Neg for Tn {
type Output = Self; type Output = Self;
fn neg(self) -> Self::Output { fn neg(self) -> Self::Output {
Tn(array::from_fn(|i| -self.0[i]))
// Tn(array::from_fn(|i| -self.0[i]))
Self {
n: self.n,
coeffs: self.coeffs.iter().map(|c_i| -*c_i).collect(),
}
} }
} }
impl<const N: usize> PartialEq for Tn<N> {
impl PartialEq for Tn {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.0 == other.0
self.coeffs == other.coeffs && self.n == other.n
} }
} }
impl<const N: usize> Mul<Tn<N>> for Tn<N> {
impl Mul<Tn> for Tn {
type Output = Self; type Output = Self;
fn mul(self, rhs: Self) -> Self { fn mul(self, rhs: Self) -> Self {
naive_poly_mul(&self, &rhs) naive_poly_mul(&self, &rhs)
} }
} }
impl<const N: usize> Mul<&Tn<N>> for &Tn<N> {
type Output = Tn<N>;
impl Mul<&Tn> for &Tn {
type Output = Tn;
fn mul(self, rhs: &Tn<N>) -> Self::Output {
fn mul(self, rhs: &Tn) -> Self::Output {
naive_poly_mul(self, rhs) naive_poly_mul(self, rhs)
} }
} }
fn naive_poly_mul<const N: usize>(poly1: &Tn<N>, poly2: &Tn<N>) -> Tn<N> {
let poly1: Vec<u128> = poly1.0.iter().map(|c| c.0 as u128).collect();
let poly2: Vec<u128> = poly2.0.iter().map(|c| c.0 as u128).collect();
let mut result: Vec<u128> = vec![0; (N * 2) - 1];
for i in 0..N {
for j in 0..N {
fn naive_poly_mul(poly1: &Tn, poly2: &Tn) -> Tn {
assert_eq!(poly1.n, poly2.n);
let n = poly1.n;
let poly1: Vec<u128> = poly1.coeffs.iter().map(|c| c.0 as u128).collect();
let poly2: Vec<u128> = poly2.coeffs.iter().map(|c| c.0 as u128).collect();
let mut result: Vec<u128> = vec![0; (n * 2) - 1];
for i in 0..n {
for j in 0..n {
result[i + j] = result[i + j] + poly1[i] * poly2[j]; result[i + j] = result[i + j] + poly1[i] * poly2[j];
} }
} }
// apply mod (X^N + 1))
modulus_u128::<N>(&mut result);
// apply mod (X^n + 1))
modulus_u128(n, &mut result);
Tn(array::from_fn(|i| T64(result[i] as u64)))
Tn {
n,
// coeffs: array::from_fn(|i| T64(result[i] as u64)),
coeffs: result.iter().map(|r_i| T64(*r_i as u64)).collect(),
}
} }
fn modulus_u128<const N: usize>(p: &mut Vec<u128>) {
if p.len() < N {
fn modulus_u128(n: usize, p: &mut Vec<u128>) {
if p.len() < n {
return; return;
} }
for i in N..p.len() {
// p[i - N] = p[i - N].clone() - p[i].clone();
p[i - N] = p[i - N].wrapping_sub(p[i]);
for i in n..p.len() {
// p[i - n] = p[i - n].clone() - p[i].clone();
p[i - n] = p[i - n].wrapping_sub(p[i]);
p[i] = 0; p[i] = 0;
} }
p.truncate(N);
p.truncate(n);
} }
impl<const N: usize> Mul<T64> for Tn<N> {
impl Mul<T64> for Tn {
type Output = Self; type Output = Self;
fn mul(self, s: T64) -> Self { fn mul(self, s: T64) -> Self {
Self(array::from_fn(|i| self.0[i] * s))
Self {
n: self.n,
// coeffs: array::from_fn(|i| self.coeffs[i] * s),
coeffs: self.coeffs.iter().map(|c_i| *c_i * s).collect(),
}
} }
} }
// mul by u64 // mul by u64
impl<const N: usize> Mul<u64> for Tn<N> {
impl Mul<u64> for Tn {
type Output = Self; type Output = Self;
fn mul(self, s: u64) -> Self { fn mul(self, s: u64) -> Self {
Self(array::from_fn(|i| self.0[i] * s))
// Self(array::from_fn(|i| self.0[i] * s))
Tn {
n: self.n,
coeffs: self.coeffs.iter().map(|c_i| *c_i * s).collect(),
}
} }
} }
impl<const N: usize> Mul<&u64> for &Tn<N> {
type Output = Tn<N>;
impl Mul<&u64> for &Tn {
type Output = Tn;
fn mul(self, s: &u64) -> Self::Output { fn mul(self, s: &u64) -> Self::Output {
Tn::<N>(array::from_fn(|i| self.0[i] * *s))
// Tn::<N>(array::from_fn(|i| self.0[i] * *s))
Self {
n: self.n,
coeffs: self.coeffs.iter().map(|c_i| c_i * s).collect(),
}
} }
} }
@ -256,8 +337,9 @@ mod tests {
#[test] #[test]
fn test_left_rotate() { fn test_left_rotate() {
const N: usize = 4;
let f = Tn::<N>::from_vec(
let n: usize = 4;
let f = Tn::from_vec(
n,
vec![2i64, 3, -4, -1] vec![2i64, 3, -4, -1]
.iter() .iter()
.map(|c| T64(*c as u64)) .map(|c| T64(*c as u64))
@ -267,7 +349,8 @@ mod tests {
// expect f*x^-3 == -1 -2x -3x^2 +4x^3 // expect f*x^-3 == -1 -2x -3x^2 +4x^3
assert_eq!( assert_eq!(
f.left_rotate(3), f.left_rotate(3),
Tn::<N>::from_vec(
Tn::from_vec(
n,
vec![-1i64, -2, -3, 4] vec![-1i64, -2, -3, 4]
.iter() .iter()
.map(|c| T64(*c as u64)) .map(|c| T64(*c as u64))
@ -277,7 +360,8 @@ mod tests {
// expect f*x^-1 == 3 -4x -1x^2 -2x^3 // expect f*x^-1 == 3 -4x -1x^2 -2x^3
assert_eq!( assert_eq!(
f.left_rotate(1), f.left_rotate(1),
Tn::<N>::from_vec(
Tn::from_vec(
n,
vec![3i64, -4, -1, -2] vec![3i64, -4, -1, -2]
.iter() .iter()
.map(|c| T64(*c as u64)) .map(|c| T64(*c as u64))

+ 12
- 10
arith/src/torus.rs

@ -16,20 +16,22 @@ pub struct T64(pub u64);
// `Tn<1>`. // `Tn<1>`.
impl Ring for T64 { impl Ring for T64 {
type C = T64; type C = T64;
const Q: u64 = u64::MAX; // WIP
const N: usize = 1;
type Params = ();
// const Q: u64 = u64::MAX; // WIP
// const N: usize = 1;
fn coeffs(&self) -> Vec<T64> { fn coeffs(&self) -> Vec<T64> {
vec![self.clone()] vec![self.clone()]
} }
fn zero() -> Self {
fn zero(_: ()) -> Self {
Self(0u64) Self(0u64)
} }
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>) -> Self {
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>, _: ()) -> Self {
let r: f64 = dist.sample(&mut rng); let r: f64 = dist.sample(&mut rng);
Self(r.round() as u64) Self(r.round() as u64)
} }
fn from_vec(coeffs: Vec<Self::C>) -> Self {
fn from_vec(_n: (), coeffs: Vec<Self::C>) -> Self {
assert_eq!(coeffs.len(), 1); assert_eq!(coeffs.len(), 1);
coeffs[0] coeffs[0]
} }
@ -46,18 +48,18 @@ impl Ring for T64 {
.map(|i| T64(((self.0 >> i) & 1) as u64)) .map(|i| T64(((self.0 >> i) & 1) as u64))
.collect() .collect()
} }
fn remodule<const P: u64>(&self) -> T64 {
fn remodule(&self, p: u64) -> T64 {
todo!() todo!()
} }
// modulus switch from Q to Q2: self * Q2/Q // modulus switch from Q to Q2: self * Q2/Q
fn mod_switch<const Q2: u64>(&self) -> T64 {
fn mod_switch(&self, q2: u64) -> T64 {
// for the moment we assume Q|Q2, since Q=2^64, check that Q2 is a power // for the moment we assume Q|Q2, since Q=2^64, check that Q2 is a power
// of two: // of two:
assert!(Q2.is_power_of_two());
assert!(q2.is_power_of_two());
// since Q=2^64, dividing Q2/Q is equivalent to dividing 2^log2(Q2)/2^64 // since Q=2^64, dividing Q2/Q is equivalent to dividing 2^log2(Q2)/2^64
// which would be like right-shifting 64-log2(Q2). // which would be like right-shifting 64-log2(Q2).
let log2_q2 = 63 - Q2.leading_zeros();
let log2_q2 = 63 - q2.leading_zeros();
T64(self.0 >> (64 - log2_q2)) T64(self.0 >> (64 - log2_q2))
} }
@ -175,7 +177,7 @@ mod tests {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..1000 { for _ in 0..1000 {
let x = T64::rand(&mut rng, Standard);
let x = T64::rand(&mut rng, Standard, ());
let d = x.decompose(beta, l); let d = x.decompose(beta, l);
assert_eq!(recompose(d), x); assert_eq!(recompose(d), x);
} }

+ 79
- 40
arith/src/tuple_ring.rs

@ -16,25 +16,37 @@ use crate::Ring;
/// Tuple of K Ring (Rq) elements. We use Vec<R> to allocate it in the heap, /// Tuple of K Ring (Rq) elements. We use Vec<R> to allocate it in the heap,
/// since if using a fixed-size array it would overflow the stack. /// since if using a fixed-size array it would overflow the stack.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct TR<R: Ring, const K: usize>(pub Vec<R>);
pub struct TR<R: Ring> {
pub k: usize,
pub r: Vec<R>,
}
// TODO rm pub from Vec<R>, so that TR can not be created from a Vec with // TODO rm pub from Vec<R>, so that TR can not be created from a Vec with
// invalid length, since it has to be created using the `new` method. // invalid length, since it has to be created using the `new` method.
impl<R: Ring, const K: usize> TR<R, K> {
pub fn new(v: Vec<R>) -> Self {
assert_eq!(v.len(), K);
Self(v)
impl<R: Ring> TR<R> {
pub fn new(k: usize, r: Vec<R>) -> Self {
assert_eq!(r.len(), k);
Self { k, r }
} }
pub fn zero() -> Self {
Self((0..K).into_iter().map(|_| R::zero()).collect())
pub fn zero(k: usize, r_params: R::Params) -> Self {
Self {
k,
r: (0..k).into_iter().map(|_| R::zero(r_params)).collect(),
}
} }
pub fn rand(mut rng: impl Rng, dist: impl Distribution<f64>) -> Self {
Self(
(0..K)
pub fn rand(
mut rng: impl Rng,
dist: impl Distribution<f64>,
k: usize,
r_params: R::Params,
) -> Self {
Self {
k,
r: (0..k)
.into_iter() .into_iter()
.map(|_| R::rand(&mut rng, &dist))
.map(|_| R::rand(&mut rng, &dist, r_params))
.collect(), .collect(),
)
}
} }
// returns the decomposition of each polynomial element // returns the decomposition of each polynomial element
pub fn decompose(&self, beta: u32, l: u32) -> Vec<Self> { pub fn decompose(&self, beta: u32, l: u32) -> Vec<Self> {
@ -43,64 +55,85 @@ impl TR {
} }
} }
impl<const K: usize> TR<crate::torus::T64, K> {
pub fn mod_switch<const Q2: u64>(&self) -> TR<crate::torus::T64, K> {
TR(self.0.iter().map(|c_i| c_i.mod_switch::<Q2>()).collect())
impl TR<crate::torus::T64> {
pub fn mod_switch(&self, q2: u64) -> TR<crate::torus::T64> {
TR::<crate::torus::T64> {
k: self.k,
r: self.r.iter().map(|c_i| c_i.mod_switch(q2)).collect(),
}
} }
// pub fn mod_switch(&self, Q2: u64) -> TR<crate::torus::T64, K> { // pub fn mod_switch(&self, Q2: u64) -> TR<crate::torus::T64, K> {
// TR(self.0.iter().map(|c_i| c_i.mod_switch(Q2)).collect()) // TR(self.0.iter().map(|c_i| c_i.mod_switch(Q2)).collect())
// } // }
} }
impl<const N: usize, const K: usize> TR<crate::ring_torus::Tn<N>, K> {
impl TR<crate::ring_torus::Tn> {
pub fn left_rotate(&self, h: usize) -> Self { pub fn left_rotate(&self, h: usize) -> Self {
TR(self.0.iter().map(|c_i| c_i.left_rotate(h)).collect())
TR {
k: self.k,
r: self.r.iter().map(|c_i| c_i.left_rotate(h)).collect(),
}
} }
} }
impl<R: Ring, const K: usize> TR<R, K> {
impl<R: Ring> TR<R> {
pub fn iter(&self) -> std::slice::Iter<R> { pub fn iter(&self) -> std::slice::Iter<R> {
self.0.iter()
self.r.iter()
} }
} }
impl<R: Ring, const K: usize> Add<TR<R, K>> for TR<R, K> {
impl<R: Ring> Add<TR<R>> for TR<R> {
type Output = Self; type Output = Self;
fn add(self, other: Self) -> Self { fn add(self, other: Self) -> Self {
Self(
zip_eq(self.0, other.0)
debug_assert_eq!(self.k, other.k);
Self {
k: self.k,
r: zip_eq(self.r, other.r)
.map(|(s, o)| s + o) .map(|(s, o)| s + o)
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
)
}
} }
} }
impl<R: Ring, const K: usize> Sub<TR<R, K>> for TR<R, K> {
impl<R: Ring> Sub<TR<R>> for TR<R> {
type Output = Self; type Output = Self;
fn sub(self, other: Self) -> Self { fn sub(self, other: Self) -> Self {
Self(zip_eq(self.0, other.0).map(|(s, o)| s - o).collect())
debug_assert_eq!(self.k, other.k);
Self {
k: self.k,
r: zip_eq(self.r, other.r).map(|(s, o)| s - o).collect(),
}
} }
} }
impl<R: Ring, const K: usize> Neg for TR<R, K> {
impl<R: Ring> Neg for TR<R> {
type Output = Self; type Output = Self;
fn neg(self) -> Self::Output { fn neg(self) -> Self::Output {
Self(self.0.iter().map(|&e_i| -e_i).collect())
Self {
k: self.k,
r: self.r.iter().map(|e_i| -*e_i).collect(),
}
} }
} }
/// for (TR,TR), the Mul operation is defined as the dot product: /// for (TR,TR), the Mul operation is defined as the dot product:
/// for A, B \in R^k, result = Σ A_i * B_i \in R /// for A, B \in R^k, result = Σ A_i * B_i \in R
impl<R: Ring, const K: usize> Mul<TR<R, K>> for TR<R, K> {
impl<R: Ring> Mul<TR<R>> for TR<R> {
type Output = R; type Output = R;
fn mul(self, other: Self) -> R { fn mul(self, other: Self) -> R {
zip_eq(self.0, other.0).map(|(s, o)| s * o).sum()
debug_assert_eq!(self.k, other.k);
zip_eq(self.r, other.r).map(|(s, o)| s * o).sum()
} }
} }
impl<R: Ring, const K: usize> Mul<&TR<R, K>> for &TR<R, K> {
impl<R: Ring> Mul<&TR<R>> for &TR<R> {
type Output = R; type Output = R;
fn mul(self, other: &TR<R, K>) -> R {
zip_eq(self.0.clone(), other.0.clone())
fn mul(self, other: &TR<R>) -> R {
debug_assert_eq!(self.k, other.k);
zip_eq(self.r.clone(), other.r.clone())
.map(|(s, o)| s * o) .map(|(s, o)| s * o)
.sum() .sum()
} }
@ -108,15 +141,21 @@ impl Mul<&TR> for &TR {
/// for (TR, R), the Mul operation is defined as each element of TR is /// for (TR, R), the Mul operation is defined as each element of TR is
/// multiplied by R /// multiplied by R
impl<R: Ring, const K: usize> Mul<R> for TR<R, K> {
type Output = TR<R, K>;
fn mul(self, other: R) -> TR<R, K> {
Self(self.0.iter().map(|s| s.clone() * other.clone()).collect())
impl<R: Ring> Mul<R> for TR<R> {
type Output = TR<R>;
fn mul(self, other: R) -> TR<R> {
Self {
k: self.k,
r: self.r.iter().map(|s| s.clone() * other.clone()).collect(),
}
} }
} }
impl<R: Ring, const K: usize> Mul<&R> for &TR<R, K> {
type Output = TR<R, K>;
fn mul(self, other: &R) -> TR<R, K> {
TR::<R, K>(self.0.iter().map(|s| s.clone() * other.clone()).collect())
impl<R: Ring> Mul<&R> for &TR<R> {
type Output = TR<R>;
fn mul(self, other: &R) -> TR<R> {
TR::<R> {
k: self.k,
r: self.r.iter().map(|s| s.clone() * other.clone()).collect(),
}
} }
} }

+ 168
- 114
arith/src/zq.rs

@ -1,10 +1,14 @@
use rand::{distributions::Distribution, Rng}; use rand::{distributions::Distribution, Rng};
use std::borrow::Borrow;
use std::fmt; use std::fmt;
use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign}; use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign};
/// Z_q, integers modulus q, not necessarily prime /// Z_q, integers modulus q, not necessarily prime
#[derive(Clone, Copy, PartialEq)] #[derive(Clone, Copy, PartialEq)]
pub struct Zq<const Q: u64>(pub u64);
pub struct Zq {
pub q: u64,
pub v: u64,
}
// WIP // WIP
// impl<const Q: u64> From<Vec<u64>> for Vec<Zq<Q>> { // impl<const Q: u64> From<Vec<u64>> for Vec<Zq<Q>> {
@ -13,32 +17,35 @@ pub struct Zq(pub u64);
// } // }
// } // }
pub(crate) fn modulus_u64<const Q: u64>(e: u64) -> u64 {
(e % Q + Q) % Q
pub(crate) fn modulus_u64(q: u64, e: u64) -> u64 {
(e % q + q) % q
} }
impl<const Q: u64> Zq<Q> {
pub fn rand(mut rng: impl Rng, dist: impl Distribution<f64>) -> Self {
impl Zq {
pub fn rand(mut rng: impl Rng, dist: impl Distribution<f64>, q: u64) -> Self {
// TODO WIP // TODO WIP
let r: f64 = dist.sample(&mut rng); let r: f64 = dist.sample(&mut rng);
Self::from_f64(r)
Self::from_f64(q, r)
// Self::from_u64(r.round() as u64) // Self::from_u64(r.round() as u64)
} }
pub fn from_u64(e: u64) -> Self {
if e >= Q {
// (e % Q + Q) % Q
return Zq(modulus_u64::<Q>(e));
// return Zq(e % Q);
pub fn from_u64(q: u64, v: u64) -> Self {
if v >= q {
// (v % Q + Q) % Q
return Zq {
q,
v: modulus_u64(q, v),
};
// return Zq(v % Q);
} }
Zq(e)
Zq { q, v }
} }
pub fn from_f64(e: f64) -> Self {
pub fn from_f64(q: u64, e: f64) -> Self {
// WIP method // WIP method
let e: i64 = e.round() as i64; let e: i64 = e.round() as i64;
let q = Q as i64;
if e < 0 || e >= q {
return Zq(((e % q + q) % q) as u64);
let q_i64 = q as i64;
if e < 0 || e >= q_i64 {
return Zq::from_u64(q, ((e % q_i64 + q_i64) % q_i64) as u64);
} }
Zq(e as u64)
Zq { q, v: e as u64 }
// if e < 0 { // if e < 0 {
// // dbg!(&e); // // dbg!(&e);
@ -50,15 +57,18 @@ impl Zq {
// } // }
// Zq(e as u64) // Zq(e as u64)
} }
pub fn from_bool(b: bool) -> Self {
pub fn from_bool(q: u64, b: bool) -> Self {
if b { if b {
Zq(1)
Zq { q, v: 1 }
} else { } else {
Zq(0)
Zq { q, v: 0 }
} }
} }
pub fn zero() -> Self {
Self(0u64)
pub fn zero(q: u64) -> Self {
Self { q, v: 0u64 }
}
pub fn one(q: u64) -> Self {
Self { q, v: 1u64 }
} }
pub fn square(self) -> Self { pub fn square(self) -> Self {
self * self self * self
@ -66,18 +76,21 @@ impl Zq {
// modular exponentiation // modular exponentiation
pub fn exp(self, e: Self) -> Self { pub fn exp(self, e: Self) -> Self {
// mul-square approach // mul-square approach
let mut res = Self(1);
let mut res = Self::one(self.q);
let mut rem = e.clone(); let mut rem = e.clone();
let mut exp = self; let mut exp = self;
// for rem != Self(0) { // for rem != Self(0) {
while rem != Self(0) {
while rem != Self::zero(self.q) {
// if odd // if odd
// TODO use a more readible expression // TODO use a more readible expression
if 1 - ((rem.0 & 1) << 1) as i64 == -1 {
if 1 - ((rem.v & 1) << 1) as i64 == -1 {
res = res * exp; res = res * exp;
} }
exp = exp.square(); exp = exp.square();
rem = Self(rem.0 >> 1);
rem = Self {
q: self.q,
v: rem.v >> 1,
};
} }
res res
} }
@ -89,9 +102,9 @@ impl Zq {
// let a = self.0; // let a = self.0;
// let q = Q; // let q = Q;
let mut t = 0; let mut t = 0;
let mut r = Q;
let mut r = self.q;
let mut new_t = 0; let mut new_t = 0;
let mut new_r = self.0.clone();
let mut new_r = self.v.clone();
while new_r != 0 { while new_r != 0 {
let q = r / new_r; let q = r / new_r;
@ -104,16 +117,16 @@ impl Zq {
// if t < 0 { // if t < 0 {
// t = t + q; // t = t + q;
// } // }
return Zq::from_u64(t);
return Zq::from_u64(self.q, t);
} }
pub fn inv(self) -> Zq<Q> {
let (g, x, _) = Self::egcd(self.0 as i128, Q as i128);
pub fn inv(self) -> Zq {
let (g, x, _) = Self::egcd(self.v as i128, self.q as i128);
if g != 1 { if g != 1 {
// None // None
panic!("E"); panic!("E");
} else { } else {
let q = Q as i128;
Zq(((x % q + q) % q) as u64) // TODO maybe just Zq::new(x)
let q = self.q as i128;
Zq::from_u64(self.q, ((x % q + q) % q) as u64) // TODO maybe just Zq::new(x)
} }
} }
fn egcd(a: i128, b: i128) -> (i128, i128, i128) { fn egcd(a: i128, b: i128) -> (i128, i128, i128) {
@ -126,8 +139,11 @@ impl Zq {
} }
/// perform the mod switch operation from Q to Q', where Q2=Q' /// perform the mod switch operation from Q to Q', where Q2=Q'
pub fn mod_switch<const Q2: u64>(&self) -> Zq<Q2> {
Zq::<Q2>::from_u64(((self.0 as f64 * Q2 as f64) / Q as f64).round() as u64)
pub fn mod_switch(&self, q2: u64) -> Zq {
Zq::from_u64(
q2,
((self.v as f64 * q2 as f64) / self.q as f64).round() as u64,
)
} }
pub fn decompose(&self, beta: u32, l: u32) -> Vec<Self> { pub fn decompose(&self, beta: u32, l: u32) -> Vec<Self> {
@ -138,19 +154,25 @@ impl Zq {
} }
} }
pub fn decompose_base_beta(&self, beta: u32, l: u32) -> Vec<Self> { pub fn decompose_base_beta(&self, beta: u32, l: u32) -> Vec<Self> {
let mut rem: u64 = self.0;
let mut rem: u64 = self.v;
// next if is for cases in which beta does not divide Q (concretely // next if is for cases in which beta does not divide Q (concretely
// beta^l!=Q). round to the nearest multiple of q/beta^l // beta^l!=Q). round to the nearest multiple of q/beta^l
if rem >= beta.pow(l) as u64 { if rem >= beta.pow(l) as u64 {
// rem = Q - 1 - (Q / beta as u64); // floor // rem = Q - 1 - (Q / beta as u64); // floor
return vec![Zq(beta as u64 - 1); l as usize];
return vec![
Zq {
q: self.q,
v: beta as u64 - 1
};
l as usize
];
} }
let mut x: Vec<Self> = vec![]; let mut x: Vec<Self> = vec![];
for i in 1..l + 1 { for i in 1..l + 1 {
let den = Q / beta.pow(i) as u64;
let den = self.q / beta.pow(i) as u64;
let x_i = rem / den; // division between u64 already does floor let x_i = rem / den; // division between u64 already does floor
x.push(Self::from_u64(x_i));
x.push(Self::from_u64(self.q, x_i));
if x_i != 0 { if x_i != 0 {
rem = rem % den; rem = rem % den;
} }
@ -161,15 +183,15 @@ impl Zq {
pub fn decompose_base2(&self, l: u32) -> Vec<Self> { pub fn decompose_base2(&self, l: u32) -> Vec<Self> {
// next if is for cases in which beta does not divide Q (concretely // next if is for cases in which beta does not divide Q (concretely
// beta^l!=Q). round to the nearest multiple of q/beta^l // beta^l!=Q). round to the nearest multiple of q/beta^l
if self.0 >= 1 << l as u64 {
if self.v >= 1 << l as u64 {
// rem = Q - 1 - (Q / beta as u64); // floor // rem = Q - 1 - (Q / beta as u64); // floor
// (where beta=2) // (where beta=2)
return vec![Zq(1); l as usize];
return vec![Zq::one(self.q); l as usize];
} }
(0..l) (0..l)
.rev() .rev()
.map(|i| Self(((self.0 >> i) & 1) as u64))
.map(|i| Self::from_u64(self.q, ((self.v >> i) & 1) as u64))
.collect() .collect()
// naive ver: // naive ver:
@ -194,114 +216,143 @@ impl Zq {
} }
} }
impl<const Q: u64> Zq<Q> {
impl Zq {
fn r#mod(self) -> Self { fn r#mod(self) -> Self {
if self.0 >= Q {
return Zq(self.0 % Q);
if self.v >= self.q {
return Zq::from_u64(self.q, self.v % self.q);
} }
self self
} }
} }
impl<const Q: u64> Add<Zq<Q>> for Zq<Q> {
impl Add<Zq> for Zq {
type Output = Self; type Output = Self;
fn add(self, rhs: Self) -> Self::Output { fn add(self, rhs: Self) -> Self::Output {
let mut r = self.0 + rhs.0;
if r >= Q {
r -= Q;
assert_eq!(self.q, rhs.q);
let mut v = self.v + rhs.v;
if v >= self.q {
v -= self.q;
} }
Zq(r)
Zq { q: self.q, v }
} }
} }
impl<const Q: u64> Add<&Zq<Q>> for &Zq<Q> {
type Output = Zq<Q>;
impl Add<&Zq> for &Zq {
type Output = Zq;
fn add(self, rhs: &Zq) -> Self::Output {
assert_eq!(self.q, rhs.q);
fn add(self, rhs: &Zq<Q>) -> Self::Output {
let mut r = self.0 + rhs.0;
if r >= Q {
r -= Q;
let mut v = self.v + rhs.v;
if v >= self.q {
v -= self.q;
} }
Zq(r)
Zq { q: self.q, v }
} }
} }
impl<const Q: u64> AddAssign<Zq<Q>> for Zq<Q> {
impl AddAssign<Zq> for Zq {
fn add_assign(&mut self, rhs: Self) { fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs *self = *self + rhs
} }
} }
impl<const Q: u64> std::iter::Sum for Zq<Q> {
fn sum<I>(iter: I) -> Self
impl std::iter::Sum for Zq {
fn sum<I>(mut iter: I) -> Self
where where
I: Iterator<Item = Self>, I: Iterator<Item = Self>,
{ {
iter.fold(Zq(0), |acc, x| acc + x)
let first: Zq = iter.next().unwrap();
iter.fold(first, |acc, x| acc + x)
} }
} }
impl<const Q: u64> Sub<Zq<Q>> for Zq<Q> {
impl Sub<Zq> for Zq {
type Output = Self; type Output = Self;
fn sub(self, rhs: Self) -> Zq<Q> {
if self.0 >= rhs.0 {
Zq(self.0 - rhs.0)
fn sub(self, rhs: Self) -> Zq {
assert_eq!(self.q, rhs.q);
if self.v >= rhs.v {
Zq {
q: self.q,
v: self.v - rhs.v,
}
} else { } else {
Zq((Q + self.0) - rhs.0)
Zq {
q: self.q,
v: (self.q + self.v) - rhs.v,
}
} }
} }
} }
impl<const Q: u64> Sub<&Zq<Q>> for &Zq<Q> {
type Output = Zq<Q>;
impl Sub<&Zq> for &Zq {
type Output = Zq;
fn sub(self, rhs: &Zq<Q>) -> Self::Output {
if self.0 >= rhs.0 {
Zq(self.0 - rhs.0)
fn sub(self, rhs: &Zq) -> Self::Output {
assert_eq!(self.q, rhs.q);
if self.q >= rhs.q {
Zq {
q: self.q,
v: self.v - rhs.v,
}
} else { } else {
Zq((Q + self.0) - rhs.0)
Zq {
q: self.q,
v: (self.q + self.v) - rhs.v,
}
} }
} }
} }
impl<const Q: u64> SubAssign<Zq<Q>> for Zq<Q> {
impl SubAssign<Zq> for Zq {
fn sub_assign(&mut self, rhs: Self) { fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs *self = *self - rhs
} }
} }
impl<const Q: u64> Neg for Zq<Q> {
impl Neg for Zq {
type Output = Self; type Output = Self;
fn neg(self) -> Self::Output { fn neg(self) -> Self::Output {
if self.0 == 0 {
if self.v == 0 {
return self; return self;
} }
Zq(Q - self.0)
Zq {
q: self.q,
v: self.q - self.v,
}
} }
} }
impl<const Q: u64> Mul<Zq<Q>> for Zq<Q> {
impl Mul<Zq> for Zq {
type Output = Self; type Output = Self;
fn mul(self, rhs: Self) -> Zq<Q> {
fn mul(self, rhs: Self) -> Zq {
assert_eq!(self.q, rhs.q);
// TODO non-naive way // TODO non-naive way
Zq(((self.0 as u128 * rhs.0 as u128) % Q as u128) as u64)
Zq {
q: self.q,
v: ((self.v as u128 * rhs.v as u128) % self.q as u128) as u64,
}
// Zq((self.0 * rhs.0) % Q) // Zq((self.0 * rhs.0) % Q)
} }
} }
impl<const Q: u64> Div<Zq<Q>> for Zq<Q> {
impl Div<Zq> for Zq {
type Output = Self; type Output = Self;
fn div(self, rhs: Self) -> Zq<Q> {
fn div(self, rhs: Self) -> Zq {
// TODO non-naive way // TODO non-naive way
// Zq((self.0 / rhs.0) % Q) // Zq((self.0 / rhs.0) % Q)
self * rhs.inv() self * rhs.inv()
} }
} }
impl<const Q: u64> fmt::Display for Zq<Q> {
impl fmt::Display for Zq {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
write!(f, "{}", self.v)
} }
} }
impl<const Q: u64> fmt::Debug for Zq<Q> {
impl fmt::Debug for Zq {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
write!(f, "{}", self.v)
} }
} }
@ -312,80 +363,83 @@ mod tests {
#[test] #[test]
fn exp() { fn exp() {
const Q: u64 = 1021;
let a = Zq::<Q>(3);
let b = Zq::<Q>(3);
assert_eq!(a.exp(b), Zq(27));
const q: u64 = 1021;
let a = Zq::from_u64(q, 3);
let b = Zq::from_u64(q, 3);
assert_eq!(a.exp(b), Zq::from_u64(q, 27));
let a = Zq::<Q>(1000);
let b = Zq::<Q>(3);
assert_eq!(a.exp(b), Zq(949));
let a = Zq::from_u64(q, 1000);
let b = Zq::from_u64(q, 3);
assert_eq!(a.exp(b), Zq::from_u64(q, 949));
} }
#[test] #[test]
fn neg() { fn neg() {
const Q: u64 = 1021;
let a = Zq::<Q>::from_f64(101.0);
let b = Zq::<Q>::from_f64(-1.0);
let q: u64 = 1021;
let a = Zq::from_f64(q, 101.0);
let b = Zq::from_f64(q, -1.0);
assert_eq!(-a, a * b); assert_eq!(-a, a * b);
} }
fn recompose<const Q: u64>(beta: u32, l: u32, d: Vec<Zq<Q>>) -> Zq<Q> {
fn recompose(q: u64, beta: u32, l: u32, d: Vec<Zq>) -> Zq {
let mut x = 0u64; let mut x = 0u64;
for i in 0..l { for i in 0..l {
x += d[i as usize].0 * Q / beta.pow(i + 1) as u64;
x += d[i as usize].v * q / beta.pow(i + 1) as u64;
} }
Zq::from_u64(x)
Zq::from_u64(q, x)
} }
#[test] #[test]
fn test_decompose() { fn test_decompose() {
const Q1: u64 = 16;
let q1: u64 = 16;
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 4; let l: u32 = 4;
let x = Zq::<Q1>::from_u64(9);
let x = Zq::from_u64(q1, 9);
let d = x.decompose(beta, l); let d = x.decompose(beta, l);
assert_eq!(recompose::<Q1>(beta, l, d), x);
assert_eq!(recompose(q1, beta, l, d), x);
const Q: u64 = 5u64.pow(3);
let q: u64 = 5u64.pow(3);
let beta: u32 = 5; let beta: u32 = 5;
let l: u32 = 3; let l: u32 = 3;
let dist = Uniform::new(0_u64, Q);
let dist = Uniform::new(0_u64, q);
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
for _ in 0..1000 { for _ in 0..1000 {
let x = Zq::<Q>::from_u64(dist.sample(&mut rng));
let x = Zq::from_u64(q, dist.sample(&mut rng));
let d = x.decompose(beta, l); let d = x.decompose(beta, l);
assert_eq!(d.len(), l as usize); assert_eq!(d.len(), l as usize);
assert_eq!(recompose::<Q>(beta, l, d), x);
assert_eq!(recompose(q, beta, l, d), x);
} }
} }
#[test] #[test]
fn test_decompose_approx() { fn test_decompose_approx() {
const Q: u64 = 2u64.pow(4) + 1;
let q: u64 = 2u64.pow(4) + 1;
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 4; let l: u32 = 4;
let x = Zq::<Q>::from_u64(16); // in q, but bigger than beta^l
let x = Zq::from_u64(q, 16); // in q, but bigger than beta^l
let d = x.decompose(beta, l); let d = x.decompose(beta, l);
assert_eq!(d.len(), l as usize); assert_eq!(d.len(), l as usize);
assert_eq!(recompose::<Q>(beta, l, d), Zq(15));
assert_eq!(recompose(q, beta, l, d), Zq::from_u64(q, 15));
const Q2: u64 = 5u64.pow(3) + 1;
let q2: u64 = 5u64.pow(3) + 1;
let beta: u32 = 5; let beta: u32 = 5;
let l: u32 = 3; let l: u32 = 3;
let x = Zq::<Q2>::from_u64(125); // in q, but bigger than beta^l
let x = Zq::from_u64(q2, 125); // in q, but bigger than beta^l
let d = x.decompose(beta, l); let d = x.decompose(beta, l);
assert_eq!(d.len(), l as usize); assert_eq!(d.len(), l as usize);
assert_eq!(recompose::<Q2>(beta, l, d), Zq(124));
assert_eq!(recompose(q2, beta, l, d), Zq::from_u64(q2, 124));
const Q3: u64 = 2u64.pow(16) + 1;
let q3: u64 = 2u64.pow(16) + 1;
let beta: u32 = 2; let beta: u32 = 2;
let l: u32 = 16; let l: u32 = 16;
let x = Zq::<Q3>::from_u64(Q3 - 1); // in q, but bigger than beta^l
let x = Zq::from_u64(q3, q3 - 1); // in q, but bigger than beta^l
let d = x.decompose(beta, l); let d = x.decompose(beta, l);
assert_eq!(d.len(), l as usize); assert_eq!(d.len(), l as usize);
assert_eq!(recompose::<Q3>(beta, l, d), Zq(beta.pow(l) as u64 - 1));
assert_eq!(
recompose(q3, beta, l, d),
Zq::from_u64(q3, beta.pow(l) as u64 - 1)
);
} }
} }

Loading…
Cancel
Save