mirror of
https://github.com/arnaucube/fhe-study.git
synced 2026-01-23 20:23:54 +01:00
add polynomial ring (Rq) impl
This commit is contained in:
@@ -4,6 +4,8 @@
|
||||
#![allow(clippy::upper_case_acronyms)]
|
||||
#![allow(dead_code)] // TMP
|
||||
|
||||
pub mod ring;
|
||||
pub mod zq;
|
||||
|
||||
pub use ring::PR;
|
||||
pub use zq::Zq;
|
||||
|
||||
280
arithmetic/src/ring.rs
Normal file
280
arithmetic/src/ring.rs
Normal file
@@ -0,0 +1,280 @@
|
||||
use rand::{distributions::Distribution, Rng};
|
||||
use std::array;
|
||||
use std::fmt;
|
||||
use std::ops;
|
||||
|
||||
use crate::zq::Zq;
|
||||
use anyhow::{anyhow, Result};
|
||||
|
||||
// PolynomialRing element, where the PolynomialRing is R = Z_q[X]/(X^n +1)
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct PR<const Q: u64, const N: usize> {
|
||||
pub(crate) coeffs: [Zq<Q>; N],
|
||||
|
||||
// evals are set when doig a PRxPR multiplication, so it can be reused in future
|
||||
// multiplications avoiding recomputing it
|
||||
pub(crate) evals: Option<[Zq<Q>; N]>,
|
||||
}
|
||||
|
||||
// TODO define a trait "PolynomialRingTrait" or similar, so that when other structs use it can just
|
||||
// use the trait and not need to add '<Q, N>' to their params
|
||||
|
||||
// apply mod (X^N+1)
|
||||
pub fn modulus<const Q: u64, const N: usize>(p: &mut Vec<Zq<Q>>) {
|
||||
if p.len() < N {
|
||||
return;
|
||||
}
|
||||
for i in N..p.len() {
|
||||
p[i - N] = p[i - N].clone() - p[i].clone();
|
||||
p[i] = Zq(0);
|
||||
}
|
||||
p.truncate(N);
|
||||
}
|
||||
|
||||
// PR stands for PolynomialRing
|
||||
impl<const Q: u64, const N: usize> PR<Q, N> {
|
||||
pub fn coeffs(&self) -> [Zq<Q>; N] {
|
||||
self.coeffs
|
||||
}
|
||||
|
||||
pub fn from_vec(coeffs: Vec<Zq<Q>>) -> Self {
|
||||
let mut p = coeffs;
|
||||
modulus::<Q, N>(&mut p);
|
||||
let coeffs = array::from_fn(|i| p[i]);
|
||||
Self {
|
||||
coeffs,
|
||||
evals: None,
|
||||
}
|
||||
}
|
||||
// this method is mostly for tests
|
||||
pub fn from_vec_u64(coeffs: Vec<u64>) -> Self {
|
||||
let coeffs_mod_q = coeffs.iter().map(|c| Zq::new(*c)).collect();
|
||||
Self::from_vec(coeffs_mod_q)
|
||||
}
|
||||
pub fn new(coeffs: [Zq<Q>; N], evals: Option<[Zq<Q>; N]>) -> Self {
|
||||
Self { coeffs, evals }
|
||||
}
|
||||
|
||||
pub fn rand_abs(mut rng: impl Rng, dist: impl Distribution<f64>) -> Result<Self> {
|
||||
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng).abs()));
|
||||
Ok(Self {
|
||||
coeffs,
|
||||
evals: None,
|
||||
})
|
||||
}
|
||||
pub fn rand(mut rng: impl Rng, dist: impl Distribution<f64>) -> Result<Self> {
|
||||
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng)));
|
||||
Ok(Self {
|
||||
coeffs,
|
||||
evals: None,
|
||||
})
|
||||
}
|
||||
// WIP. returns random v \in {0,1}. // TODO {-1, 0, 1}
|
||||
pub fn rand_bin(mut rng: impl Rng, dist: impl Distribution<bool>) -> Result<Self> {
|
||||
let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_bool(dist.sample(&mut rng)));
|
||||
Ok(PR {
|
||||
coeffs,
|
||||
evals: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
// TODO simplify
|
||||
let mut str = "";
|
||||
let mut zero = true;
|
||||
for (i, coeff) in self.coeffs.iter().enumerate().rev() {
|
||||
if coeff.0 == 0 {
|
||||
continue;
|
||||
}
|
||||
zero = false;
|
||||
f.write_str(str)?;
|
||||
if coeff.0 != 1 {
|
||||
f.write_str(coeff.0.to_string().as_str())?;
|
||||
if i > 0 {
|
||||
f.write_str("*")?;
|
||||
}
|
||||
}
|
||||
if coeff.0 == 1 && i == 0 {
|
||||
f.write_str(coeff.0.to_string().as_str())?;
|
||||
}
|
||||
if i == 1 {
|
||||
f.write_str("x")?;
|
||||
} else if i > 1 {
|
||||
f.write_str("x^")?;
|
||||
f.write_str(i.to_string().as_str())?;
|
||||
}
|
||||
str = " + ";
|
||||
}
|
||||
if zero {
|
||||
f.write_str("0")?;
|
||||
}
|
||||
|
||||
f.write_str(" mod Z_")?;
|
||||
f.write_str(Q.to_string().as_str())?;
|
||||
f.write_str("/(X^")?;
|
||||
f.write_str(N.to_string().as_str())?;
|
||||
f.write_str("+1)")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
pub fn matrix_vec_product<const Q: u64>(m: &Vec<Vec<Zq<Q>>>, v: &Vec<Zq<Q>>) -> Result<Vec<Zq<Q>>> {
|
||||
// assert_eq!(m.len(), m[0].len()); // TODO change to returning err
|
||||
// assert_eq!(m.len(), v.len());
|
||||
if m.len() != m[0].len() {
|
||||
return Err(anyhow!("expected 'm' to be a square matrix"));
|
||||
}
|
||||
if m.len() != v.len() {
|
||||
return Err(anyhow!(
|
||||
"m.len: {} should be equal to v.len(): {}",
|
||||
m.len(),
|
||||
v.len(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(m.iter()
|
||||
.map(|row| {
|
||||
row.iter()
|
||||
.zip(v.iter())
|
||||
.map(|(&row_i, &v_i)| row_i * v_i)
|
||||
.sum()
|
||||
})
|
||||
.collect::<Vec<Zq<Q>>>())
|
||||
}
|
||||
pub fn transpose<const Q: u64>(m: &[Vec<Zq<Q>>]) -> Vec<Vec<Zq<Q>>> {
|
||||
// TODO case when m[0].len()=0
|
||||
// TODO non square matrix
|
||||
let mut r: Vec<Vec<Zq<Q>>> = vec![vec![Zq(0); m[0].len()]; m.len()];
|
||||
for (i, m_row) in m.iter().enumerate() {
|
||||
for (j, m_ij) in m_row.iter().enumerate() {
|
||||
r[j][i] = *m_ij;
|
||||
}
|
||||
}
|
||||
r
|
||||
}
|
||||
|
||||
impl<const Q: u64, const N: usize> PartialEq for PR<Q, N> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.coeffs == other.coeffs
|
||||
}
|
||||
}
|
||||
impl<const Q: u64, const N: usize> ops::Add<PR<Q, N>> for PR<Q, N> {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self {
|
||||
Self {
|
||||
coeffs: array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]),
|
||||
evals: None,
|
||||
}
|
||||
// Self {
|
||||
// coeffs: self
|
||||
// .coeffs
|
||||
// .iter()
|
||||
// .zip(rhs.coeffs)
|
||||
// .map(|(a, b)| *a + b)
|
||||
// .collect(),
|
||||
// evals: None,
|
||||
// }
|
||||
// Self(r.iter_mut().map(|e| e.r#mod()).collect()) // TODO mod should happen auto in +
|
||||
}
|
||||
}
|
||||
impl<const Q: u64, const N: usize> ops::Add<&PR<Q, N>> for &PR<Q, N> {
|
||||
type Output = PR<Q, N>;
|
||||
|
||||
fn add(self, rhs: &PR<Q, N>) -> Self::Output {
|
||||
PR {
|
||||
coeffs: array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]),
|
||||
evals: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<const Q: u64, const N: usize> ops::Sub<PR<Q, N>> for PR<Q, N> {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self {
|
||||
Self {
|
||||
coeffs: array::from_fn(|i| self.coeffs[i] - rhs.coeffs[i]),
|
||||
evals: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<const Q: u64, const N: usize> ops::Sub<&PR<Q, N>> for &PR<Q, N> {
|
||||
type Output = PR<Q, N>;
|
||||
|
||||
fn sub(self, rhs: &PR<Q, N>) -> Self::Output {
|
||||
PR {
|
||||
coeffs: array::from_fn(|i| self.coeffs[i] - rhs.coeffs[i]),
|
||||
evals: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const Q: u64, const N: usize> ops::Neg for PR<Q, N> {
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
Self {
|
||||
coeffs: array::from_fn(|i| -self.coeffs[i]),
|
||||
evals: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const Q: u64, const N: usize> fmt::Display for PR<Q, N> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
self.fmt(f)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
impl<const Q: u64, const N: usize> fmt::Debug for PR<Q, N> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
self.fmt(f)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn poly_ring() {
|
||||
// the test values used are generated with SageMath
|
||||
const Q: u64 = 7;
|
||||
const N: usize = 3;
|
||||
|
||||
// p = 1x + 2x^2 + 3x^3 + 4 x^4 + 5 x^5 in R=Z_q[X]/(X^n +1)
|
||||
let p = PR::<Q, N>::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]);
|
||||
assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
|
||||
|
||||
// try with coefficients bigger than Q
|
||||
let p = PR::<Q, N>::from_vec_u64(vec![0u64, 1, Q + 2, 3, 4, 5]);
|
||||
assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
|
||||
|
||||
// try with other ring
|
||||
let p = PR::<7, 4>::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]);
|
||||
assert_eq!(p.to_string(), "3*x^3 + 2*x^2 + 3*x + 3 mod Z_7/(X^4+1)");
|
||||
|
||||
let p = PR::<Q, N>::from_vec_u64(vec![0u64, 0, 0, 0, 4, 5]);
|
||||
assert_eq!(p.to_string(), "2*x^2 + 3*x mod Z_7/(X^3+1)");
|
||||
|
||||
let p = PR::<Q, N>::from_vec_u64(vec![5u64, 4, 5, 2, 1, 0]);
|
||||
assert_eq!(p.to_string(), "5*x^2 + 3*x + 3 mod Z_7/(X^3+1)");
|
||||
|
||||
let a = PR::<Q, N>::from_vec_u64(vec![0u64, 1, 2, 3, 4, 5]);
|
||||
assert_eq!(a.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
|
||||
|
||||
let b = PR::<Q, N>::from_vec_u64(vec![5u64, 4, 3, 2, 1, 0]);
|
||||
assert_eq!(b.to_string(), "3*x^2 + 3*x + 3 mod Z_7/(X^3+1)");
|
||||
|
||||
// add
|
||||
assert_eq!((a.clone() + b.clone()).to_string(), "0 mod Z_7/(X^3+1)");
|
||||
assert_eq!((&a + &b).to_string(), "0 mod Z_7/(X^3+1)");
|
||||
// assert_eq!((a.0.clone() + b.0.clone()).to_string(), "[0, 0, 0]"); // TODO
|
||||
|
||||
// sub
|
||||
assert_eq!(
|
||||
(a.clone() - b.clone()).to_string(),
|
||||
"x^2 + x + 1 mod Z_7/(X^3+1)"
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user