mv arithmetic arith

This commit is contained in:
2025-06-22 19:15:14 +02:00
parent 7740a3ef3e
commit 19457c98dd
12 changed files with 30 additions and 32 deletions

3
arith/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
/target
Cargo.lock
*.sage.py

9
arith/Cargo.toml Normal file
View File

@@ -0,0 +1,9 @@
[package]
name = "arith"
version = "0.1.0"
edition = "2024"
[dependencies]
anyhow = { workspace = true }
rand = { workspace = true }
rand_distr = { workspace = true }

2
arith/README.md Normal file
View File

@@ -0,0 +1,2 @@
# arithmetic
Contains $\mathbb{Z}_q$ and $\mathbb{Z}_q[X]/(X^N+1)$ arithmetic implementations, together with the NTT implementation.

16
arith/src/lib.rs Normal file
View File

@@ -0,0 +1,16 @@
#![allow(non_snake_case)]
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(clippy::upper_case_acronyms)]
#![allow(dead_code)] // TMP
mod naive_ntt; // TODO rm
pub mod ntt;
pub mod ring;
pub mod ringq;
pub mod zq;
pub use ntt::NTT;
pub use ring::R;
pub use ringq::Rq;
pub use zq::Zq;

196
arith/src/naive_ntt.rs Normal file
View File

