polish, tensor & mul with relinearization works for some parameters choice

This commit is contained in:
2025-06-22 23:04:28 +02:00
parent 19457c98dd
commit b968310ce1
7 changed files with 366 additions and 203 deletions

View File

@@ -1,2 +1,2 @@
# arithmetic
# arith
Contains $\mathbb{Z}_q$ and $\mathbb{Z}_q[X]/(X^N+1)$ arithmetic implementations, together with the NTT implementation.

View File

@@ -2,7 +2,7 @@
//! Vandermonde matrix.
use crate::zq::Zq;
use anyhow::{Result, anyhow};
use anyhow::{anyhow, Result};
#[derive(Debug)]
pub struct NTT<const Q: u64, const N: usize> {
@@ -35,6 +35,8 @@ impl<const Q: u64, const N: usize> NTT<Q, N> {
intt,
})
}
/// returns the Vandermonde matrix for the given primitive root of unity.
/// Vandermonde matrix: https://en.wikipedia.org/wiki/Vandermonde_matrix
pub fn vandermonde(primitive: Zq<Q>) -> Vec<Vec<Zq<Q>>> {
let mut v: Vec<Vec<Zq<Q>>> = vec![];
let n = (2 * N) as u64;
@@ -52,6 +54,7 @@ impl<const Q: u64, const N: usize> NTT<Q, N> {
v
}
// specifically for the Vandermonde matrix
/// returns the inverse Vandermonde matrix
pub fn invert_vandermonde(v: &Vec<Vec<Zq<Q>>>) -> Vec<Vec<Zq<Q>>> {
let n = 2 * N;
// let n = N;
@@ -68,6 +71,8 @@ impl<const Q: u64, const N: usize> NTT<Q, N> {
inv
}
/// computes a primitive N-th root of unity using the method described by
/// Thomas Pornin in https://crypto.stackexchange.com/a/63616
pub fn get_primitive_root_of_unity(n: u64) -> Result<Zq<Q>> {
// using the method described by Thomas Pornin in
// https://crypto.stackexchange.com/a/63616
@@ -101,8 +106,8 @@ mod tests {
use super::*;
use rand_distr::Uniform;
use crate::ring::Rq;
use crate::ring::matrix_vec_product;
use crate::ringq::matrix_vec_product;
use crate::ringq::Rq;
#[test]
fn roots_of_unity() -> Result<()> {

View File

@@ -1,18 +1,13 @@
//! Polynomial ring Z[X]/(X^N+1)
//!
use anyhow::{Result, anyhow};
use rand::{Rng, distributions::Distribution};
use std::array;
use std::fmt;
use std::ops;
use crate::ntt::NTT;
use crate::zq::Zq;
// PolynomialRing element, where the PolynomialRing is R = Z[X]/(X^n +1)
#[derive(Clone, Copy, Debug)]
pub struct R<const N: usize>([i64; N]);
#[derive(Clone, Copy)]
pub struct R<const N: usize>(pub [i64; N]);
impl<const Q: u64, const N: usize> From<crate::ringq::Rq<Q, N>> for R<N> {
fn from(rq: crate::ringq::Rq<Q, N>) -> Self {
@@ -157,10 +152,48 @@ pub fn naive_poly_mul<const N: usize>(poly1: &R<N>, poly2: &R<N>) -> R<N> {
}
// apply mod (X^N + 1))
R::<N>::from_vec(result.iter().map(|c| *c as i64).collect())
// R::<N>::from_vec(result.iter().map(|c| *c as i64).collect())
modulus_i128::<N>(&mut result);
// dbg!(&result);
// dbg!(R::<N>(array::from_fn(|i| result[i] as i64)).coeffs());
// sanity check: check that there are no coeffs > i64_max
assert_eq!(
result,
R::<N>(array::from_fn(|i| result[i] as i64))
.coeffs()
.iter()
.map(|c| *c as i128)
.collect::<Vec<_>>()
);
R(array::from_fn(|i| result[i] as i64))
}
pub fn naive_mul_2<const N: usize>(poly1: &Vec<i128>, poly2: &Vec<i128>) -> Vec<i128> {
let mut result: Vec<i128> = vec![0; (N * 2) - 1];
for i in 0..N {
for j in 0..N {
result[i + j] = result[i + j] + poly1[i] * poly2[j];
}
}
// apply mod (X^N + 1))
// R::<N>::from_vec(result.iter().map(|c| *c as i64).collect())
modulus_i128::<N>(&mut result);
result
}
pub fn naive_mul<const N: usize>(poly1: &R<N>, poly2: &R<N>) -> Vec<i64> {
let poly1: Vec<i128> = poly1.0.iter().map(|c| *c as i128).collect();
let poly2: Vec<i128> = poly2.0.iter().map(|c| *c as i128).collect();
let mut result = vec![0; (N * 2) - 1];
for i in 0..N {
for j in 0..N {
result[i + j] = result[i + j] + poly1[i] * poly2[j];
}
}
result.iter().map(|c| *c as i64).collect()
}
pub fn naive_mul_TMP<const N: usize>(poly1: &R<N>, poly2: &R<N>) -> Vec<i64> {
let poly1: Vec<i128> = poly1.0.iter().map(|c| *c as i128).collect();
let poly2: Vec<i128> = poly2.0.iter().map(|c| *c as i128).collect();
let mut result: Vec<i128> = vec![0; (N * 2) - 1];
@@ -170,6 +203,7 @@ pub fn naive_mul<const N: usize>(poly1: &R<N>, poly2: &R<N>) -> Vec<i64> {
}
}
// dbg!(&result);
modulus_i128::<N>(&mut result);
// for c_i in result.iter() {
// println!("---");
@@ -178,19 +212,25 @@ pub fn naive_mul<const N: usize>(poly1: &R<N>, poly2: &R<N>) -> Vec<i64> {
// println!("{:?}", (*c_i as i64) as i128);
// assert_eq!(*c_i, (*c_i as i64) as i128, "{:?}", c_i);
// }
// let q: i128 = 65537;
// let result: Vec<i64> = result
// .iter()
// // .map(|c_i| ((c_i % q + q) % q) as i64)
// .map(|c_i| (c_i % q) as i64)
// // .map(|c_i| *c_i as i64)
// .collect();
// result
result.iter().map(|c| *c as i64).collect()
}
// wip
pub fn mod_centered_q<const Q: u64, const N: usize>(p: Vec<i128>) -> R<N> {
let q: i128 = Q as i128;
let r = p
.iter()
.map(|v| {
let mut res = v % q;
if res > q / 2 {
res = res - q;
}
res
})
.collect::<Vec<i128>>();
R::<N>::from_vec(r.iter().map(|v| *v as i64).collect::<Vec<i64>>())
}
// mul by u64
impl<const N: usize> ops::Mul<u64> for R<N> {
type Output = Self;
@@ -214,3 +254,97 @@ impl<const N: usize> ops::Neg for R<N> {
Self(array::from_fn(|i| -self.0[i]))
}
}
impl<const N: usize> R<N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut str = "";
let mut zero = true;
for (i, coeff) in self.0.iter().enumerate().rev() {
if *coeff == 0 {
continue;
}
zero = false;
f.write_str(str)?;
if *coeff != 1 {
f.write_str(coeff.to_string().as_str())?;
if i > 0 {
f.write_str("*")?;
}
}
if *coeff == 1 && i == 0 {
f.write_str(coeff.to_string().as_str())?;
}
if i == 1 {
f.write_str("x")?;
} else if i > 1 {
f.write_str("x^")?;
f.write_str(i.to_string().as_str())?;
}
str = " + ";
}
if zero {
f.write_str("0")?;
}
f.write_str(" mod Z")?;
f.write_str("/(X^")?;
f.write_str(N.to_string().as_str())?;
f.write_str("+1)")?;
Ok(())
}
}
impl<const N: usize> fmt::Display for R<N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt(f)?;
Ok(())
}
}
impl<const N: usize> fmt::Debug for R<N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt(f)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
#[test]
fn test_mul() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
const N: usize = 2;
let q: i64 = Q as i64;
// test vectors generated with SageMath
let a: [i64; N] = [q - 1, q - 1];
let b: [i64; N] = [q - 1, q - 1];
let c: [i64; N] = [0, 8589934592];
test_mul_opt::<Q, N>(a, b, c)?;
let a: [i64; N] = [1, q - 1];
let b: [i64; N] = [1, q - 1];
let c: [i64; N] = [-4294967295, 131072];
test_mul_opt::<Q, N>(a, b, c)?;
Ok(())
}
fn test_mul_opt<const Q: u64, const N: usize>(
a: [i64; N],
b: [i64; N],
expected_c: [i64; N],
) -> Result<()> {
let mut a = R::new(a);
let mut b = R::new(b);
dbg!(&a);
dbg!(&b);
let expected_c = R::new(expected_c);
let mut c = naive_mul(&mut a, &mut b);
modulus::<N>(&mut c);
dbg!(R::<N>::from_vec(c.clone()));
assert_eq!(c, expected_c.0.to_vec());
Ok(())
}
}

