Browse Source

add polynomial ring (Rq) impl

gfhe-over-ring-trait
arnaucube 1 month ago
parent
commit
182fd518fe
2 changed files with 282 additions and 0 deletions
  1. +2
    -0
      arithmetic/src/lib.rs
  2. +280
    -0
      arithmetic/src/ring.rs

+ 2
- 0
arithmetic/src/lib.rs

@ -4,6 +4,8 @@
#![allow(clippy::upper_case_acronyms)]
#![allow(dead_code)] // TMP
pub mod ring;
pub mod zq;
pub use ring::PR;
pub use zq::Zq;

+ 280
- 0
arithmetic/src/ring.rs

@ -0,0 +1,280 @@
use rand::{distributions::Distribution, Rng};
use std::array;
use std::fmt;
use std::ops;
use crate::zq::Zq;
use anyhow::{anyhow, Result};
// PolynomialRing element, where the PolynomialRing is R = Z_q[X]/(X^n +1)
#[derive(Clone, Copy)]
pub struct PR<const Q: u64, const N: usize> {
pub(crate) coeffs: [Zq<Q>; N],
// evals are set when doig a PRxPR multiplication, so it can be reused in future
// multiplications avoiding recomputing it
pub(crate) evals: Option<[Zq<Q>; N]>,
}
// TODO define a trait "PolynomialRingTrait" or similar, so that when other structs use it can just
// use the trait and not need to add '<Q, N>' to their params
// apply mod (X^N+1)
pub fn modulus<const Q: u64, const N: usize>(p: &mut Vec<Zq<Q>>) {
if p.len() < N {
return;
}
for i in N..p.len() {
p[i - N] = p[i - N].clone() - p[i].clone();
p[i] = Zq(0);
}
p.truncate(N);
}
// PR stands for PolynomialRing
impl<const Q: u64, const N: usize> PR<Q, N> {
pub fn coeffs(&self) -> [Zq<Q>; N] {
self.coeffs
}
pub fn from_vec(coeffs: Vec<Zq<Q>>) -> Self {
let mut p = coeffs;
modulus::<Q, N>(&mut p);
let coeffs = array::from_fn(|i| p[i]);
Self {
coeffs,
evals: None,
}
}
// this method is mostly for tests
pub fn from_vec_u64(coeffs: Vec<u64>) -> Self {
let coeffs_mod_q = coeffs.iter().map(|c| Zq::new(*c)).collect();
Self::from_vec(coeffs_mod_q)
}
pub fn new(coeffs: [Zq<Q>; N], evals: Option<[Zq<Q>; N]>) -> Self {
Self { coeffs, evals }
}
pub fn rand_abs(mut rng: impl Rng, dist: impl Distribution<f64>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng).abs()));
Ok(Self {
coeffs,
evals: None,
})
}
pub fn rand(mut rng: impl Rng, dist: impl Distribution<f64>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng)));
Ok(Self {
coeffs,
evals: None,
})
}
// WIP. returns random v \in {0,1}. // TODO {-1, 0, 1}
pub fn rand_bin(mut rng: impl Rng, dist: impl Distribution<bool>) -> Result<Self> {
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_bool(dist.sample(&mut rng)));
Ok(PR {
coeffs,
evals: None,
})
}
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// TODO simplify
let mut str = "";
let mut zero = true;
for (i, coeff) in self.coeffs.iter().enumerate().rev() {
if coeff.0 == 0 {
continue;
}
zero = false;
f.write_str(str)?;
if coeff.0 != 1 {
f.write_str(coeff.0.to_string().as_str())?;
if i > 0 {
f.write_str("*")?;
}
}
if coeff.0 == 1 && i == 0 {
f.write_str(coeff.0.to_string().as_str())?;
}
if i == 1 {
f.write_str("x")?;
} else if i > 1 {
f.write_str("x^")?;
f.write_str(i.to_string().as_str())?;
}
str = " + ";
}
if zero {
f.write_str("0")?;
}
f.write_str(" mod Z_")?;
f.write_str(Q.to_string().as_str())?;
f.write_str("/(X^")?;
f.write_str(N.to_string().as_str())?;
f.write_str("+1)")?;
Ok(())
}
}
pub fn matrix_vec_product<const Q: u64>(m: &Vec<Vec<Zq<Q>>>, v: &Vec<Zq<Q>>) -> Result<Vec<Zq<Q>>> {
// assert_eq!(m.len(), m[0].len()); // TODO change to returning err
// assert_eq!(m.len(), v.len());
if m.len() != m[0].len() {
return Err(anyhow!("expected 'm' to be a square matrix"));
}
if m.len() != v.len() {
return Err(anyhow!(
"m.len: {} should be equal to v.len(): {}",
m.len(),
v.len(),
));
}
Ok(m.iter()
.map(|row| {
row.iter()
.zip(v.iter())
.map(|(&row_i, &v_i)| row_i * v_i)
.sum()
})
.collect::<Vec<Zq<Q>>>())
}
pub fn transpose<const Q: u64>(m: &[Vec<Zq<Q>>]) -> Vec<Vec<Zq<Q>>> {
// TODO case when m[0].len()=0
// TODO non square matrix
let mut r: Vec<Vec<Zq<Q>>> = vec![vec![Zq(0); m[0].len()]; m.len()];
for (i, m_row) in m.iter().enumerate() {
for (j, m_ij) in m_row.iter().enumerate() {
r[j][i] = *m_ij;
}
}
r
}
impl<const Q: u64, const N: usize> PartialEq for PR<Q, N> {
fn eq(&self, other: &Self) -> bool {
self.coeffs == other.coeffs
}
}
impl<const Q: u64, const N: usize> ops::Add<PR<Q, N>> for PR<Q, N> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self {
coeffs: array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]),
evals: None,
}
// Self {
// coeffs: self
// .coeffs
// .iter()
// .zip(rhs.coeffs)
// .map(|(a, b)| *a + b)
// .collect(),
// evals: None,
// }
// Self(r.iter_mut().map(|e| e.r#mod()).collect()) // TODO mod should happen auto in +
}
}
impl<const Q: u64, const N: usize> ops::Add<&PR<Q, N>> for &PR<Q, N> {
type Output = PR<Q, N>;
fn add(self, rhs: &PR<Q, N>) -> Self::Output {
PR {
coeffs: array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]),
evals: None,
}
}
}
impl<const Q: u64, const N: usize> ops::Sub<PR<Q, N>> for PR<Q, N> {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Self {
coeffs: array::from_fn(|i| self.coeffs[i] - rhs.coeffs[i]),
evals: None,
}
}
}
impl<const Q: u64, const N: usize> ops::Sub<&PR<Q, N>> for &PR<Q, N> {
type Output = PR<Q, N>;
fn sub(self, rhs: &PR<Q, N>) -> Self::Output {
PR {
coeffs: array::from_fn(|i| self.coeffs[i] - rhs.coeffs[i]),
evals: None,
}
}
}
impl<const Q: u64, const N: usize> ops::Neg for PR<Q, N> {
type Output = Self;
fn neg(self) -> Self::Output {
Self {
coeffs: array::from_fn(|i| -self.coeffs[i]),
evals: None,
}
}
}
impl<const Q: u64, const N: usize> fmt::Display for PR<Q, N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt(f)?;
Ok(())
}
}
impl<const Q: u64, const N: usize> fmt::Debug for PR<Q, N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt(f)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn poly_ring() {
// the test values used are generated with SageMath
const Q: u64 = 7;
const N: usize = 3;
// p = 1x + 2x^2 + 3x^3 + 4 x^4 + 5 x^5 in R=Z_q[X]/(X^n +1)
let p = PR::<Q, N>::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]);
assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
// try with coefficients bigger than Q
let p = PR::<Q, N>::from_vec_u64(vec![0u64, 1, Q + 2, 3, 4, 5]);
assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
// try with other ring
let p = PR::<7, 4>::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]);
assert_eq!(p.to_string(), "3*x^3 + 2*x^2 + 3*x + 3 mod Z_7/(X^4+1)");
let p = PR::<Q, N>::from_vec_u64(vec![0u64, 0, 0, 0, 4, 5]);
assert_eq!(p.to_string(), "2*x^2 + 3*x mod Z_7/(X^3+1)");
let p = PR::<Q, N>::from_vec_u64(vec![5u64, 4, 5, 2, 1, 0]);
assert_eq!(p.to_string(), "5*x^2 + 3*x + 3 mod Z_7/(X^3+1)");
let a = PR::<Q, N>::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]);
assert_eq!(a.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
let b = PR::<Q, N>::from_vec_u64(vec![5u64, 4, 3, 2, 1, 0]);
assert_eq!(b.to_string(), "3*x^2 + 3*x + 3 mod Z_7/(X^3+1)");
// add
assert_eq!((a.clone() + b.clone()).to_string(), "0 mod Z_7/(X^3+1)");
assert_eq!((&a + &b).to_string(), "0 mod Z_7/(X^3+1)");
// assert_eq!((a.0.clone() + b.0.clone()).to_string(), "[0, 0, 0]"); // TODO
// sub
assert_eq!(
(a.clone() - b.clone()).to_string(),
"x^2 + x + 1 mod Z_7/(X^3+1)"
);
}
}

Loading…
Cancel
Save