@@ -0,0 +1,196 @@
//! this file implements the non-efficient NTT, which uses multiplication by the
//! Vandermonde matrix.
use crate::zq::Zq;
use anyhow::{Result, anyhow};
#[derive(Debug)]
pub struct NTT<const Q: u64, const N: usize> {
pub primitive: Zq<Q>,
// nth_roots: Vec<Zq<Q>>,
pub ntt: Vec<Vec<Zq<Q>>>,
pub intt: Vec<Vec<Zq<Q>>>,
}
impl<const Q: u64, const N: usize> NTT<Q, N> {
pub fn new() -> Result<Self> {
// TODO change n to be u64 and ensure that is n<Q
// note: `n` here is not the `N` from `(X^N+1)`
// TODO: in fact n will be N (trait/struct param)
// let primitive = Self::get_primitive_root_of_unity((2 * N) as u64)?;
let primitive = Self::get_primitive_root_of_unity((2 * N) as u64)?;
// let mut nth_roots = vec![Zq(0); N];
// let mut w_i = Zq(1);
// for i in 0..N {
// w_i = w_i * primitive;
// nth_roots[i] = w_i;
// }
let ntt: Vec<Vec<Zq<Q>>> = Self::vandermonde(primitive);
let intt = Self::invert_vandermonde(&ntt);
Ok(Self {
primitive,
// nth_roots,
ntt,
intt,
})
}
pub fn vandermonde(primitive: Zq<Q>) -> Vec<Vec<Zq<Q>>> {
let mut v: Vec<Vec<Zq<Q>>> = vec![];
let n = (2 * N) as u64;
// let n = N as u64;
for i in 0..n {
let mut row: Vec<Zq<Q>> = vec![];
let primitive_i = primitive.exp(Zq(i));
let mut primitive_ij = Zq(1);
for _ in 0..n {
row.push(primitive_ij);
primitive_ij = primitive_ij * primitive_i;
}
v.push(row);
}
v
}
// specifically for the Vandermonde matrix
pub fn invert_vandermonde(v: &Vec<Vec<Zq<Q>>>) -> Vec<Vec<Zq<Q>>> {
let n = 2 * N;
// let n = N;
let mut inv: Vec<Vec<Zq<Q>>> = vec![];
for i in 0..n {
let w_i = v[i][1]; // = w_i^1=w^i^1 = w^i
let w_i_inv = w_i.inv();
let mut row: Vec<Zq<Q>> = vec![];
for j in 0..n {
row.push(w_i_inv.exp(Zq(j as u64)) / Zq(n as u64));
}
inv.push(row);
}
inv
}
pub fn get_primitive_root_of_unity(n: u64) -> Result<Zq<Q>> {
// using the method described by Thomas Pornin in
// https://crypto.stackexchange.com/a/63616
// assert!((Q - 1) % N as u64 == 0);
assert!((Q - 1) % n == 0);
// TODO maybe not using Zq and using u64 directly
let n = Zq(n);
for k in 0..Q {
if k == 0 {
continue;
}
let g = Zq(k);
// g = F.random_element()
if g == Zq(0) {
continue;
}
let w = g.exp((-Zq(1)) / n);
if w.exp(n / Zq(2)) != Zq(1) {
// g is the generator
return Ok(w);
}
}
Err(anyhow!("can not find the primitive root of unity"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand_distr::Uniform;
use crate::ring::Rq;
use crate::ring::matrix_vec_product;
#[test]
fn roots_of_unity() -> Result<()> {
const Q: u64 = 12289;
const N: usize = 512;
let _ntt = NTT::<Q, N>::new()?;
Ok(())
}
#[test]
fn vandermonde_ntt() -> Result<()> {
const Q: u64 = 41;
const N: usize = 4;
let primitive = NTT::<Q, N>::get_primitive_root_of_unity((2 * N) as u64)?;
let v = NTT::<Q, N>::vandermonde(primitive);
// naively compute the Vandermonde matrix, and assert that the one from the method matches
// the naively obtained one
let n2 = (2 * N) as u64;
let mut v2: Vec<Vec<Zq<Q>>> = vec![];
for i in 0..n2 {
let mut row: Vec<Zq<Q>> = vec![];
for j in 0..n2 {
row.push(primitive.exp(Zq(i * j)));
}
v2.push(row);
}
assert_eq!(v, v2);
let v_inv = NTT::<Q, N>::invert_vandermonde(&v);
let mut rng = rand::thread_rng();
let uniform_distr = Uniform::new(0_f64, Q as f64);
let a = Rq::<Q, N>::rand_f64(&mut rng, uniform_distr)?;
// let a = PR::<Q, N>::new_from_u64(vec![36, 21, 9, 19]);
// let a_padded_coeffs: [Zq<Q>; 2 * N] =
// std::array::from_fn(|i| if i < N { a.coeffs[i] } else { Zq::zero() });
let mut a_padded = a.coeffs.to_vec();
a_padded.append(&mut vec![Zq(0); N]);
// let a_ntt = a_padded.mul_by_matrix(&v)?;
let a_ntt = matrix_vec_product(&v, &a_padded)?;
let a_intt: Vec<Zq<Q>> = matrix_vec_product(&v_inv, &a_ntt)?;
assert_eq!(a_intt, a_padded);
let a_intt_arr: [Zq<Q>; N] = std::array::from_fn(|i| a_intt[i]);
assert_eq!(Rq::new(a_intt_arr, None), a);
Ok(())
}
#[test]
fn vec_by_ntt() -> Result<()> {
const Q: u64 = 257;
const N: usize = 4;
// let primitive = NTT::<Q, N>::get_primitive_root_of_unity((2*N) as u64)?;
let ntt = NTT::<Q, N>::new()?;
let a: Vec<Zq<Q>> = vec![256, 256, 256, 256, 0, 0, 0, 0]
.iter()
.map(|&e| Zq::from_u64(e))
.collect();
let a_ntt = matrix_vec_product(&ntt.ntt, &a)?;
let a_intt = matrix_vec_product(&ntt.intt, &a_ntt)?;
assert_eq!(a_intt, a);
Ok(())
}
#[test]
fn bench_ntt() -> Result<()> {
// const Q: u64 = 12289;
// const N: usize = 512;
const Q: u64 = 257;
const N: usize = 4;
// let primitive = NTT::<Q, N>::get_primitive_root_of_unity((2*N) as u64)?;
let ntt = NTT::<Q, N>::new()?;
let rng = rand::thread_rng();
let a = Rq::<Q, { 2 * N }>::rand_f64(rng, Uniform::new(0_f64, (Q - 1) as f64))?;
let a = a.coeffs;
dbg!(&a);
let a_ntt = matrix_vec_product(&ntt.ntt, &a.to_vec())?;
dbg!(&a_ntt);
let a_intt = matrix_vec_product(&ntt.intt, &a_ntt)?;
dbg!(&a_intt);
assert_eq!(a_intt, a);
// TODO bench
Ok(())
}
}

187
arith/src/ntt.rs Normal file
View File

@@ -0,0 +1,187 @@
//! Implementation of the NTT & iNTT, following the CT & GS algorighms, more details in
//! https://eprint.iacr.org/2017/727.pdf, some notes at
//! https://github.com/arnaucube/math/blob/master/notes_ntt.pdf .
use crate::zq::Zq;
#[derive(Debug)]
pub struct NTT<const Q: u64, const N: usize> {}
impl<const Q: u64, const N: usize> NTT<Q, N> {
const N_INV: Zq<Q> = Zq(const_inv_mod::<Q>(N as u64));
// since we work over Zq[X]/(X^N+1) (negacyclic), get the 2*N-th root of unity
pub(crate) const ROOT_OF_UNITY: u64 = primitive_root_of_unity::<Q>(2 * N);
pub(crate) const ROOTS_OF_UNITY: [Zq<Q>; N] = roots_of_unity(Self::ROOT_OF_UNITY);
const ROOTS_OF_UNITY_INV: [Zq<Q>; N] = roots_of_unity_inv(Self::ROOTS_OF_UNITY);
}
impl<const Q: u64, const N: usize> NTT<Q, N> {
/// implements the Cooley-Tukey (CT) algorithm. Details at
/// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.1 of
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
pub fn ntt(a: [Zq<Q>; N]) -> [Zq<Q>; N] {
let mut t = N / 2;
let mut m = 1;
let mut r: [Zq<Q>; N] = a.clone();
while m < N {
let mut k = 0;
for i in 0..m {
let S: Zq<Q> = Self::ROOTS_OF_UNITY[m + i];
for j in k..k + t {
let U: Zq<Q> = r[j];
let V: Zq<Q> = r[j + t] * S;
r[j] = U + V;
r[j + t] = U - V;
}
k = k + 2 * t;
}
t /= 2;
m *= 2;
}
r
}
/// implements the Cooley-Tukey (CT) algorithm. Details at
/// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.2 of
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
pub fn intt(a: [Zq<Q>; N]) -> [Zq<Q>; N] {
let mut t = 1;
let mut m = N / 2;
let mut r: [Zq<Q>; N] = a.clone();
while m > 0 {
let mut k = 0;
for i in 0..m {
let S: Zq<Q> = Self::ROOTS_OF_UNITY_INV[m + i];
for j in k..k + t {
let U: Zq<Q> = r[j];
let V: Zq<Q> = r[j + t];
r[j] = U + V;
r[j + t] = (U - V) * S;
}
k += 2 * t;
}
t *= 2;
m /= 2;
}
for i in 0..N {
r[i] = r[i] * Self::N_INV;
}
r
}
}
/// computes a primitive N-th root of unity using the method described by Thomas
/// Pornin in https://crypto.stackexchange.com/a/63616
const fn primitive_root_of_unity<const Q: u64>(N: usize) -> u64 {
assert!(N.is_power_of_two());
assert!((Q - 1) % N as u64 == 0);
let n: u64 = N as u64;
let mut k = 1;
while k < Q {
// alternatively could get a random k at each iteration, if so, add the following if:
// `if k == 0 { continue; }`
let w = const_exp_mod::<Q>(k, (Q - 1) / n);
if const_exp_mod::<Q>(w, n / 2) != 1 {
return w; // w is a primitive N-th root of unity
}
k += 1;
}
panic!("No primitive root of unity");
}
const fn roots_of_unity<const Q: u64, const N: usize>(w: u64) -> [Zq<Q>; N] {
let mut r: [Zq<Q>; N] = [Zq(0u64); N];
let mut i = 0;
let log_n = N.ilog2();
while i < N {
// (return the roots in bit-reverset order)
let j = ((i as u64).reverse_bits() >> (64 - log_n)) as usize;
r[i] = Zq(const_exp_mod::<Q>(w, j as u64));
i += 1;
}
r
}
const fn roots_of_unity_inv<const Q: u64, const N: usize>(v: [Zq<Q>; N]) -> [Zq<Q>; N] {
// assumes that the inputted roots are already in bit-reverset order
let mut r: [Zq<Q>; N] = [Zq(0u64); N];
let mut i = 0;
while i < N {
r[i] = Zq(const_inv_mod::<Q>(v[i].0));
i += 1;
}
r
}
/// returns x^k mod Q
const fn const_exp_mod<const Q: u64>(x: u64, k: u64) -> u64 {
// work on u128 to avoid overflow
let mut r = 1u128;
let mut x = x as u128;
let mut k = k as u128;
x = x % Q as u128;
// exponentiation by square strategy
while k > 0 {
if k % 2 == 1 {
r = (r * x) % Q as u128;
}
x = (x * x) % Q as u128;
k /= 2;
}
r as u64
}
/// returns x^-1 mod Q
const fn const_inv_mod<const Q: u64>(x: u64) -> u64 {
// by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q
const_exp_mod::<Q>(x, Q - 2)
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
use std::array;
#[test]
fn test_ntt() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 4;
let a: [u64; N] = [1u64, 2, 3, 4];
let a: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(a[i]));
let a_ntt = NTT::<Q, N>::ntt(a);
let a_intt = NTT::<Q, N>::intt(a_ntt);
dbg!(&a);
dbg!(&a_ntt);
dbg!(&a_intt);
dbg!(NTT::<Q, N>::ROOT_OF_UNITY);
dbg!(NTT::<Q, N>::ROOTS_OF_UNITY);
assert_eq!(a, a_intt);
Ok(())
}
#[test]
fn test_ntt_loop() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 512;
use rand::distributions::Distribution;
use rand::distributions::Uniform;
let mut rng = rand::thread_rng();
let dist = Uniform::new(0_f64, Q as f64);
for _ in 0..100 {
let a: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng)));
let a_ntt = NTT::<Q, N>::ntt(a);
let a_intt = NTT::<Q, N>::intt(a_ntt);
assert_eq!(a, a_intt);
}
Ok(())
}
}

