use itertools::{izip, Itertools};
|
|
use rand::{thread_rng, Rng, RngCore, SeedableRng};
|
|
use rand_chacha::{rand_core::le, ChaCha8Rng};
|
|
|
|
use crate::{
|
|
backend::{ArithmeticOps, ModInit, ModularOpsU64, Modulus},
|
|
utils::{mod_exponent, mod_inverse, shoup_representation_fq},
|
|
};
|
|
|
|
pub trait NttInit<M> {
|
|
/// Ntt istance must be compatible across different instances with same `q`
|
|
/// and `n`
|
|
fn new(q: &M, n: usize) -> Self;
|
|
}
|
|
|
|
pub trait Ntt {
|
|
type Element;
|
|
fn forward_lazy(&self, v: &mut [Self::Element]);
|
|
fn forward(&self, v: &mut [Self::Element]);
|
|
fn backward_lazy(&self, v: &mut [Self::Element]);
|
|
fn backward(&self, v: &mut [Self::Element]);
|
|
}
|
|
|
|
/// Forward butterfly routine for Number theoretic transform. Given inputs `x <
|
|
/// 4q` and `y < 4q` mutates x and y in place to equal x' and y' where
|
|
/// x' = x + wy
|
|
/// y' = x - wy
|
|
/// and both x' and y' are \in [0, 4q)
|
|
///
|
|
/// Implements Algorithm 4 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf)
|
|
pub fn forward_butterly_0_to_4q(
|
|
mut x: u64,
|
|
y: u64,
|
|
w: u64,
|
|
w_shoup: u64,
|
|
q: u64,
|
|
q_twice: u64,
|
|
) -> (u64, u64) {
|
|
debug_assert!(x < q * 4, "{} >= (4q){}", x, 4 * q);
|
|
debug_assert!(y < q * 4, "{} >= (4q){}", y, 4 * q);
|
|
|
|
if x >= q_twice {
|
|
x = x - q_twice;
|
|
}
|
|
|
|
// TODO (Jay): Hot path expected. How expensive is it?
|
|
let k = ((w_shoup as u128 * y as u128) >> 64) as u64;
|
|
let t = w.wrapping_mul(y).wrapping_sub(k.wrapping_mul(q));
|
|
|
|
(x + t, x + q_twice - t)
|
|
}
|
|
|
|
pub fn forward_butterly_0_to_2q(
|
|
mut x: u64,
|
|
mut y: u64,
|
|
w: u64,
|
|
w_shoup: u64,
|
|
q: u64,
|
|
q_twice: u64,
|
|
) -> (u64, u64) {
|
|
debug_assert!(x < q * 4, "{} >= (4q){}", x, 4 * q);
|
|
debug_assert!(y < q * 4, "{} >= (4q){}", y, 4 * q);
|
|
|
|
if x >= q_twice {
|
|
x = x - q_twice;
|
|
}
|
|
|
|
let k = ((w_shoup as u128 * y as u128) >> 64) as u64;
|
|
let t = w.wrapping_mul(y).wrapping_sub(k.wrapping_mul(q));
|
|
|
|
let ox = x.wrapping_add(t);
|
|
let oy = x.wrapping_sub(t);
|
|
|
|
(
|
|
(ox).min(ox.wrapping_sub(q_twice)),
|
|
oy.min(oy.wrapping_add(q_twice)),
|
|
)
|
|
}
|
|
|
|
/// Inverse butterfly routine of Inverse Number theoretic transform. Given
|
|
/// inputs `x < 2q` and `y < 2q` mutates x and y to equal x' and y' where
|
|
/// x'= x + y
|
|
/// y' = w(x - y)
|
|
/// and both x' and y' are \in [0, 2q)
|
|
///
|
|
/// Implements Algorithm 3 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf)
|
|
pub unsafe fn inverse_butterfly(
|
|
x: *mut u64,
|
|
y: *mut u64,
|
|
w_inv: &u64,
|
|
w_inv_shoup: &u64,
|
|
q: &u64,
|
|
q_twice: &u64,
|
|
) {
|
|
debug_assert!(*x < *q_twice, "{} >= (2q){q_twice}", *x);
|
|
debug_assert!(*y < *q_twice, "{} >= (2q){q_twice}", *y);
|
|
|
|
let mut x_dash = *x + *y;
|
|
if x_dash >= *q_twice {
|
|
x_dash -= q_twice
|
|
}
|
|
|
|
let t = *x + q_twice - *y;
|
|
let k = ((*w_inv_shoup as u128 * t as u128) >> 64) as u64; // TODO (Jay): Hot path
|
|
*y = w_inv.wrapping_mul(t).wrapping_sub(k.wrapping_mul(*q));
|
|
|
|
*x = x_dash;
|
|
}
|
|
|
|
/// Number theoretic transform of vector `a` where each element can be in range
|
|
/// [0, 2q). Outputs NTT(a) where each element is in range [0,2q)
|
|
///
|
|
/// Implements Cooley-tukey based forward NTT as given in Algorithm 1 of https://eprint.iacr.org/2016/504.pdf.
|
|
pub fn ntt_lazy(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) {
|
|
assert!(a.len() == psi.len());
|
|
|
|
let n = a.len();
|
|
let mut t = n;
|
|
|
|
let mut m = 1;
|
|
while m < n {
|
|
t >>= 1;
|
|
let w = &psi[m..];
|
|
let w_shoup = &psi_shoup[m..];
|
|
|
|
if t == 1 {
|
|
for (a, w, w_shoup) in izip!(a.chunks_mut(2), w.iter(), w_shoup.iter()) {
|
|
let (ox, oy) = forward_butterly_0_to_2q(a[0], a[1], *w, *w_shoup, q, q_twice);
|
|
a[0] = ox;
|
|
a[1] = oy;
|
|
}
|
|
} else {
|
|
for i in 0..m {
|
|
let a = &mut a[2 * i * t..(2 * (i + 1) * t)];
|
|
let (left, right) = a.split_at_mut(t);
|
|
|
|
for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
|
|
let (ox, oy) = forward_butterly_0_to_4q(*x, *y, w[i], w_shoup[i], q, q_twice);
|
|
*x = ox;
|
|
*y = oy;
|
|
}
|
|
}
|
|
}
|
|
|
|
m <<= 1;
|
|
}
|
|
}
|
|
|
|
/// Same as `ntt_lazy` with output in range [0, q)
|
|
pub fn ntt(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) {
|
|
assert!(a.len() == psi.len());
|
|
|
|
let n = a.len();
|
|
let mut t = n;
|
|
|
|
let mut m = 1;
|
|
while m < n {
|
|
t >>= 1;
|
|
let w = &psi[m..];
|
|
let w_shoup = &psi_shoup[m..];
|
|
|
|
if t == 1 {
|
|
for (a, w, w_shoup) in izip!(a.chunks_mut(2), w.iter(), w_shoup.iter()) {
|
|
let (ox, oy) = forward_butterly_0_to_2q(a[0], a[1], *w, *w_shoup, q, q_twice);
|
|
a[0] = ox.min(ox.wrapping_sub(q_twice));
|
|
a[1] = oy.min(oy.wrapping_sub(q_twice));
|
|
}
|
|
} else {
|
|
for i in 0..m {
|
|
let a = &mut a[2 * i * t..(2 * (i + 1) * t)];
|
|
let (left, right) = a.split_at_mut(t);
|
|
|
|
for (x, y) in izip!(left.iter_mut(), right.iter_mut()) {
|
|
let (ox, oy) = forward_butterly_0_to_4q(*x, *y, w[i], w_shoup[i], q, q_twice);
|
|
*x = ox;
|
|
*y = oy;
|
|
}
|
|
}
|
|
}
|
|
|
|
m <<= 1;
|
|
}
|
|
}
|
|
|
|
/// Inverse number theoretic transform of input vector `a` with each element can
|
|
/// be in range [0, 2q). Outputs vector INTT(a) with each element in range [0,
|
|
/// 2q)
|
|
///
|
|
/// Implements backward number theorectic transform using GS algorithm as given in Algorithm 2 of https://eprint.iacr.org/2016/504.pdf
|
|
pub fn ntt_inv_lazy(
|
|
a: &mut [u64],
|
|
psi_inv: &[u64],
|
|
psi_inv_shoup: &[u64],
|
|
n_inv: u64,
|
|
q: u64,
|
|
q_twice: u64,
|
|
) {
|
|
debug_assert!(a.len() == psi_inv.len());
|
|
|
|
let mut m = a.len();
|
|
let mut t = 1;
|
|
while m > 1 {
|
|
let mut j_1: usize = 0;
|
|
let h = m >> 1;
|
|
for i in 0..h {
|
|
let j_2 = j_1 + t;
|
|
unsafe {
|
|
let w_inv = psi_inv.get_unchecked(h + i);
|
|
let w_inv_shoup = psi_inv_shoup.get_unchecked(h + i);
|
|
|
|
for j in j_1..j_2 {
|
|
let x = a.get_unchecked_mut(j) as *mut u64;
|
|
let y = a.get_unchecked_mut(j + t) as *mut u64;
|
|
inverse_butterfly(x, y, w_inv, w_inv_shoup, &q, &q_twice);
|
|
}
|
|
}
|
|
j_1 = j_1 + 2 * t;
|
|
}
|
|
t *= 2;
|
|
m >>= 1;
|
|
}
|
|
|
|
a.iter_mut()
|
|
.for_each(|a0| *a0 = ((*a0 as u128 * n_inv as u128) % q as u128) as u64);
|
|
}
|
|
|
|
/// Find n^{th} root of unity in field F_q, if one exists
|
|
///
|
|
/// Note: n^{th} root of unity exists if and only if $q = 1 \mod{n}$
|
|
pub(crate) fn find_primitive_root<R: RngCore>(q: u64, n: u64, rng: &mut R) -> Option<u64> {
|
|
assert!(n.is_power_of_two(), "{n} is not power of two");
|
|
|
|
// n^th root of unity only exists if n|(q-1)
|
|
assert!(q % n == 1, "{n}^th root of unity in F_{q} does not exists");
|
|
|
|
let t = (q - 1) / n;
|
|
|
|
for _ in 0..100 {
|
|
let mut omega = rng.gen::<u64>() % q;
|
|
|
|
// \omega = \omega^t. \omega is now n^th root of unity
|
|
omega = mod_exponent(omega, t, q);
|
|
|
|
// We restrict n to be power of 2. Thus checking whether \omega is primitive
|
|
// n^th root of unity is as simple as checking: \omega^{n/2} != 1
|
|
if mod_exponent(omega, n >> 1, q) == 1 {
|
|
continue;
|
|
} else {
|
|
return Some(omega);
|
|
}
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct NttBackendU64 {
|
|
q: u64,
|
|
q_twice: u64,
|
|
n: u64,
|
|
n_inv: u64,
|
|
psi_powers_bo: Box<[u64]>,
|
|
psi_inv_powers_bo: Box<[u64]>,
|
|
psi_powers_bo_shoup: Box<[u64]>,
|
|
psi_inv_powers_bo_shoup: Box<[u64]>,
|
|
}
|
|
|
|
impl NttBackendU64 {
|
|
fn _new(q: u64, n: usize) -> Self {
|
|
// \psi = 2n^{th} primitive root of unity in F_q
|
|
let mut rng = ChaCha8Rng::from_seed([0u8; 32]);
|
|
let psi = find_primitive_root(q, (n * 2) as u64, &mut rng)
|
|
.expect("Unable to find 2n^th root of unity");
|
|
let psi_inv = mod_inverse(psi, q);
|
|
|
|
// assert!(
|
|
// ((psi_inv as u128 * psi as u128) % q as u128) == 1,
|
|
// "psi:{psi}, psi_inv:{psi_inv}"
|
|
// );
|
|
|
|
let modulus = ModularOpsU64::new(q);
|
|
|
|
let mut psi_powers = Vec::with_capacity(n as usize);
|
|
let mut psi_inv_powers = Vec::with_capacity(n as usize);
|
|
let mut running_psi = 1;
|
|
let mut running_psi_inv = 1;
|
|
for _ in 0..n {
|
|
psi_powers.push(running_psi);
|
|
psi_inv_powers.push(running_psi_inv);
|
|
|
|
running_psi = modulus.mul(&running_psi, &psi);
|
|
running_psi_inv = modulus.mul(&running_psi_inv, &psi_inv);
|
|
}
|
|
|
|
// powers stored in bit reversed order
|
|
let mut psi_powers_bo = vec![0u64; n as usize];
|
|
let mut psi_inv_powers_bo = vec![0u64; n as usize];
|
|
let shift_by = n.leading_zeros() + 1;
|
|
for i in 0..n as usize {
|
|
// i in bit reversed order
|
|
let bo_index = i.reverse_bits() >> shift_by;
|
|
|
|
psi_powers_bo[bo_index] = psi_powers[i];
|
|
psi_inv_powers_bo[bo_index] = psi_inv_powers[i];
|
|
}
|
|
|
|
// shoup representation
|
|
let psi_powers_bo_shoup = psi_powers_bo
|
|
.iter()
|
|
.map(|v| shoup_representation_fq(*v, q))
|
|
.collect_vec();
|
|
let psi_inv_powers_bo_shoup = psi_inv_powers_bo
|
|
.iter()
|
|
.map(|v| shoup_representation_fq(*v, q))
|
|
.collect_vec();
|
|
|
|
// n^{-1} \mod{q}
|
|
let n_inv = mod_inverse(n as u64, q);
|
|
|
|
NttBackendU64 {
|
|
q,
|
|
q_twice: 2 * q,
|
|
n: n as u64,
|
|
n_inv,
|
|
psi_powers_bo: psi_powers_bo.into_boxed_slice(),
|
|
psi_inv_powers_bo: psi_inv_powers_bo.into_boxed_slice(),
|
|
psi_powers_bo_shoup: psi_powers_bo_shoup.into_boxed_slice(),
|
|
psi_inv_powers_bo_shoup: psi_inv_powers_bo_shoup.into_boxed_slice(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<M: Modulus<Element = u64>> NttInit<M> for NttBackendU64 {
|
|
fn new(q: &M, n: usize) -> Self {
|
|
// This NTT does not support native modulus
|
|
assert!(!q.is_native());
|
|
NttBackendU64::_new(q.q().unwrap(), n)
|
|
}
|
|
}
|
|
|
|
impl NttBackendU64 {
|
|
fn reduce_from_lazy(&self, a: &mut [u64]) {
|
|
let q = self.q;
|
|
a.iter_mut().for_each(|a0| {
|
|
if *a0 >= q {
|
|
*a0 = *a0 - q;
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
impl Ntt for NttBackendU64 {
|
|
type Element = u64;
|
|
|
|
fn forward_lazy(&self, v: &mut [Self::Element]) {
|
|
ntt_lazy(
|
|
v,
|
|
&self.psi_powers_bo,
|
|
&self.psi_powers_bo_shoup,
|
|
self.q,
|
|
self.q_twice,
|
|
)
|
|
}
|
|
|
|
fn forward(&self, v: &mut [Self::Element]) {
|
|
ntt(
|
|
v,
|
|
&self.psi_powers_bo,
|
|
&self.psi_powers_bo_shoup,
|
|
self.q,
|
|
self.q_twice,
|
|
);
|
|
}
|
|
|
|
fn backward_lazy(&self, v: &mut [Self::Element]) {
|
|
ntt_inv_lazy(
|
|
v,
|
|
&self.psi_inv_powers_bo,
|
|
&self.psi_inv_powers_bo_shoup,
|
|
self.n_inv,
|
|
self.q,
|
|
self.q_twice,
|
|
)
|
|
}
|
|
|
|
fn backward(&self, v: &mut [Self::Element]) {
|
|
ntt_inv_lazy(
|
|
v,
|
|
&self.psi_inv_powers_bo,
|
|
&self.psi_inv_powers_bo_shoup,
|
|
self.n_inv,
|
|
self.q,
|
|
self.q_twice,
|
|
);
|
|
self.reduce_from_lazy(v);
|
|
}
|
|
}
|
|
|
|
mod tests {
|
|
use itertools::Itertools;
|
|
use rand::{thread_rng, Rng};
|
|
use rand_distr::Uniform;
|
|
|
|
use super::{NttBackendU64, NttInit};
|
|
use crate::{
|
|
backend::{ArithmeticOps, ModInit, ModularOpsU64, VectorOps},
|
|
ntt::Ntt,
|
|
utils::{generate_prime, negacyclic_mul},
|
|
};
|
|
|
|
const Q_60_BITS: u64 = 1152921504606748673;
|
|
const N: usize = 1 << 4;
|
|
|
|
const K: usize = 128;
|
|
|
|
fn random_vec_in_fq(size: usize, q: u64) -> Vec<u64> {
|
|
thread_rng()
|
|
.sample_iter(Uniform::new(0, q))
|
|
.take(size)
|
|
.collect_vec()
|
|
}
|
|
|
|
#[test]
|
|
fn native_ntt_backend_works() {
|
|
// TODO(Jay): Improve tests. Add tests for different primes and ring size.
|
|
let ntt_backend = NttBackendU64::_new(Q_60_BITS, N);
|
|
for _ in 0..K {
|
|
let mut a = random_vec_in_fq(N, Q_60_BITS);
|
|
let a_clone = a.clone();
|
|
|
|
ntt_backend.forward(&mut a);
|
|
assert_ne!(a, a_clone);
|
|
ntt_backend.backward(&mut a);
|
|
assert_eq!(a, a_clone);
|
|
|
|
ntt_backend.forward_lazy(&mut a);
|
|
assert_ne!(a, a_clone);
|
|
ntt_backend.backward(&mut a);
|
|
assert_eq!(a, a_clone);
|
|
|
|
ntt_backend.forward(&mut a);
|
|
ntt_backend.backward_lazy(&mut a);
|
|
// reduce
|
|
a.iter_mut().for_each(|a0| {
|
|
if *a0 > Q_60_BITS {
|
|
*a0 -= *a0 - Q_60_BITS;
|
|
}
|
|
});
|
|
assert_eq!(a, a_clone);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn native_ntt_negacylic_mul() {
|
|
let primes = [25, 40, 50, 60]
|
|
.iter()
|
|
.map(|bits| generate_prime(*bits, (2 * N) as u64, 1u64 << bits).unwrap())
|
|
.collect_vec();
|
|
|
|
for p in primes.into_iter() {
|
|
let ntt_backend = NttBackendU64::_new(p, N);
|
|
let modulus_backend = ModularOpsU64::new(p);
|
|
for _ in 0..K {
|
|
let a = random_vec_in_fq(N, p);
|
|
let b = random_vec_in_fq(N, p);
|
|
|
|
let mut a_clone = a.clone();
|
|
let mut b_clone = b.clone();
|
|
ntt_backend.forward_lazy(&mut a_clone);
|
|
ntt_backend.forward_lazy(&mut b_clone);
|
|
modulus_backend.elwise_mul_mut(&mut a_clone, &b_clone);
|
|
ntt_backend.backward(&mut a_clone);
|
|
|
|
let mul = |a: &u64, b: &u64| {
|
|
let tmp = *a as u128 * *b as u128;
|
|
(tmp % p as u128) as u64
|
|
};
|
|
let expected_out = negacyclic_mul(&a, &b, mul, p);
|
|
|
|
assert_eq!(a_clone, expected_out);
|
|
}
|
|
}
|
|
}
|
|
}
|