View File

@@ -1,14 +1,14 @@
//! Polynomial ring Z_q[X]/(X^N+1)
//!
use rand::{Rng, distributions::Distribution};
use rand::{distributions::Distribution, Rng};
use std::array;
use std::fmt;
use std::ops;
use crate::ntt::NTT;
use crate::zq::{Zq, modulus_u64};
use anyhow::{Result, anyhow};
use crate::zq::{modulus_u64, Zq};
use anyhow::{anyhow, Result};
/// PolynomialRing element, where the PolynomialRing is R = Z_q[X]/(X^n +1)
/// The implementation assumes that q is prime.
@@ -231,7 +231,10 @@ impl<const Q: u64, const N: usize> Rq<Q, N> {
}
pub fn infinity_norm(&self) -> u64 {
self.coeffs().iter().map(|x| x.0).fold(0, |a, b| a.max(b))
self.coeffs()
.iter()
.map(|x| if x.0 > (Q / 2) { Q - x.0 } else { x.0 })
.fold(0, |a, b| a.max(b))
}
}
pub fn matrix_vec_product<const Q: u64>(m: &Vec<Vec<Zq<Q>>>, v: &Vec<Zq<Q>>) -> Result<Vec<Zq<Q>>> {
@@ -369,6 +372,21 @@ impl<const Q: u64, const N: usize> ops::Mul<&u64> for &Rq<Q, N> {
self.mul_by_u64(*s)
}
}
// mul by f64
impl<const Q: u64, const N: usize> ops::Mul<f64> for Rq<Q, N> {
type Output = Self;
fn mul(self, s: f64) -> Self {
self.mul_by_f64(s)
}
}
impl<const Q: u64, const N: usize> ops::Mul<&f64> for &Rq<Q, N> {
type Output = Rq<Q, N>;
fn mul(self, s: &f64) -> Self::Output {
self.mul_by_f64(*s)
}
}
impl<const Q: u64, const N: usize> ops::Neg for Rq<Q, N> {
type Output = Self;
@@ -473,22 +491,6 @@ mod tests {
);
}
fn test_mul_opt<const Q: u64, const N: usize>(
a: [u64; N],
b: [u64; N],
expected_c: [u64; N],
) -> Result<()> {
let a: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(a[i]));
let mut a = Rq::new(a, None);
let b: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(b[i]));
let mut b = Rq::new(b, None);
let expected_c: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(expected_c[i]));
let expected_c = Rq::new(expected_c, None);
let c = mul_mut(&mut a, &mut b);
assert_eq!(c, expected_c);
Ok(())
}
#[test]
fn test_mul() -> Result<()> {
const Q: u64 = 2u64.pow(16) + 1;
@@ -508,4 +510,20 @@ mod tests {
Ok(())
}
fn test_mul_opt<const Q: u64, const N: usize>(
a: [u64; N],
b: [u64; N],
expected_c: [u64; N],
) -> Result<()> {
let a: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(a[i]));
let mut a = Rq::new(a, None);
let b: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(b[i]));
let mut b = Rq::new(b, None);
let expected_c: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(expected_c[i]));
let expected_c = Rq::new(expected_c, None);
let c = mul_mut(&mut a, &mut b);
assert_eq!(c, expected_c);
Ok(())
}
}

View File

@@ -44,7 +44,11 @@ impl<const Q: u64> Zq<Q> {
// Zq(e as u64)
}
pub fn from_bool(b: bool) -> Self {
if b { Zq(1) } else { Zq(0) }
if b {
Zq(1)
} else {
Zq(0)
}
}
pub fn zero() -> Self {
Zq(0u64)