216
arith/src/ring.rs Normal file
View File

@@ -0,0 +1,216 @@
//! Polynomial ring Z[X]/(X^N+1)
//!
use anyhow::{Result, anyhow};
use rand::{Rng, distributions::Distribution};
use std::array;
use std::fmt;
use std::ops;
use crate::ntt::NTT;
use crate::zq::Zq;
// PolynomialRing element, where the PolynomialRing is R = Z[X]/(X^n +1)
#[derive(Clone, Copy, Debug)]
pub struct R<const N: usize>([i64; N]);
impl<const Q: u64, const N: usize> From<crate::ringq::Rq<Q, N>> for R<N> {
fn from(rq: crate::ringq::Rq<Q, N>) -> Self {
Self::from_vec_u64(rq.coeffs().to_vec().iter().map(|e| e.0).collect())
}
}
impl<const N: usize> R<N> {
pub fn coeffs(&self) -> [i64; N] {
self.0
}
pub fn to_rq<const Q: u64>(self) -> crate::Rq<Q, N> {
crate::Rq::<Q, N>::from(self)
}
pub fn from_vec(coeffs: Vec<i64>) -> Self {
let mut p = coeffs;
modulus::<N>(&mut p);
Self(array::from_fn(|i| p[i]))
}
// this method is mostly for tests
pub fn from_vec_u64(coeffs: Vec<u64>) -> Self {
let coeffs_i64 = coeffs.iter().map(|c| *c as i64).collect();
Self::from_vec(coeffs_i64)
}
pub fn from_vec_f64(coeffs: Vec<f64>) -> Self {
let coeffs_i64 = coeffs.iter().map(|c| c.round() as i64).collect();
Self::from_vec(coeffs_i64)
}
pub fn new(coeffs: [i64; N]) -> Self {
Self(coeffs)
}
pub fn mul_by_i64(&self, s: i64) -> Self {
Self(array::from_fn(|i| self.0[i] * s))
}
// performs the multiplication and division over f64, and then it rounds the
// result, only applying the mod Q at the end
pub fn mul_div_round<const Q: u64>(&self, num: u64, den: u64) -> crate::Rq<Q, N> {
let r: Vec<f64> = self
.coeffs()
.iter()
.map(|e| ((num as f64 * *e as f64) / den as f64).round())
.collect();
crate::Rq::<Q, N>::from_vec_f64(r)
}
}
pub fn mul_div_round<const Q: u64, const N: usize>(
v: Vec<i64>,
num: u64,
den: u64,
) -> crate::Rq<Q, N> {
// dbg!(&v);
let r: Vec<f64> = v
.iter()
.map(|e| ((num as f64 * *e as f64) / den as f64).round())
.collect();
// dbg!(&r);
crate::Rq::<Q, N>::from_vec_f64(r)
}
// TODO rename to make it clear that is not mod q, but mod X^N+1
// apply mod (X^N+1)
pub fn modulus<const N: usize>(p: &mut Vec<i64>) {
if p.len() < N {
return;
}
for i in N..p.len() {
p[i - N] = p[i - N].clone() - p[i].clone();
p[i] = 0;
}
p.truncate(N);
}
pub fn modulus_i128<const N: usize>(p: &mut Vec<i128>) {
if p.len() < N {
return;
}
for i in N..p.len() {
p[i - N] = p[i - N].clone() - p[i].clone();
p[i] = 0;
}
p.truncate(N);
}
impl<const N: usize> PartialEq for R<N> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<const N: usize> ops::Add<R<N>> for R<N> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(array::from_fn(|i| self.0[i] + rhs.0[i]))
}
}
impl<const N: usize> ops::Add<&R<N>> for &R<N> {
type Output = R<N>;
fn add(self, rhs: &R<N>) -> Self::Output {
R(array::from_fn(|i| self.0[i] + rhs.0[i]))
}
}
impl<const N: usize> ops::Sub<R<N>> for R<N> {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Self(array::from_fn(|i| self.0[i] - rhs.0[i]))
}
}
impl<const N: usize> ops::Sub<&R<N>> for &R<N> {
type Output = R<N>;
fn sub(self, rhs: &R<N>) -> Self::Output {
R(array::from_fn(|i| self.0[i] - rhs.0[i]))
}
}
impl<const N: usize> ops::Mul<R<N>> for R<N> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
naive_poly_mul(&self, &rhs)
}
}
impl<const N: usize> ops::Mul<&R<N>> for &R<N> {
type Output = R<N>;
fn mul(self, rhs: &R<N>) -> Self::Output {
naive_poly_mul(self, rhs)
}
}
// TODO WIP
pub fn naive_poly_mul<const N: usize>(poly1: &R<N>, poly2: &R<N>) -> R<N> {
let poly1: Vec<i128> = poly1.0.iter().map(|c| *c as i128).collect();
let poly2: Vec<i128> = poly2.0.iter().map(|c| *c as i128).collect();
let mut result: Vec<i128> = vec![0; (N * 2) - 1];
for i in 0..N {
for j in 0..N {
result[i + j] = result[i + j] + poly1[i] * poly2[j];
}
}
// apply mod (X^N + 1))
R::<N>::from_vec(result.iter().map(|c| *c as i64).collect())
}
pub fn naive_mul<const N: usize>(poly1: &R<N>, poly2: &R<N>) -> Vec<i64> {
let poly1: Vec<i128> = poly1.0.iter().map(|c| *c as i128).collect();
let poly2: Vec<i128> = poly2.0.iter().map(|c| *c as i128).collect();
let mut result: Vec<i128> = vec![0; (N * 2) - 1];
for i in 0..N {
for j in 0..N {
result[i + j] = result[i + j] + poly1[i] * poly2[j];
}
}
modulus_i128::<N>(&mut result);
// for c_i in result.iter() {
// println!("---");
// println!("{:?}", &c_i);
// println!("{:?}", *c_i as i64);
// println!("{:?}", (*c_i as i64) as i128);
// assert_eq!(*c_i, (*c_i as i64) as i128, "{:?}", c_i);
// }
// let q: i128 = 65537;
// let result: Vec<i64> = result
// .iter()
// // .map(|c_i| ((c_i % q + q) % q) as i64)
// .map(|c_i| (c_i % q) as i64)
// // .map(|c_i| *c_i as i64)
// .collect();
// result
result.iter().map(|c| *c as i64).collect()
}
// mul by u64
impl<const N: usize> ops::Mul<u64> for R<N> {
type Output = Self;
fn mul(self, s: u64) -> Self {
self.mul_by_i64(s as i64)
}
}
impl<const N: usize> ops::Mul<&u64> for &R<N> {
type Output = R<N>;
fn mul(self, s: &u64) -> Self::Output {
self.mul_by_i64(*s as i64)
}
}
impl<const N: usize> ops::Neg for R<N> {
type Output = Self;
fn neg(self) -> Self::Output {
Self(array::from_fn(|i| -self.0[i]))
}
}

