mirror of
https://github.com/arnaucube/fhe-study.git
synced 2026-01-24 04:33:52 +01:00
add NTT implementation, and use it for the negacyclic poly ring multiplication, more details on the NTT can be found at https://github.com/arnaucube/math/blob/master/notes_ntt.pdf .
This commit is contained in:
4
README.md
Normal file
4
README.md
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
# fhe-study
|
||||||
|
Code done while studying some FHE papers.
|
||||||
|
|
||||||
|
- arithmetic: contains $\mathbb{Z}_q$ and $\mathbb{Z}_q[X]/(X^N+1)$ arithmetic implementations, together with the NTT implementation.
|
||||||
3
arithmetic/.gitignore
vendored
Normal file
3
arithmetic/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
/target
|
||||||
|
Cargo.lock
|
||||||
|
*.sage.py
|
||||||
2
arithmetic/README.md
Normal file
2
arithmetic/README.md
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# arithmetic
|
||||||
|
Contains $\mathbb{Z}_q$ and $\mathbb{Z}_q[X]/(X^N+1)$ arithmetic implementations, together with the NTT implementation.
|
||||||
@@ -4,8 +4,11 @@
|
|||||||
#![allow(clippy::upper_case_acronyms)]
|
#![allow(clippy::upper_case_acronyms)]
|
||||||
#![allow(dead_code)] // TMP
|
#![allow(dead_code)] // TMP
|
||||||
|
|
||||||
|
mod naive; // TODO rm
|
||||||
|
pub mod ntt;
|
||||||
pub mod ring;
|
pub mod ring;
|
||||||
pub mod zq;
|
pub mod zq;
|
||||||
|
|
||||||
|
pub use ntt::NTT;
|
||||||
pub use ring::PR;
|
pub use ring::PR;
|
||||||
pub use zq::Zq;
|
pub use zq::Zq;
|
||||||
|
|||||||
195
arithmetic/src/naive.rs
Normal file
195
arithmetic/src/naive.rs
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
//! this file implements the non-efficient NTT, which uses multiplication by the
|
||||||
|
//! Vandermonde matrix.
|
||||||
|
use crate::zq::Zq;
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct NTT<const Q: u64, const N: usize> {
|
||||||
|
pub primitive: Zq<Q>,
|
||||||
|
// nth_roots: Vec<Zq<Q>>,
|
||||||
|
pub ntt: Vec<Vec<Zq<Q>>>,
|
||||||
|
pub intt: Vec<Vec<Zq<Q>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const Q: u64, const N: usize> NTT<Q, N> {
|
||||||
|
pub fn new() -> Result<Self> {
|
||||||
|
// TODO change n to be u64 and ensure that is n<Q
|
||||||
|
// note: `n` here is not the `N` from `(X^N+1)`
|
||||||
|
// TODO: in fact n will be N (trait/struct param)
|
||||||
|
|
||||||
|
// let primitive = Self::get_primitive_root_of_unity((2 * N) as u64)?;
|
||||||
|
let primitive = Self::get_primitive_root_of_unity((2 * N) as u64)?;
|
||||||
|
// let mut nth_roots = vec![Zq(0); N];
|
||||||
|
// let mut w_i = Zq(1);
|
||||||
|
// for i in 0..N {
|
||||||
|
// w_i = w_i * primitive;
|
||||||
|
// nth_roots[i] = w_i;
|
||||||
|
// }
|
||||||
|
let ntt: Vec<Vec<Zq<Q>>> = Self::vandermonde(primitive);
|
||||||
|
let intt = Self::invert_vandermonde(&ntt);
|
||||||
|
Ok(Self {
|
||||||
|
primitive,
|
||||||
|
// nth_roots,
|
||||||
|
ntt,
|
||||||
|
intt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
pub fn vandermonde(primitive: Zq<Q>) -> Vec<Vec<Zq<Q>>> {
|
||||||
|
let mut v: Vec<Vec<Zq<Q>>> = vec![];
|
||||||
|
let n = (2 * N) as u64;
|
||||||
|
// let n = N as u64;
|
||||||
|
for i in 0..n {
|
||||||
|
let mut row: Vec<Zq<Q>> = vec![];
|
||||||
|
let primitive_i = primitive.exp(Zq(i));
|
||||||
|
let mut primitive_ij = Zq(1);
|
||||||
|
for _ in 0..n {
|
||||||
|
row.push(primitive_ij);
|
||||||
|
primitive_ij = primitive_ij * primitive_i;
|
||||||
|
}
|
||||||
|
v.push(row);
|
||||||
|
}
|
||||||
|
v
|
||||||
|
}
|
||||||
|
// specifically for the Vandermonde matrix
|
||||||
|
pub fn invert_vandermonde(v: &Vec<Vec<Zq<Q>>>) -> Vec<Vec<Zq<Q>>> {
|
||||||
|
let n = 2 * N;
|
||||||
|
// let n = N;
|
||||||
|
let mut inv: Vec<Vec<Zq<Q>>> = vec![];
|
||||||
|
for i in 0..n {
|
||||||
|
let w_i = v[i][1]; // = w_i^1=w^i^1 = w^i
|
||||||
|
let w_i_inv = w_i.inv();
|
||||||
|
let mut row: Vec<Zq<Q>> = vec![];
|
||||||
|
for j in 0..n {
|
||||||
|
row.push(w_i_inv.exp(Zq(j as u64)) / Zq(n as u64));
|
||||||
|
}
|
||||||
|
inv.push(row);
|
||||||
|
}
|
||||||
|
inv
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_primitive_root_of_unity(n: u64) -> Result<Zq<Q>> {
|
||||||
|
// using the method described by Thomas Pornin in
|
||||||
|
// https://crypto.stackexchange.com/a/63616
|
||||||
|
|
||||||
|
// assert!((Q - 1) % N as u64 == 0);
|
||||||
|
assert!((Q - 1) % n == 0);
|
||||||
|
|
||||||
|
// TODO maybe not using Zq and using u64 directly
|
||||||
|
let n = Zq(n);
|
||||||
|
for k in 0..Q {
|
||||||
|
if k == 0 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let g = Zq(k);
|
||||||
|
// g = F.random_element()
|
||||||
|
if g == Zq(0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let w = g.exp((-Zq(1)) / n);
|
||||||
|
if w.exp(n / Zq(2)) != Zq(1) {
|
||||||
|
// g is the generator
|
||||||
|
return Ok(w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(anyhow!("can not find the primitive root of unity"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use rand_distr::Uniform;
|
||||||
|
|
||||||
|
use crate::ring::matrix_vec_product;
|
||||||
|
use crate::ring::PR;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn roots_of_unity() -> Result<()> {
|
||||||
|
const Q: u64 = 12289;
|
||||||
|
const N: usize = 512;
|
||||||
|
let _ntt = NTT::<Q, N>::new()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn vandermonde_ntt() -> Result<()> {
|
||||||
|
const Q: u64 = 41;
|
||||||
|
const N: usize = 4;
|
||||||
|
let primitive = NTT::<Q, N>::get_primitive_root_of_unity((2 * N) as u64)?;
|
||||||
|
let v = NTT::<Q, N>::vandermonde(primitive);
|
||||||
|
|
||||||
|
// naively compute the Vandermonde matrix, and assert that the one from the method matches
|
||||||
|
// the naively obtained one
|
||||||
|
let n2 = (2 * N) as u64;
|
||||||
|
let mut v2: Vec<Vec<Zq<Q>>> = vec![];
|
||||||
|
for i in 0..n2 {
|
||||||
|
let mut row: Vec<Zq<Q>> = vec![];
|
||||||
|
for j in 0..n2 {
|
||||||
|
row.push(primitive.exp(Zq(i * j)));
|
||||||
|
}
|
||||||
|
v2.push(row);
|
||||||
|
}
|
||||||
|
assert_eq!(v, v2);
|
||||||
|
|
||||||
|
let v_inv = NTT::<Q, N>::invert_vandermonde(&v);
|
||||||
|
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
let uniform_distr = Uniform::new(0_f64, Q as f64);
|
||||||
|
let a = PR::<Q, N>::rand(&mut rng, uniform_distr)?;
|
||||||
|
// let a = PR::<Q, N>::new_from_u64(vec![36, 21, 9, 19]);
|
||||||
|
|
||||||
|
// let a_padded_coeffs: [Zq<Q>; 2 * N] =
|
||||||
|
// std::array::from_fn(|i| if i < N { a.coeffs[i] } else { Zq::zero() });
|
||||||
|
let mut a_padded = a.coeffs.to_vec();
|
||||||
|
a_padded.append(&mut vec![Zq(0); N]);
|
||||||
|
// let a_ntt = a_padded.mul_by_matrix(&v)?;
|
||||||
|
let a_ntt = matrix_vec_product(&v, &a_padded)?;
|
||||||
|
let a_intt: Vec<Zq<Q>> = matrix_vec_product(&v_inv, &a_ntt)?;
|
||||||
|
assert_eq!(a_intt, a_padded);
|
||||||
|
let a_intt_arr: [Zq<Q>; N] = std::array::from_fn(|i| a_intt[i]);
|
||||||
|
assert_eq!(PR::new(a_intt_arr, None), a);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn vec_by_ntt() -> Result<()> {
|
||||||
|
const Q: u64 = 257;
|
||||||
|
const N: usize = 4;
|
||||||
|
// let primitive = NTT::<Q, N>::get_primitive_root_of_unity((2*N) as u64)?;
|
||||||
|
let ntt = NTT::<Q, N>::new()?;
|
||||||
|
|
||||||
|
let a: Vec<Zq<Q>> = vec![256, 256, 256, 256, 0, 0, 0, 0]
|
||||||
|
.iter()
|
||||||
|
.map(|&e| Zq::new(e))
|
||||||
|
.collect();
|
||||||
|
let a_ntt = matrix_vec_product(&ntt.ntt, &a)?;
|
||||||
|
let a_intt = matrix_vec_product(&ntt.intt, &a_ntt)?;
|
||||||
|
assert_eq!(a_intt, a);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bench_ntt() -> Result<()> {
|
||||||
|
// const Q: u64 = 12289;
|
||||||
|
// const N: usize = 512;
|
||||||
|
const Q: u64 = 257;
|
||||||
|
const N: usize = 4;
|
||||||
|
// let primitive = NTT::<Q, N>::get_primitive_root_of_unity((2*N) as u64)?;
|
||||||
|
let ntt = NTT::<Q, N>::new()?;
|
||||||
|
|
||||||
|
let rng = rand::thread_rng();
|
||||||
|
let a = PR::<Q, { 2 * N }>::rand(rng, Uniform::new(0_f64, (Q - 1) as f64))?;
|
||||||
|
let a = a.coeffs;
|
||||||
|
dbg!(&a);
|
||||||
|
let a_ntt = matrix_vec_product(&ntt.ntt, &a.to_vec())?;
|
||||||
|
dbg!(&a_ntt);
|
||||||
|
let a_intt = matrix_vec_product(&ntt.intt, &a_ntt)?;
|
||||||
|
dbg!(&a_intt);
|
||||||
|
assert_eq!(a_intt, a);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
183
arithmetic/src/ntt.rs
Normal file
183
arithmetic/src/ntt.rs
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
//! Implementation of the NTT & iNTT, following the CT & GS algorighms, more
|
||||||
|
//! details in https://github.com/arnaucube/math/blob/master/notes_ntt.pdf .
|
||||||
|
use crate::zq::Zq;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct NTT<const Q: u64, const N: usize> {}
|
||||||
|
|
||||||
|
impl<const Q: u64, const N: usize> NTT<Q, N> {
|
||||||
|
const N_INV: Zq<Q> = Zq(const_inv_mod::<Q>(N as u64));
|
||||||
|
// since we work over Zq[X]/(X^N+1) (negacyclic), get the 2*N-th root of unity
|
||||||
|
pub(crate) const ROOT_OF_UNITY: u64 = primitive_root_of_unity::<Q>(2 * N);
|
||||||
|
pub(crate) const ROOTS_OF_UNITY: [Zq<Q>; N] = roots_of_unity(Self::ROOT_OF_UNITY);
|
||||||
|
const ROOTS_OF_UNITY_INV: [Zq<Q>; N] = roots_of_unity_inv(Self::ROOTS_OF_UNITY);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const Q: u64, const N: usize> NTT<Q, N> {
|
||||||
|
/// implements the Cooley-Tukey (CT) algorithm. Details at section 3.1 of
|
||||||
|
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
|
||||||
|
pub fn ntt(a: [Zq<Q>; N]) -> [Zq<Q>; N] {
|
||||||
|
let mut t = N / 2;
|
||||||
|
let mut m = 1;
|
||||||
|
let mut r: [Zq<Q>; N] = a.clone();
|
||||||
|
while m < N {
|
||||||
|
let mut k = 0;
|
||||||
|
for i in 0..m {
|
||||||
|
let S: Zq<Q> = Self::ROOTS_OF_UNITY[m + i];
|
||||||
|
for j in k..k + t {
|
||||||
|
let U: Zq<Q> = r[j];
|
||||||
|
let V: Zq<Q> = r[j + t] * S;
|
||||||
|
r[j] = U + V;
|
||||||
|
r[j + t] = U - V;
|
||||||
|
}
|
||||||
|
k = k + 2 * t;
|
||||||
|
}
|
||||||
|
t /= 2;
|
||||||
|
m *= 2;
|
||||||
|
}
|
||||||
|
r
|
||||||
|
}
|
||||||
|
|
||||||
|
/// implements the Gentleman-Sande (GS) algorithm. Details at section 3.2 of
|
||||||
|
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
|
||||||
|
pub fn intt(a: [Zq<Q>; N]) -> [Zq<Q>; N] {
|
||||||
|
let mut t = 1;
|
||||||
|
let mut m = N / 2;
|
||||||
|
let mut r: [Zq<Q>; N] = a.clone();
|
||||||
|
while m > 0 {
|
||||||
|
let mut k = 0;
|
||||||
|
for i in 0..m {
|
||||||
|
let S: Zq<Q> = Self::ROOTS_OF_UNITY_INV[m + i];
|
||||||
|
for j in k..k + t {
|
||||||
|
let U: Zq<Q> = r[j];
|
||||||
|
let V: Zq<Q> = r[j + t];
|
||||||
|
r[j] = U + V;
|
||||||
|
r[j + t] = (U - V) * S;
|
||||||
|
}
|
||||||
|
k += 2 * t;
|
||||||
|
}
|
||||||
|
t *= 2;
|
||||||
|
m /= 2;
|
||||||
|
}
|
||||||
|
for i in 0..N {
|
||||||
|
r[i] = r[i] * Self::N_INV;
|
||||||
|
}
|
||||||
|
r
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// computes a primitive N-th root of unity using the method described by Thomas
|
||||||
|
/// Pornin in https://crypto.stackexchange.com/a/63616
|
||||||
|
const fn primitive_root_of_unity<const Q: u64>(N: usize) -> u64 {
|
||||||
|
assert!(N.is_power_of_two());
|
||||||
|
assert!((Q - 1) % N as u64 == 0);
|
||||||
|
|
||||||
|
let n: u64 = N as u64;
|
||||||
|
let mut k = 1;
|
||||||
|
while k < Q {
|
||||||
|
// alternatively could get a random k at each iteration, if so, add the following if:
|
||||||
|
// `if k == 0 { continue; }`
|
||||||
|
let w = const_exp_mod::<Q>(k, (Q - 1) / n);
|
||||||
|
if const_exp_mod::<Q>(w, n / 2) != 1 {
|
||||||
|
return w; // w is a primitive N-th root of unity
|
||||||
|
}
|
||||||
|
k += 1;
|
||||||
|
}
|
||||||
|
panic!("No primitive root of unity");
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn roots_of_unity<const Q: u64, const N: usize>(w: u64) -> [Zq<Q>; N] {
|
||||||
|
let mut r: [Zq<Q>; N] = [Zq(0u64); N];
|
||||||
|
let mut i = 0;
|
||||||
|
let log_n = N.ilog2();
|
||||||
|
while i < N {
|
||||||
|
// (return the roots in bit-reverset order)
|
||||||
|
let j = ((i as u64).reverse_bits() >> (64 - log_n)) as usize;
|
||||||
|
r[i] = Zq(const_exp_mod::<Q>(w, j as u64));
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
r
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn roots_of_unity_inv<const Q: u64, const N: usize>(v: [Zq<Q>; N]) -> [Zq<Q>; N] {
|
||||||
|
// assumes that the inputted roots are already in bit-reverset order
|
||||||
|
let mut r: [Zq<Q>; N] = [Zq(0u64); N];
|
||||||
|
let mut i = 0;
|
||||||
|
while i < N {
|
||||||
|
r[i] = Zq(const_inv_mod::<Q>(v[i].0));
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
r
|
||||||
|
}
|
||||||
|
|
||||||
|
/// returns x^k mod Q
|
||||||
|
const fn const_exp_mod<const Q: u64>(x: u64, k: u64) -> u64 {
|
||||||
|
let mut r = 1u64;
|
||||||
|
let mut x = x;
|
||||||
|
let mut k = k;
|
||||||
|
x = x % Q;
|
||||||
|
// exponentiation by square strategy
|
||||||
|
while k > 0 {
|
||||||
|
if k % 2 == 1 {
|
||||||
|
r = (r * x) % Q;
|
||||||
|
}
|
||||||
|
x = (x * x) % Q;
|
||||||
|
k /= 2;
|
||||||
|
}
|
||||||
|
r
|
||||||
|
}
|
||||||
|
|
||||||
|
/// returns x^-1 mod Q
|
||||||
|
const fn const_inv_mod<const Q: u64>(x: u64) -> u64 {
|
||||||
|
// by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q
|
||||||
|
const_exp_mod::<Q>(x, Q - 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use std::array;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ntt() -> Result<()> {
|
||||||
|
const Q: u64 = 2u64.pow(16) + 1;
|
||||||
|
const N: usize = 4;
|
||||||
|
|
||||||
|
let a: [u64; N] = [1u64, 2, 3, 4];
|
||||||
|
let a: [Zq<Q>; N] = array::from_fn(|i| Zq::new(a[i]));
|
||||||
|
|
||||||
|
let a_ntt = NTT::<Q, N>::ntt(a);
|
||||||
|
|
||||||
|
let a_intt = NTT::<Q, N>::intt(a_ntt);
|
||||||
|
|
||||||
|
dbg!(&a);
|
||||||
|
dbg!(&a_ntt);
|
||||||
|
dbg!(&a_intt);
|
||||||
|
dbg!(NTT::<Q, N>::ROOT_OF_UNITY);
|
||||||
|
dbg!(NTT::<Q, N>::ROOTS_OF_UNITY);
|
||||||
|
|
||||||
|
assert_eq!(a, a_intt);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ntt_loop() -> Result<()> {
|
||||||
|
const Q: u64 = 2u64.pow(16) + 1;
|
||||||
|
const N: usize = 512;
|
||||||
|
|
||||||
|
use rand::distributions::Distribution;
|
||||||
|
use rand::distributions::Uniform;
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
let dist = Uniform::new(0_f64, Q as f64);
|
||||||
|
|
||||||
|
for _ in 0..100 {
|
||||||
|
let a: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng)));
|
||||||
|
let a_ntt = NTT::<Q, N>::ntt(a);
|
||||||
|
let a_intt = NTT::<Q, N>::intt(a_ntt);
|
||||||
|
assert_eq!(a, a_intt);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ use std::array;
|
|||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::ops;
|
use std::ops;
|
||||||
|
|
||||||
|
use crate::ntt::NTT;
|
||||||
use crate::zq::Zq;
|
use crate::zq::Zq;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
|
|
||||||
@@ -78,6 +79,35 @@ impl<const Q: u64, const N: usize> PR<Q, N> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO review if needed, or if with this interface
|
||||||
|
pub fn mul_by_matrix(&self, m: &Vec<Vec<Zq<Q>>>) -> Result<Vec<Zq<Q>>> {
|
||||||
|
matrix_vec_product(m, &self.coeffs.to_vec())
|
||||||
|
}
|
||||||
|
pub fn mul_by_zq(&self, s: &Zq<Q>) -> Self {
|
||||||
|
Self {
|
||||||
|
coeffs: array::from_fn(|i| self.coeffs[i] * *s),
|
||||||
|
evals: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn mul_by_u64(&self, s: u64) -> Self {
|
||||||
|
let s = Zq::new(s);
|
||||||
|
Self {
|
||||||
|
coeffs: array::from_fn(|i| self.coeffs[i] * s),
|
||||||
|
// coeffs: self.coeffs.iter().map(|&e| e * s).collect(),
|
||||||
|
evals: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn mul_by_f64(&self, s: f64) -> Self {
|
||||||
|
Self {
|
||||||
|
coeffs: array::from_fn(|i| Zq::from_f64(self.coeffs[i].0 as f64 * s)),
|
||||||
|
evals: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mul(&mut self, rhs: &mut Self) -> Self {
|
||||||
|
mul_mut(self, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
// TODO simplify
|
// TODO simplify
|
||||||
let mut str = "";
|
let mut str = "";
|
||||||
@@ -207,6 +237,51 @@ impl<const Q: u64, const N: usize> ops::Sub<&PR<Q, N>> for &PR<Q, N> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
impl<const Q: u64, const N: usize> ops::Mul<PR<Q, N>> for PR<Q, N> {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn mul(self, rhs: Self) -> Self {
|
||||||
|
mul(&self, &rhs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<const Q: u64, const N: usize> ops::Mul<&PR<Q, N>> for &PR<Q, N> {
|
||||||
|
type Output = PR<Q, N>;
|
||||||
|
|
||||||
|
fn mul(self, rhs: &PR<Q, N>) -> Self::Output {
|
||||||
|
mul(self, rhs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mul by Zq element
|
||||||
|
impl<const Q: u64, const N: usize> ops::Mul<Zq<Q>> for PR<Q, N> {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn mul(self, s: Zq<Q>) -> Self {
|
||||||
|
self.mul_by_zq(&s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<const Q: u64, const N: usize> ops::Mul<&Zq<Q>> for &PR<Q, N> {
|
||||||
|
type Output = PR<Q, N>;
|
||||||
|
|
||||||
|
fn mul(self, s: &Zq<Q>) -> Self::Output {
|
||||||
|
self.mul_by_zq(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// mul by u64
|
||||||
|
impl<const Q: u64, const N: usize> ops::Mul<u64> for PR<Q, N> {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn mul(self, s: u64) -> Self {
|
||||||
|
self.mul_by_u64(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<const Q: u64, const N: usize> ops::Mul<&u64> for &PR<Q, N> {
|
||||||
|
type Output = PR<Q, N>;
|
||||||
|
|
||||||
|
fn mul(self, s: &u64) -> Self::Output {
|
||||||
|
self.mul_by_u64(*s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<const Q: u64, const N: usize> ops::Neg for PR<Q, N> {
|
impl<const Q: u64, const N: usize> ops::Neg for PR<Q, N> {
|
||||||
type Output = Self;
|
type Output = Self;
|
||||||
@@ -219,6 +294,39 @@ impl<const Q: u64, const N: usize> ops::Neg for PR<Q, N> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn mul_mut<const Q: u64, const N: usize>(lhs: &mut PR<Q, N>, rhs: &mut PR<Q, N>) -> PR<Q, N> {
|
||||||
|
// reuse evaluations if already computed
|
||||||
|
if !lhs.evals.is_some() {
|
||||||
|
lhs.evals = Some(NTT::<Q, N>::ntt(lhs.coeffs));
|
||||||
|
};
|
||||||
|
if !rhs.evals.is_some() {
|
||||||
|
rhs.evals = Some(NTT::<Q, N>::ntt(rhs.coeffs));
|
||||||
|
};
|
||||||
|
let lhs_evals = lhs.evals.unwrap();
|
||||||
|
let rhs_evals = rhs.evals.unwrap();
|
||||||
|
|
||||||
|
let c_ntt: [Zq<Q>; N] = array::from_fn(|i| lhs_evals[i] * rhs_evals[i]);
|
||||||
|
let c = NTT::<Q, { N }>::intt(c_ntt);
|
||||||
|
PR::new(c, Some(c_ntt))
|
||||||
|
}
|
||||||
|
fn mul<const Q: u64, const N: usize>(lhs: &PR<Q, N>, rhs: &PR<Q, N>) -> PR<Q, N> {
|
||||||
|
// reuse evaluations if already computed
|
||||||
|
let lhs_evals = if lhs.evals.is_some() {
|
||||||
|
lhs.evals.unwrap()
|
||||||
|
} else {
|
||||||
|
NTT::<Q, N>::ntt(lhs.coeffs)
|
||||||
|
};
|
||||||
|
let rhs_evals = if rhs.evals.is_some() {
|
||||||
|
rhs.evals.unwrap()
|
||||||
|
} else {
|
||||||
|
NTT::<Q, N>::ntt(rhs.coeffs)
|
||||||
|
};
|
||||||
|
|
||||||
|
let c_ntt: [Zq<Q>; N] = array::from_fn(|i| lhs_evals[i] * rhs_evals[i]);
|
||||||
|
let c = NTT::<Q, { N }>::intt(c_ntt);
|
||||||
|
PR::new(c, Some(c_ntt))
|
||||||
|
}
|
||||||
|
|
||||||
impl<const Q: u64, const N: usize> fmt::Display for PR<Q, N> {
|
impl<const Q: u64, const N: usize> fmt::Display for PR<Q, N> {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
self.fmt(f)?;
|
self.fmt(f)?;
|
||||||
@@ -277,4 +385,40 @@ mod tests {
|
|||||||
"x^2 + x + 1 mod Z_7/(X^3+1)"
|
"x^2 + x + 1 mod Z_7/(X^3+1)"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn test_mul_opt<const Q: u64, const N: usize>(
|
||||||
|
a: [u64; N],
|
||||||
|
b: [u64; N],
|
||||||
|
expected_c: [u64; N],
|
||||||
|
) -> Result<()> {
|
||||||
|
let a: [Zq<Q>; N] = array::from_fn(|i| Zq::new(a[i]));
|
||||||
|
let mut a = PR::new(a, None);
|
||||||
|
let b: [Zq<Q>; N] = array::from_fn(|i| Zq::new(b[i]));
|
||||||
|
let mut b = PR::new(b, None);
|
||||||
|
let expected_c: [Zq<Q>; N] = array::from_fn(|i| Zq::new(expected_c[i]));
|
||||||
|
let expected_c = PR::new(expected_c, None);
|
||||||
|
|
||||||
|
let c = mul_mut(&mut a, &mut b);
|
||||||
|
assert_eq!(c, expected_c);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn test_mul() -> Result<()> {
|
||||||
|
const Q: u64 = 2u64.pow(16) + 1;
|
||||||
|
const N: usize = 4;
|
||||||
|
|
||||||
|
let a: [u64; N] = [1u64, 2, 3, 4];
|
||||||
|
let b: [u64; N] = [1u64, 2, 3, 4];
|
||||||
|
let c: [u64; N] = [65513, 65517, 65531, 20];
|
||||||
|
test_mul_opt::<Q, N>(a, b, c)?;
|
||||||
|
|
||||||
|
let a: [u64; N] = [0u64, 0, 0, 2];
|
||||||
|
let b: [u64; N] = [0u64, 0, 0, 2];
|
||||||
|
let c: [u64; N] = [0u64, 0, 65533, 0];
|
||||||
|
test_mul_opt::<Q, N>(a, b, c)?;
|
||||||
|
|
||||||
|
// TODO more testvectors
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user