511
arith/src/ringq.rs Normal file
View File

@@ -0,0 +1,511 @@
//! Polynomial ring Z_q[X]/(X^N+1)
//!
use rand::{Rng, distributions::Distribution};
use std::array;
use std::fmt;
use std::ops;
use crate::ntt::NTT;
use crate::zq::{Zq, modulus_u64};
use anyhow::{Result, anyhow};
/// PolynomialRing element, where the PolynomialRing is R = Z_q[X]/(X^n +1)
/// The implementation assumes that q is prime.
#[derive(Clone, Copy)]
pub struct Rq<const Q: u64, const N: usize> {
pub(crate) coeffs: [Zq<Q>; N],
// evals are set when doig a PRxPR multiplication, so it can be reused in future
// multiplications avoiding recomputing it
pub(crate) evals: Option<[Zq<Q>; N]>,
}
// TODO define a trait "PolynomialRingTrait" or similar, so that when other structs use it can just
// use the trait and not need to add '<Q, N>' to their params
impl<const Q: u64, const N: usize> From<crate::ring::R<N>> for Rq<Q, N> {
fn from(r: crate::ring::R<N>) -> Self {
Self::from_vec(
r.coeffs()
.iter()
.map(|e| Zq::<Q>::from_f64(*e as f64))
.collect(),
)
}
}
// apply mod (X^N+1)
pub fn modulus<const Q: u64, const N: usize>(p: &mut Vec<Zq<Q>>) {
if p.len() < N {
return;
}
for i in N..p.len() {
p[i - N] = p[i - N].clone() - p[i].clone();
p[i] = Zq(0);
}
p.truncate(N);
}
// PR stands for PolynomialRing
impl<const Q: u64, const N: usize> Rq<Q, N> {
pub fn coeffs(&self) -> [Zq<Q>; N] {
self.coeffs
}
pub fn to_r(self) -> crate::R<N> {
crate::R::<N>::from(self)
}
pub fn zero() -> Self {
let coeffs = array::from_fn(|_| Zq::zero());
Self {
coeffs,
evals: None,
}
}
pub fn from_vec(coeffs: Vec<Zq<Q>>) -> Self {
let mut p = coeffs;
modulus::<Q, N>(&mut p);
let coeffs = array::from_fn(|i| p[i]);
Self {
coeffs,
evals: None,
}
}
// this method is mostly for tests
pub fn from_vec_u64(coeffs: Vec<u64>) -> Self {
let coeffs_mod_q = coeffs.iter().map(|c| Zq::from_u64(*c)).collect();
Self::from_vec(coeffs_mod_q)
}
pub fn from_vec_f64(coeffs: Vec<f64>) -> Self {
let coeffs_mod_q = coeffs.iter().map(|c| Zq::from_f64(*c)).collect();
Self::from_vec(coeffs_mod_q)
}
pub fn from_vec_i64(coeffs: Vec<i64>) -> Self {
let coeffs_mod_q = coeffs.iter().map(|c| Zq::from_f64(*c as f64)).collect();
Self::from_vec(coeffs_mod_q)
}
pub fn new(coeffs: [Zq<Q>; N], evals: Option<[Zq<Q>; N]>) -> Self {
Self { coeffs, evals }
}
pub fn rand_abs(mut rng: impl Rng, dist: impl Distribution<f64>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng).abs()));
Ok(Self {
coeffs,
evals: None,
})
}
pub fn rand_f64_abs(mut rng: impl Rng, dist: impl Distribution<f64>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng).abs()));
Ok(Self {
coeffs,
evals: None,
})
}
pub fn rand_f64(mut rng: impl Rng, dist: impl Distribution<f64>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng)));
Ok(Self {
coeffs,
evals: None,
})
}
pub fn rand_u64(mut rng: impl Rng, dist: impl Distribution<u64>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_u64(dist.sample(&mut rng)));
Ok(Self {
coeffs,
evals: None,
})
}
// WIP. returns random v \in {0,1}. // TODO {-1, 0, 1}
pub fn rand_bin(mut rng: impl Rng, dist: impl Distribution<bool>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_bool(dist.sample(&mut rng)));
Ok(Rq {
coeffs,
evals: None,
})
}
// Warning: this method will behave differently depending on the values P and Q:
// if Q<P, it just 'renames' the modulus parameter to P
// if Q>=P, it crops to mod P
pub fn remodule<const P: u64>(&self) -> Rq<P, N> {
Rq::<P, N>::from_vec_u64(self.coeffs().iter().map(|m_i| m_i.0).collect())
}
// applies mod(T) to all coefficients of self
pub fn coeffs_mod<const T: u64>(&self) -> Self {
Rq::<Q, N>::from_vec_u64(
self.coeffs()
.iter()
.map(|m_i| modulus_u64::<T>(m_i.0))
.collect(),
)
}
// TODO review if needed, or if with this interface
pub fn mul_by_matrix(&self, m: &Vec<Vec<Zq<Q>>>) -> Result<Vec<Zq<Q>>> {
matrix_vec_product(m, &self.coeffs.to_vec())
}
pub fn mul_by_zq(&self, s: &Zq<Q>) -> Self {
Self {
coeffs: array::from_fn(|i| self.coeffs[i] * *s),
evals: None,
}
}
pub fn mul_by_u64(&self, s: u64) -> Self {
let s = Zq::from_u64(s);
Self {
coeffs: array::from_fn(|i| self.coeffs[i] * s),
// coeffs: self.coeffs.iter().map(|&e| e * s).collect(),
evals: None,
}
}
pub fn mul_by_f64(&self, s: f64) -> Self {
Self {
coeffs: array::from_fn(|i| Zq::from_f64(self.coeffs[i].0 as f64 * s)),
evals: None,
}
}
pub fn mul(&mut self, rhs: &mut Self) -> Self {
mul_mut(self, rhs)
}
// divides by the given scalar 's' and rounds, returning a Rq<Q,N>
// TODO rm
pub fn div_round(&self, s: u64) -> Self {
let r: Vec<f64> = self
.coeffs()
.iter()
.map(|e| (e.0 as f64 / s as f64).round())
.collect();
Rq::<Q, N>::from_vec_f64(r)
}
// returns [ [(num/den) * self].round() ] mod q
// ie. performs the multiplication and division over f64, and then it rounds the
// result, only applying the mod Q at the end
pub fn mul_div_round(&self, num: u64, den: u64) -> Self {
let r: Vec<f64> = self
.coeffs()
.iter()
.map(|e| ((num as f64 * e.0 as f64) / den as f64).round())
.collect();
Rq::<Q, N>::from_vec_f64(r)
}
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// TODO simplify
let mut str = "";
let mut zero = true;
for (i, coeff) in self.coeffs.iter().enumerate().rev() {
if coeff.0 == 0 {
continue;
}
zero = false;
f.write_str(str)?;
if coeff.0 != 1 {
f.write_str(coeff.0.to_string().as_str())?;
if i > 0 {
f.write_str("*")?;
}
}
if coeff.0 == 1 && i == 0 {
f.write_str(coeff.0.to_string().as_str())?;
}
if i == 1 {
f.write_str("x")?;
} else if i > 1 {
f.write_str("x^")?;
f.write_str(i.to_string().as_str())?;
}
str = " + ";
}
if zero {
f.write_str("0")?;
}
f.write_str(" mod Z_")?;
f.write_str(Q.to_string().as_str())?;
f.write_str("/(X^")?;
f.write_str(N.to_string().as_str())?;
f.write_str("+1)")?;
Ok(())
}
pub fn infinity_norm(&self) -> u64 {
self.coeffs().iter().map(|x| x.0).fold(0, |a, b| a.max(b))
}
}
pub fn matrix_vec_product<const Q: u64>(m: &Vec<Vec<Zq<Q>>>, v: &Vec<Zq<Q>>) -> Result<Vec<Zq<Q>>> {
// assert_eq!(m.len(), m[0].len()); // TODO change to returning err
// assert_eq!(m.len(), v.len());
if m.len() != m[0].len() {
return Err(anyhow!("expected 'm' to be a square matrix"));
}
if m.len() != v.len() {
return Err(anyhow!(
"m.len: {} should be equal to v.len(): {}",
m.len(),
v.len(),
));
}
Ok(m.iter()
.map(|row| {
row.iter()
.zip(v.iter())
.map(|(&row_i, &v_i)| row_i * v_i)
.sum()
})
.collect::<Vec<Zq<Q>>>())
}
pub fn transpose<const Q: u64>(m: &[Vec<Zq<Q>>]) -> Vec<Vec<Zq<Q>>> {
// TODO case when m[0].len()=0
// TODO non square matrix
let mut r: Vec<Vec<Zq<Q>>> = vec![vec![Zq(0); m[0].len()]; m.len()];
for (i, m_row) in m.iter().enumerate() {
for (j, m_ij) in m_row.iter().enumerate() {
r[j][i] = *m_ij;
}
}
r
}
impl<const Q: u64, const N: usize> PartialEq for Rq<Q, N> {
fn eq(&self, other: &Self) -> bool {
self.coeffs == other.coeffs
}
}
impl<const Q: u64, const N: usize> ops::Add<Rq<Q, N>> for Rq<Q, N> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self {
coeffs: array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]),
evals: None,
}
// Self {
// coeffs: self
// .coeffs
// .iter()
// .zip(rhs.coeffs)
// .map(|(a, b)| *a + b)
// .collect(),
// evals: None,
// }
// Self(r.iter_mut().map(|e| e.r#mod()).collect()) // TODO mod should happen auto in +
}
}
impl<const Q: u64, const N: usize> ops::Add<&Rq<Q, N>> for &Rq<Q, N> {
type Output = Rq<Q, N>;
fn add(self, rhs: &Rq<Q, N>) -> Self::Output {
Rq {
coeffs: array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]),
evals: None,
}
}
}
impl<const Q: u64, const N: usize> ops::Sub<Rq<Q, N>> for Rq<Q, N> {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Self {
coeffs: array::from_fn(|i| self.coeffs[i] - rhs.coeffs[i]),
evals: None,
}
}
}
impl<const Q: u64, const N: usize> ops::Sub<&Rq<Q, N>> for &Rq<Q, N> {
type Output = Rq<Q, N>;
fn sub(self, rhs: &Rq<Q, N>) -> Self::Output {
Rq {
coeffs: array::from_fn(|i| self.coeffs[i] - rhs.coeffs[i]),
evals: None,
}
}
}
impl<const Q: u64, const N: usize> ops::Mul<Rq<Q, N>> for Rq<Q, N> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
mul(&self, &rhs)
}
}
impl<const Q: u64, const N: usize> ops::Mul<&Rq<Q, N>> for &Rq<Q, N> {
type Output = Rq<Q, N>;
fn mul(self, rhs: &Rq<Q, N>) -> Self::Output {
mul(self, rhs)
}
}
// mul by Zq element
impl<const Q: u64, const N: usize> ops::Mul<Zq<Q>> for Rq<Q, N> {
type Output = Self;
fn mul(self, s: Zq<Q>) -> Self {
self.mul_by_zq(&s)
}
}
impl<const Q: u64, const N: usize> ops::Mul<&Zq<Q>> for &Rq<Q, N> {
type Output = Rq<Q, N>;
fn mul(self, s: &Zq<Q>) -> Self::Output {
self.mul_by_zq(s)
}
}
// mul by u64
impl<const Q: u64, const N: usize> ops::Mul<u64> for Rq<Q, N> {
type Output = Self;
fn mul(self, s: u64) -> Self {
self.mul_by_u64(s)
}
}
impl<const Q: u64, const N: usize> ops::Mul<&u64> for &Rq<Q, N> {
type Output = Rq<Q, N>;
fn mul(self, s: &u64) -> Self::Output {
self.mul_by_u64(*s)
}
}
impl<const Q: u64, const N: usize> ops::Neg for Rq<Q, N> {
type Output = Self;
fn neg(self) -> Self::Output {
Self {
coeffs: array::from_fn(|i| -self.coeffs[i]),
evals: None,
}
}
}
fn mul_mut<const Q: u64, const N: usize>(lhs: &mut Rq<Q, N>, rhs: &mut Rq<Q, N>) -> Rq<Q, N> {
// reuse evaluations if already computed
if !lhs.evals.is_some() {
lhs.evals = Some(NTT::<Q, N>::ntt(lhs.coeffs));
};
if !rhs.evals.is_some() {
rhs.evals = Some(NTT::<Q, N>::ntt(rhs.coeffs));
};
let lhs_evals = lhs.evals.unwrap();
let rhs_evals = rhs.evals.unwrap();
let c_ntt: [Zq<Q>; N] = array::from_fn(|i| lhs_evals[i] * rhs_evals[i]);
let c = NTT::<Q, { N }>::intt(c_ntt);
Rq::new(c, Some(c_ntt))
}
fn mul<const Q: u64, const N: usize>(lhs: &Rq<Q, N>, rhs: &Rq<Q, N>) -> Rq<Q, N> {
// reuse evaluations if already computed
let lhs_evals = if lhs.evals.is_some() {
lhs.evals.unwrap()
} else {
NTT::<Q, N>::ntt(lhs.coeffs)
};
let rhs_evals = if rhs.evals.is_some() {
rhs.evals.unwrap()
} else {
NTT::<Q, N>::ntt(rhs.coeffs)
};
let c_ntt: [Zq<Q>; N] = array::from_fn(|i| lhs_evals[i] * rhs_evals[i]);
let c = NTT::<Q, { N }>::intt(c_ntt);
Rq::new(c, Some(c_ntt))
}
impl<const Q: u64, const N: usize> fmt::Display for Rq<Q, N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt(f)?;
Ok(())
}
}
impl<const Q: u64, const N: usize> fmt::Debug for Rq<Q, N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt(f)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn poly_ring() {
// the test values used are generated with SageMath
const Q: u64 = 7;
const N: usize = 3;
// p = 1x + 2x^2 + 3x^3 + 4 x^4 + 5 x^5 in R=Z_q[X]/(X^n +1)
let p = Rq::<Q, N>::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]);
assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
// try with coefficients bigger than Q
let p = Rq::<Q, N>::from_vec_u64(vec![0u64, 1, Q + 2, 3, 4, 5]);
assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
// try with other ring
let p = Rq::<7, 4>::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]);
assert_eq!(p.to_string(), "3*x^3 + 2*x^2 + 3*x + 3 mod Z_7/(X^4+1)");
let p = Rq::<Q, N>::from_vec_u64(vec![0u64, 0, 0, 0, 4, 5]);
assert_eq!(p.to_string(), "2*x^2 + 3*x mod Z_7/(X^3+1)");
let p = Rq::<Q, N>::from_vec_u64(vec![5u64, 4, 5, 2, 1, 0]);
assert_eq!(p.to_string(), "5*x^2 + 3*x + 3 mod Z_7/(X^3+1)");
let a = Rq::<Q, N>::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]);
assert_eq!(a.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
let b = Rq::<Q, N>::from_vec_u64(vec![5u64, 4, 3, 2, 1, 0]);
assert_eq!(b.to_string(), "3*x^2 + 3*x + 3 mod Z_7/(X^3+1)");
// add
assert_eq!((a.clone() + b.clone()).to_string(), "0 mod Z_7/(X^3+1)");
assert_eq!((&a + &b).to_string(), "0 mod Z_7/(X^3+1)");
// assert_eq!((a.0.clone() + b.0.clone()).to_string(), "[0, 0, 0]"); // TODO
// sub
assert_eq!(
(a.clone() - b.clone()).to_string(),
"x^2 + x + 1 mod Z_7/(X^3+1)"
);
}
fn test_mul_opt<const Q: u64, const N: usize>(
a: [u64; N],
b: [u64; N],
expected_c: [u64; N],
) -> Result<()> {
let a: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(a[i]));
let mut a = Rq::new(a, None);
let b: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(b[i]));
let mut b = Rq::new(b, None);
let expected_c: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(expected_c[i]));
let expected_c = Rq::new(expected_c, None);
let c = mul_mut(&mut a, &mut b);
assert_eq!(c, expected_c);
Ok(())
}
#[test]
fn test_mul() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 4;
let a: [u64; N] = [1u64, 2, 3, 4];
let b: [u64; N] = [1u64, 2, 3, 4];
let c: [u64; N] = [65513, 65517, 65531, 20];
test_mul_opt::<Q, N>(a, b, c)?;
let a: [u64; N] = [0u64, 0, 0, 2];
let b: [u64; N] = [0u64, 0, 0, 2];
let c: [u64; N] = [0u64, 0, 65533, 0];
test_mul_opt::<Q, N>(a, b, c)?;
// TODO more testvectors
Ok(())
}
}

248
arith/src/zq.rs Normal file
View File

@@ -0,0 +1,248 @@
use std::fmt;
use std::ops;
// Z_q, integers modulus q, not necessarily prime
#[derive(Clone, Copy, PartialEq)]
pub struct Zq<const Q: u64>(pub u64);
// WIP
// impl<const Q: u64> From<Vec<u64>> for Vec<Zq<Q>> {
// fn from(v: Vec<u64>) -> Self {
// v.into_iter().map(Zq::new).collect()
// }
// }
pub(crate) fn modulus_u64<const Q: u64>(e: u64) -> u64 {
(e % Q + Q) % Q
}
impl<const Q: u64> Zq<Q> {
pub fn from_u64(e: u64) -> Self {
if e >= Q {
// (e % Q + Q) % Q
return Zq(modulus_u64::<Q>(e));
// return Zq(e % Q);
}
Zq(e)
}
pub fn from_f64(e: f64) -> Self {
// WIP method
let e: i64 = e.round() as i64;
let q = Q as i64;
if e < 0 || e >= q {
return Zq(((e % q + q) % q) as u64);
}
Zq(e as u64)
// if e < 0 {
// // dbg!(&e);
// // dbg!(Zq::<Q>(((Q as i64 + e) % Q as i64) as u64));
// // return Zq(((Q as i64 + e) % Q as i64) as u64);
// return Zq(e as u64 % Q);
// } else if e >= Q as i64 {
// return Zq((e % Q as i64) as u64);
// }
// Zq(e as u64)
}
pub fn from_bool(b: bool) -> Self {
if b { Zq(1) } else { Zq(0) }
}
pub fn zero() -> Self {
Zq(0u64)
}
pub fn square(self) -> Self {
self * self
}
// modular exponentiation
pub fn exp(self, e: Self) -> Self {
// mul-square approach
let mut res = Self(1);
let mut rem = e.clone();
let mut exp = self;
// for rem != Self(0) {
while rem != Self(0) {
// if odd
// TODO use a more readible expression
if 1 - ((rem.0 & 1) << 1) as i64 == -1 {
res = res * exp;
}
exp = exp.square();
rem = Self(rem.0 >> 1);
}
res
}
// multiplicative inverse
// WARNING: if this is needed, it means that 'Zq' is a Finite Field. For the moment we assume
// we work in a Finite Field
pub fn inv_OLD(self) -> Self {
// TODO
// let a = self.0;
// let q = Q;
let mut t = 0;
let mut r = Q;
let mut new_t = 0;
let mut new_r = self.0.clone();
while new_r != 0 {
let q = r / new_r;
t = new_t.clone();
new_t = t - q;
r = new_r.clone();
new_r = r - (q * new_r);
}
// if t < 0 {
// t = t + q;
// }
return Zq::from_u64(t);
}
pub fn inv(self) -> Zq<Q> {
let (g, x, _) = Self::egcd(self.0 as i128, Q as i128);
if g != 1 {
// None
panic!("E");
} else {
let q = Q as i128;
Zq(((x % q + q) % q) as u64) // TODO maybe just Zq::new(x)
}
}
fn egcd(a: i128, b: i128) -> (i128, i128, i128) {
if a == 0 {
(b, 0, 1)
} else {
let (g, x, y) = Self::egcd(b % a, a);
(g, y - (b / a) * x, x)
}
}
}
impl<const Q: u64> Zq<Q> {
fn r#mod(self) -> Self {
if self.0 >= Q {
return Zq(self.0 % Q);
}
self
}
}
impl<const Q: u64> ops::Add<Zq<Q>> for Zq<Q> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
let mut r = self.0 + rhs.0;
if r >= Q {
r -= Q;
}
Zq(r)
}
}
impl<const Q: u64> ops::Add<&Zq<Q>> for &Zq<Q> {
type Output = Zq<Q>;
fn add(self, rhs: &Zq<Q>) -> Self::Output {
let mut r = self.0 + rhs.0;
if r >= Q {
r -= Q;
}
Zq(r)
}
}
impl<const Q: u64> ops::AddAssign<Zq<Q>> for Zq<Q> {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs
}
}
impl<const Q: u64> std::iter::Sum for Zq<Q> {
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = Self>,
{
iter.fold(Zq(0), |acc, x| acc + x)
}
}
impl<const Q: u64> ops::Sub<Zq<Q>> for Zq<Q> {
type Output = Self;
fn sub(self, rhs: Self) -> Zq<Q> {
if self.0 >= rhs.0 {
Zq(self.0 - rhs.0)
} else {
Zq((Q + self.0) - rhs.0)
}
}
}
impl<const Q: u64> ops::Sub<&Zq<Q>> for &Zq<Q> {
type Output = Zq<Q>;
fn sub(self, rhs: &Zq<Q>) -> Self::Output {
if self.0 >= rhs.0 {
Zq(self.0 - rhs.0)
} else {
Zq((Q + self.0) - rhs.0)
}
}
}
impl<const Q: u64> ops::SubAssign<Zq<Q>> for Zq<Q> {
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs
}
}
impl<const Q: u64> ops::Neg for Zq<Q> {
type Output = Self;
fn neg(self) -> Self::Output {
Zq(Q - self.0)
}
}
impl<const Q: u64> ops::Mul<Zq<Q>> for Zq<Q> {
type Output = Self;
fn mul(self, rhs: Self) -> Zq<Q> {
// TODO non-naive way
Zq(((self.0 as u128 * rhs.0 as u128) % Q as u128) as u64)
// Zq((self.0 * rhs.0) % Q)
}
}
impl<const Q: u64> ops::Div<Zq<Q>> for Zq<Q> {
type Output = Self;
fn div(self, rhs: Self) -> Zq<Q> {
// TODO non-naive way
// Zq((self.0 / rhs.0) % Q)
self * rhs.inv()
}
}
impl<const Q: u64> fmt::Display for Zq<Q> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl<const Q: u64> fmt::Debug for Zq<Q> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exp() {
const Q: u64 = 1021;
let a = Zq::<Q>(3);
let b = Zq::<Q>(3);
assert_eq!(a.exp(b), Zq(27));
let a = Zq::<Q>(1000);
let b = Zq::<Q>(3);
assert_eq!(a.exp(b), Zq(949));
}
#[test]
fn neg() {
const Q: u64 = 1021;
let a = Zq::<Q>::from_f64(101.0);
let b = Zq::<Q>::from_f64(-1.0);
assert_eq!(-a, a * b);
}
}