group ring params under a single struct

This commit is contained in:
2025-08-12 12:09:51 +00:00
parent 0c7e078aeb
commit 9e90f094a9
10 changed files with 295 additions and 261 deletions

View File

@@ -25,7 +25,7 @@ pub use matrix::Matrix;
// pub use torus::T64;
pub use zq::Zq;
pub use ring::Ring;
pub use ring::{Ring, RingParam};
pub use ring_n::R;
pub use ring_nq::Rq;
// pub use ring_torus::Tn;

View File

@@ -6,7 +6,11 @@
//! generics; but once using real-world parameters, the stack could not handle
//! it, so moved to use Vec instead of fixed-sized arrays, and adapted the NTT
//! implementation to that too.
use crate::{ring::Ring, ring_nq::Rq, zq::Zq};
use crate::{
ring::{Ring, RingParam},
ring_nq::Rq,
zq::Zq,
};
use std::collections::HashMap;
@@ -50,7 +54,7 @@ impl NTT {
/// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.1 of
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
pub fn ntt(a: &Rq) -> Rq {
let (q, n) = (a.q, a.n);
let (q, n) = (a.param.q, a.param.n);
let (roots_of_unity, _, _) = roots(q, n);
let mut t = n / 2;
@@ -73,8 +77,7 @@ impl NTT {
}
// Rq::from_vec((a.q, n), r)
Rq {
q,
n,
param: RingParam { q, n },
coeffs: r,
evals: None,
}
@@ -84,7 +87,7 @@ impl NTT {
/// https://eprint.iacr.org/2017/727.pdf, also some notes at section 3.2 of
/// https://github.com/arnaucube/math/blob/master/notes_ntt.pdf
pub fn intt(a: &Rq) -> Rq {
let (q, n) = (a.q, a.n);
let (q, n) = (a.param.q, a.param.n);
let (_, roots_of_unity_inv, n_inv) = roots(q, n);
let mut t = 1;
@@ -110,8 +113,7 @@ impl NTT {
}
// Rq::from_vec((a.q, n), r)
Rq {
q,
n,
param: RingParam { q, n },
coeffs: r,
evals: None,
}
@@ -202,9 +204,10 @@ mod tests {
fn test_ntt() -> Result<()> {
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 4;
let param = RingParam { q, n };
let a: Vec<u64> = vec![1u64, 2, 3, 4];
let a: Rq = Rq::from_vec_u64(q, n, a);
let a: Rq = Rq::from_vec_u64(&param, a);
let a_ntt = NTT::ntt(&a);
@@ -224,13 +227,14 @@ mod tests {
fn test_ntt_loop() -> Result<()> {
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 512;
let param = RingParam { q, n };
use rand::distributions::Uniform;
let mut rng = rand::thread_rng();
let dist = Uniform::new(0_f64, q as f64);
for _ in 0..10_000 {
let a: Rq = Rq::rand(&mut rng, dist, (q, n));
for _ in 0..1000 {
let a: Rq = Rq::rand(&mut rng, dist, &param);
let a_ntt = NTT::ntt(&a);
let a_intt = NTT::intt(&a_ntt);
assert_eq!(a, a_intt);

View File

@@ -3,6 +3,12 @@ use std::fmt::Debug;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct RingParam {
pub q: u64, // TODO think if really needed or it's fine with coeffs[0].q
pub n: usize,
}
/// Represents a ring element. Currently implemented by ring_nq.rs#Rq and
/// ring_torus.rs#Tn. Is not a 'pure algebraic ring', but more a custom trait
/// definition which includes methods like `mod_switch`.
@@ -27,17 +33,18 @@ pub trait Ring:
{
/// C defines the coefficient type
type C: Debug + Clone;
type Params: Debug+Clone+Copy;
// type Param: Debug+Clone+Copy;
// const Q: u64;
// const N: usize;
fn param(&self) -> RingParam;
fn coeffs(&self) -> Vec<Self::C>;
fn zero(params: Self::Params) -> Self;
fn zero(param: &RingParam) -> Self;
// note/wip/warning: dist (0,q) with f64, will output more '0=q' elements than other values
fn rand(rng: impl Rng, dist: impl Distribution<f64>, params: Self::Params) -> Self;
fn rand(rng: impl Rng, dist: impl Distribution<f64>, param: &RingParam) -> Self;
fn from_vec(params: Self::Params, coeffs: Vec<Self::C>) -> Self;
fn from_vec(param: &RingParam, coeffs: Vec<Self::C>) -> Self;
fn decompose(&self, beta: u32, l: u32) -> Vec<Self>;

View File

@@ -23,7 +23,7 @@ pub struct R {
// impl<const N: usize> Ring for R<N> {
impl R {
// type C = i64;
// type Params = usize; // n
// type Param = usize; // n
// const Q: u64 = i64::MAX as u64; // WIP
// const N: usize = N;
@@ -87,7 +87,10 @@ impl R {
impl From<crate::ring_nq::Rq> for R {
fn from(rq: crate::ring_nq::Rq) -> Self {
Self::from_vec_u64(rq.n, rq.coeffs().to_vec().iter().map(|e| e.v).collect())
Self::from_vec_u64(
rq.param.n,
rq.coeffs().to_vec().iter().map(|e| e.v).collect(),
)
}
}
@@ -150,7 +153,7 @@ pub fn mul_div_round(q: u64, n: usize, v: Vec<i64>, num: u64, den: u64) -> crate
.map(|e| ((num as f64 * *e as f64) / den as f64).round())
.collect();
// dbg!(&r);
crate::Rq::from_vec_f64(q, n, r)
crate::Rq::from_vec_f64(&crate::ring::RingParam { q, n }, r)
}
// TODO rename to make it clear that is not mod q, but mod X^N+1

View File

@@ -13,7 +13,7 @@ use std::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign};
use crate::ntt::NTT;
use crate::zq::{modulus_u64, Zq};
use crate::Ring;
use crate::{Ring, RingParam};
// NOTE: currently using fixed-size arrays, but pending to see if with
// real-world parameters the stack can keep up; if not will move everything to
@@ -22,8 +22,7 @@ use crate::Ring;
/// The implementation assumes that q is prime.
#[derive(Clone)]
pub struct Rq {
pub q: u64, // TODO think if really needed or it's fine with coeffs[0].q
pub n: usize,
pub param: RingParam,
pub(crate) coeffs: Vec<Zq>,
@@ -34,42 +33,41 @@ pub struct Rq {
impl Ring for Rq {
type C = Zq;
type Params = (u64, usize);
// type Param = (u64, usize);
// type Param = Param;
fn param(&self) -> RingParam {
self.param
}
fn coeffs(&self) -> Vec<Self::C> {
self.coeffs.to_vec()
}
// fn zero(q: u64, n: usize) -> Self {
fn zero(param: (u64, usize)) -> Self {
let (q, n) = param;
// fn zero(param: (u64, usize)) -> Self {
fn zero(param: &RingParam) -> Self {
Self {
q,
n,
coeffs: vec![Zq::zero(q); n],
param: param.clone(),
coeffs: vec![Zq::zero(param.q); param.n],
evals: None,
}
}
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>, params: Self::Params) -> Self {
fn rand(mut rng: impl Rng, dist: impl Distribution<f64>, param: &RingParam) -> Self {
// let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_u64(dist.sample(&mut rng)));
// let coeffs: [Zq<Q>; N] = array::from_fn(|_| Self::C::rand(&mut rng, &dist));
let (q, n) = params;
Self {
q,
n,
coeffs: std::iter::repeat_with(|| Self::C::rand(&mut rng, &dist, q))
.take(n)
param: param.clone(),
coeffs: std::iter::repeat_with(|| Self::C::rand(&mut rng, &dist, param.q))
.take(param.n)
.collect(),
evals: None,
}
}
fn from_vec(params: Self::Params, coeffs: Vec<Zq>) -> Self {
let (q, n) = params;
fn from_vec(param: &RingParam, coeffs: Vec<Zq>) -> Self {
let mut p = coeffs;
modulus(q, n, &mut p);
modulus(param.q, param.n, &mut p);
Self {
q,
n,
param: param.clone(),
coeffs: p,
evals: None,
}
@@ -85,7 +83,7 @@ impl Ring for Rq {
.collect();
// convert it to Rq<Q,N>
r.iter()
.map(|a_i| Self::from_vec((self.q, self.n), a_i.clone()))
.map(|a_i| Self::from_vec(&self.param, a_i.clone()))
.collect()
}
@@ -93,16 +91,26 @@ impl Ring for Rq {
// if Q<P, it just 'renames' the modulus parameter to P
// if Q>=P, it crops to mod P
fn remodule(&self, p: u64) -> Rq {
Rq::from_vec_u64(p, self.n, self.coeffs().iter().map(|m_i| m_i.v).collect())
let param = RingParam {
q: p,
n: self.param.n,
};
// Rq::from_vec_u64(p, self.n, self.coeffs().iter().map(|m_i| m_i.v).collect())
Rq::from_vec_u64(&param, self.coeffs().iter().map(|m_i| m_i.v).collect())
}
/// perform the mod switch operation from Q to Q', where Q2=Q'
// fn mod_switch<const P: u64, const M: usize>(&self) -> impl Ring {
fn mod_switch(&self, p: u64) -> Rq {
let param = RingParam {
q: p,
n: self.param.n,
};
// assert_eq!(N, M); // sanity check
Rq {
q: p,
n: self.n,
param,
// q: p,
// n: self.n,
// coeffs: array::from_fn(|i| self.coeffs[i].mod_switch::<P>()),
coeffs: self.coeffs.iter().map(|c_i| c_i.mod_switch(p)).collect(),
evals: None,
@@ -118,7 +126,7 @@ impl Ring for Rq {
.iter()
.map(|e| ((num as f64 * e.v as f64) / den as f64).round())
.collect();
Rq::from_vec_f64(self.q, self.n, r)
Rq::from_vec_f64(&self.param, r)
}
}
@@ -128,7 +136,8 @@ impl From<(u64, crate::ring_n::R)> for Rq {
assert_eq!(r.n, r.coeffs.len());
Self::from_vec(
(q, r.n),
&RingParam { q, n: r.n },
// (q, r.n),
r.coeffs()
.iter()
.map(|e| Zq::from_f64(q, *e as f64))
@@ -170,22 +179,24 @@ impl Rq {
// }
// }
// this method is mostly for tests
pub fn from_vec_u64(q: u64, n: usize, coeffs: Vec<u64>) -> Self {
let coeffs_mod_q: Vec<Zq> = coeffs.iter().map(|c| Zq::from_u64(q, *c)).collect();
Self::from_vec((q, n), coeffs_mod_q)
pub fn from_vec_u64(param: &RingParam, coeffs: Vec<u64>) -> Self {
let coeffs_mod_q: Vec<Zq> = coeffs.iter().map(|c| Zq::from_u64(param.q, *c)).collect();
Self::from_vec(param, coeffs_mod_q)
}
pub fn from_vec_f64(q: u64, n: usize, coeffs: Vec<f64>) -> Self {
let coeffs_mod_q: Vec<Zq> = coeffs.iter().map(|c| Zq::from_f64(q, *c)).collect();
Self::from_vec((q, n), coeffs_mod_q)
pub fn from_vec_f64(param: &RingParam, coeffs: Vec<f64>) -> Self {
let coeffs_mod_q: Vec<Zq> = coeffs.iter().map(|c| Zq::from_f64(param.q, *c)).collect();
Self::from_vec(param, coeffs_mod_q)
}
pub fn from_vec_i64(q: u64, n: usize, coeffs: Vec<i64>) -> Self {
let coeffs_mod_q: Vec<Zq> = coeffs.iter().map(|c| Zq::from_f64(q, *c as f64)).collect();
Self::from_vec((q, n), coeffs_mod_q)
pub fn from_vec_i64(param: &RingParam, coeffs: Vec<i64>) -> Self {
let coeffs_mod_q: Vec<Zq> = coeffs
.iter()
.map(|c| Zq::from_f64(param.q, *c as f64))
.collect();
Self::from_vec(param, coeffs_mod_q)
}
pub fn new(q: u64, n: usize, coeffs: Vec<Zq>, evals: Option<Vec<Zq>>) -> Self {
pub fn new(param: &RingParam, coeffs: Vec<Zq>, evals: Option<Vec<Zq>>) -> Self {
Self {
q,
n,
param: *param,
coeffs,
evals,
}
@@ -194,15 +205,17 @@ impl Rq {
pub fn rand_abs(
mut rng: impl Rng,
dist: impl Distribution<f64>,
q: u64,
n: usize,
param: &RingParam,
// q: u64,
// n: usize,
) -> Result<Self> {
// let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng).abs()));
Ok(Self {
q,
n,
coeffs: std::iter::repeat_with(|| Zq::from_f64(q, dist.sample(&mut rng).abs()))
.take(n)
param: *param,
// q,
// n,
coeffs: std::iter::repeat_with(|| Zq::from_f64(param.q, dist.sample(&mut rng).abs()))
.take(param.n)
.collect(),
evals: None,
})
@@ -210,15 +223,17 @@ impl Rq {
pub fn rand_f64_abs(
mut rng: impl Rng,
dist: impl Distribution<f64>,
q: u64,
n: usize,
param: &RingParam,
// q: u64,
// n: usize,
) -> Result<Self> {
// let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng).abs()));
Ok(Self {
q,
n,
coeffs: std::iter::repeat_with(|| Zq::from_f64(q, dist.sample(&mut rng).abs()))
.take(n)
param: *param,
// q,
// n,
coeffs: std::iter::repeat_with(|| Zq::from_f64(param.q, dist.sample(&mut rng).abs()))
.take(param.n)
.collect(),
evals: None,
})
@@ -226,15 +241,13 @@ impl Rq {
pub fn rand_f64(
mut rng: impl Rng,
dist: impl Distribution<f64>,
q: u64,
n: usize,
param: &RingParam,
) -> Result<Self> {
// let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_f64(dist.sample(&mut rng)));
Ok(Self {
q,
n,
coeffs: std::iter::repeat_with(|| Zq::from_f64(q, dist.sample(&mut rng)))
.take(n)
param: *param,
coeffs: std::iter::repeat_with(|| Zq::from_f64(param.q, dist.sample(&mut rng)))
.take(param.n)
.collect(),
evals: None,
})
@@ -242,15 +255,13 @@ impl Rq {
pub fn rand_u64(
mut rng: impl Rng,
dist: impl Distribution<u64>,
q: u64,
n: usize,
param: &RingParam,
) -> Result<Self> {
// let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_u64(dist.sample(&mut rng)));
Ok(Self {
q,
n,
coeffs: std::iter::repeat_with(|| Zq::from_u64(q, dist.sample(&mut rng)))
.take(n)
param: *param,
coeffs: std::iter::repeat_with(|| Zq::from_u64(param.q, dist.sample(&mut rng)))
.take(param.n)
.collect(),
evals: None,
})
@@ -259,15 +270,13 @@ impl Rq {
pub fn rand_bin(
mut rng: impl Rng,
dist: impl Distribution<bool>,
q: u64,
n: usize,
param: &RingParam,
) -> Result<Self> {
// let coeffs: [Zq<Q>; N] = array::from_fn(|_| Zq::from_bool(dist.sample(&mut rng)));
Ok(Rq {
q,
n,
coeffs: std::iter::repeat_with(|| Zq::from_bool(q, dist.sample(&mut rng)))
.take(n)
param: *param,
coeffs: std::iter::repeat_with(|| Zq::from_bool(param.q, dist.sample(&mut rng)))
.take(param.n)
.collect(),
evals: None,
})
@@ -280,10 +289,9 @@ impl Rq {
// }
// applies mod(T) to all coefficients of self
pub fn coeffs_mod<const T: u64>(&self, q: u64, n: usize, t: u64) -> Self {
pub fn coeffs_mod(&self, param: &RingParam, t: u64) -> Self {
Rq::from_vec_u64(
q,
n,
param,
self.coeffs()
.iter()
.map(|m_i| modulus_u64(t, m_i.v))
@@ -297,18 +305,16 @@ impl Rq {
}
pub fn mul_by_zq(&self, s: &Zq) -> Self {
Self {
q: self.q,
n: self.n,
param: self.param,
// coeffs: array::from_fn(|i| self.coeffs[i] * *s),
coeffs: self.coeffs.iter().map(|c_i| *c_i * *s).collect(),
evals: None,
}
}
pub fn mul_by_u64(&self, s: u64) -> Self {
let s = Zq::from_u64(self.q, s);
let s = Zq::from_u64(self.param.q, s);
Self {
q: self.q,
n: self.n,
param: self.param,
// coeffs: array::from_fn(|i| self.coeffs[i] * s),
coeffs: self.coeffs.iter().map(|&e| e * s).collect(),
evals: None,
@@ -316,13 +322,12 @@ impl Rq {
}
pub fn mul_by_f64(&self, s: f64) -> Self {
Self {
q: self.q,
n: self.n,
param: self.param,
// coeffs: array::from_fn(|i| Zq::from_f64(self.coeffs[i].0 as f64 * s)),
coeffs: self
.coeffs
.iter()
.map(|c_i| Zq::from_f64(self.q, c_i.v as f64 * s))
.map(|c_i| Zq::from_f64(self.param.q, c_i.v as f64 * s))
.collect(),
evals: None,
}
@@ -339,7 +344,7 @@ impl Rq {
.iter()
.map(|e| (e.v as f64 / s as f64).round())
.collect();
Rq::from_vec_f64(self.q, self.n, r)
Rq::from_vec_f64(&self.param, r)
}
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
@@ -374,9 +379,9 @@ impl Rq {
}
f.write_str(" mod Z_")?;
f.write_str(self.q.to_string().as_str())?;
f.write_str(self.param.q.to_string().as_str())?;
f.write_str("/(X^")?;
f.write_str(self.n.to_string().as_str())?;
f.write_str(self.param.n.to_string().as_str())?;
f.write_str("+1)")?;
Ok(())
}
@@ -385,8 +390,8 @@ impl Rq {
self.coeffs()
.iter()
.map(|x| {
if x.v > (self.q / 2) {
self.q - x.v
if x.v > (self.param.q / 2) {
self.param.q - x.v
} else {
x.v
}
@@ -394,7 +399,7 @@ impl Rq {
.fold(0, |a, b| a.max(b))
}
pub fn mod_centered_q(&self) -> crate::ring_n::R {
self.clone().to_r().mod_centered_q(self.q)
self.clone().to_r().mod_centered_q(self.param.q)
}
}
pub fn matrix_vec_product(m: &Vec<Vec<Zq>>, v: &Vec<Zq>) -> Result<Vec<Zq>> {
@@ -437,18 +442,16 @@ pub fn transpose(m: &[Vec<Zq>]) -> Vec<Vec<Zq>> {
impl PartialEq for Rq {
fn eq(&self, other: &Self) -> bool {
self.coeffs == other.coeffs && self.q == other.q && self.n == other.n
self.coeffs == other.coeffs && self.param == other.param
}
}
impl Add<Rq> for Rq {
type Output = Self;
fn add(self, rhs: Self) -> Self {
assert_eq!(self.q, rhs.q);
assert_eq!(self.n, rhs.n);
assert_eq!(self.param, rhs.param);
Self {
q: self.q,
n: self.n,
param: self.param,
// coeffs: array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]),
coeffs: zip_eq(self.coeffs, rhs.coeffs)
.map(|(l, r)| l + r)
@@ -471,11 +474,9 @@ impl Add<&Rq> for &Rq {
type Output = Rq;
fn add(self, rhs: &Rq) -> Self::Output {
assert_eq!(self.q, rhs.q);
assert_eq!(self.n, rhs.n);
assert_eq!(self.param, rhs.param);
Rq {
q: self.q,
n: self.n,
param: self.param,
// coeffs: array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]),
coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone())
.map(|(l, r)| l + r)
@@ -486,9 +487,8 @@ impl Add<&Rq> for &Rq {
}
impl AddAssign for Rq {
fn add_assign(&mut self, rhs: Self) {
debug_assert_eq!(self.q, rhs.q);
debug_assert_eq!(self.n, rhs.n);
for i in 0..self.n {
debug_assert_eq!(self.param, rhs.param);
for i in 0..self.param.n {
self.coeffs[i] += rhs.coeffs[i];
}
}
@@ -513,11 +513,9 @@ impl Sub<Rq> for Rq {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
assert_eq!(self.q, rhs.q);
assert_eq!(self.n, rhs.n);
assert_eq!(self.param, rhs.param);
Self {
q: self.q,
n: self.n,
param: self.param,
// coeffs: array::from_fn(|i| self.coeffs[i] - rhs.coeffs[i]),
coeffs: zip_eq(self.coeffs, rhs.coeffs)
.map(|(l, r)| l - r)
@@ -530,11 +528,9 @@ impl Sub<&Rq> for &Rq {
type Output = Rq;
fn sub(self, rhs: &Rq) -> Self::Output {
assert_eq!(self.q, rhs.q); // TODO replace all those with debug_assert_eq
debug_assert_eq!(self.n, rhs.n);
debug_assert_eq!(self.param, rhs.param);
Rq {
q: self.q,
n: self.n,
param: self.param,
// coeffs: array::from_fn(|i| self.coeffs[i] - rhs.coeffs[i]),
coeffs: zip_eq(self.coeffs.clone(), rhs.coeffs.clone())
.map(|(l, r)| l - r)
@@ -545,9 +541,8 @@ impl Sub<&Rq> for &Rq {
}
impl SubAssign for Rq {
fn sub_assign(&mut self, rhs: Self) {
debug_assert_eq!(self.q, rhs.q);
debug_assert_eq!(self.n, rhs.n);
for i in 0..self.n {
debug_assert_eq!(self.param, rhs.param);
for i in 0..self.param.n {
self.coeffs[i] -= rhs.coeffs[i];
}
}
@@ -619,8 +614,7 @@ impl Neg for Rq {
fn neg(self) -> Self::Output {
Self {
q: self.q,
n: self.n,
param: self.param,
// coeffs: array::from_fn(|i| -self.coeffs[i]),
// coeffs: self.coeffs.iter().map(|c_i| -c_i).collect(),
coeffs: self.coeffs.iter().map(|c_i| -*c_i).collect(),
@@ -631,9 +625,8 @@ impl Neg for Rq {
// note: this assumes that Q is prime
fn mul_mut(lhs: &mut Rq, rhs: &mut Rq) -> Rq {
assert_eq!(lhs.q, rhs.q);
assert_eq!(lhs.n, rhs.n);
let (q, n) = (lhs.q, lhs.n);
assert_eq!(lhs.param, rhs.param);
// let (q, n) = (lhs.q, lhs.n);
// reuse evaluations if already computed
if !lhs.evals.is_some() {
@@ -647,18 +640,16 @@ fn mul_mut(lhs: &mut Rq, rhs: &mut Rq) -> Rq {
// let c_ntt: [Zq<Q>; N] = array::from_fn(|i| lhs_evals[i] * rhs_evals[i]);
let c_ntt: Rq = Rq::from_vec(
(q, n),
&lhs.param,
zip_eq(lhs_evals, rhs_evals).map(|(l, r)| l * r).collect(),
);
let c = NTT::intt(&c_ntt);
Rq::new(q, n, c.coeffs, Some(c_ntt.coeffs))
Rq::new(&lhs.param, c.coeffs, Some(c_ntt.coeffs))
}
// note: this assumes that Q is prime
// TODO impl karatsuba for non-prime Q
fn mul(lhs: &Rq, rhs: &Rq) -> Rq {
assert_eq!(lhs.q, rhs.q);
assert_eq!(lhs.n, rhs.n);
let (q, n) = (lhs.q, lhs.n);
assert_eq!(lhs.param, rhs.param);
// reuse evaluations if already computed
let lhs_evals: Vec<Zq> = if lhs.evals.is_some() {
@@ -674,11 +665,11 @@ fn mul(lhs: &Rq, rhs: &Rq) -> Rq {
// let c_ntt: [Zq<Q>; N] = array::from_fn(|i| lhs_evals[i] * rhs_evals[i]);
let c_ntt: Rq = Rq::from_vec(
(q, n),
&lhs.param,
zip_eq(lhs_evals, rhs_evals).map(|(l, r)| l * r).collect(),
);
let c = NTT::intt(&c_ntt);
Rq::new(q, n, c.coeffs, Some(c_ntt.coeffs))
Rq::new(&lhs.param, c.coeffs, Some(c_ntt.coeffs))
}
impl fmt::Display for Rq {
@@ -701,31 +692,30 @@ mod tests {
#[test]
fn test_polynomial_ring() {
// the test values used are generated with SageMath
let q: u64 = 7;
let n: usize = 3;
let param = RingParam { q: 7, n: 3 };
// p = 1x + 2x^2 + 3x^3 + 4 x^4 + 5 x^5 in R=Z_q[X]/(X^n +1)
let p = Rq::from_vec_u64(q, n, vec![0u64, 1, 2, 3, 4, 5]);
let p = Rq::from_vec_u64(&param, vec![0u64, 1, 2, 3, 4, 5]);
assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
// try with coefficients bigger than Q
let p = Rq::from_vec_u64(q, n, vec![0u64, 1, q + 2, 3, 4, 5]);
let p = Rq::from_vec_u64(&param, vec![0u64, 1, param.q + 2, 3, 4, 5]);
assert_eq!(p.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
// try with other ring
let p = Rq::from_vec_u64(7, 4, vec![0u64, 1, 2, 3, 4, 5]);
let p = Rq::from_vec_u64(&RingParam { q: 7, n: 4 }, vec![0u64, 1, 2, 3, 4, 5]);
assert_eq!(p.to_string(), "3*x^3 + 2*x^2 + 3*x + 3 mod Z_7/(X^4+1)");
let p = Rq::from_vec_u64(q, n, vec![0u64, 0, 0, 0, 4, 5]);
let p = Rq::from_vec_u64(&param, vec![0u64, 0, 0, 0, 4, 5]);
assert_eq!(p.to_string(), "2*x^2 + 3*x mod Z_7/(X^3+1)");
let p = Rq::from_vec_u64(q, n, vec![5u64, 4, 5, 2, 1, 0]);
let p = Rq::from_vec_u64(&param, vec![5u64, 4, 5, 2, 1, 0]);
assert_eq!(p.to_string(), "5*x^2 + 3*x + 3 mod Z_7/(X^3+1)");
let a = Rq::from_vec_u64(q, n, vec![0u64, 1, 2, 3, 4, 5]);
let a = Rq::from_vec_u64(&param, vec![0u64, 1, 2, 3, 4, 5]);
assert_eq!(a.to_string(), "4*x^2 + 4*x + 4 mod Z_7/(X^3+1)");
let b = Rq::from_vec_u64(q, n, vec![5u64, 4, 3, 2, 1, 0]);
let b = Rq::from_vec_u64(&param, vec![5u64, 4, 3, 2, 1, 0]);
assert_eq!(b.to_string(), "3*x^2 + 3*x + 3 mod Z_7/(X^3+1)");
// add
@@ -742,39 +732,40 @@ mod tests {
#[test]
fn test_mul() -> Result<()> {
let q: u64 = 2u64.pow(16) + 1;
let n: usize = 4;
let param = RingParam {
q: 2u64.pow(16) + 1,
n: 4,
};
let a: Vec<u64> = vec![1u64, 2, 3, 4];
let b: Vec<u64> = vec![1u64, 2, 3, 4];
let c: Vec<u64> = vec![65513, 65517, 65531, 20];
test_mul_opt(q, n, a, b, c)?;
test_mul_opt(&param, a, b, c)?;
let a: Vec<u64> = vec![0u64, 0, 0, 2];
let b: Vec<u64> = vec![0u64, 0, 0, 2];
let c: Vec<u64> = vec![0u64, 0, 65533, 0];
test_mul_opt(q, n, a, b, c)?;
test_mul_opt(&param, a, b, c)?;
// TODO more testvectors
Ok(())
}
fn test_mul_opt(
q: u64,
n: usize,
param: &RingParam,
a: Vec<u64>,
b: Vec<u64>,
expected_c: Vec<u64>,
) -> Result<()> {
assert_eq!(a.len(), n);
assert_eq!(b.len(), n);
assert_eq!(a.len(), param.n);
assert_eq!(b.len(), param.n);
// let a: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(a[i]));
let mut a = Rq::from_vec_u64(q, n, a);
let mut a = Rq::from_vec_u64(&param, a);
// let b: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(b[i]));
let mut b = Rq::from_vec_u64(q, n, b);
let mut b = Rq::from_vec_u64(&param, b);
// let expected_c: [Zq<Q>; N] = array::from_fn(|i| Zq::from_u64(expected_c[i]));
let expected_c = Rq::from_vec_u64(q, n, expected_c);
let expected_c = Rq::from_vec_u64(&param, expected_c);
let c = mul_mut(&mut a, &mut b);
assert_eq!(c, expected_c);
@@ -783,26 +774,25 @@ mod tests {
#[test]
fn test_rq_decompose() -> Result<()> {
let q: u64 = 16;
let n: usize = 4;
let param = RingParam { q: 16, n: 4 };
let beta = 4;
let l = 2;
let a = Rq::from_vec_u64(q, n, vec![7u64, 14, 3, 6]);
let a = Rq::from_vec_u64(&param, vec![7u64, 14, 3, 6]);
let d = a.decompose(beta, l);
assert_eq!(
d[0].coeffs(),
vec![1u64, 3, 0, 1]
.iter()
.map(|e| Zq::from_u64(q, *e))
.map(|e| Zq::from_u64(param.q, *e))
.collect::<Vec<_>>()
);
assert_eq!(
d[1].coeffs(),
vec![3u64, 2, 3, 2]
.iter()
.map(|e| Zq::from_u64(q, *e))
.map(|e| Zq::from_u64(param.q, *e))
.collect::<Vec<_>>()
);
Ok(())

View File

@@ -25,11 +25,14 @@ pub struct Tn {
impl Ring for Tn {
type C = T64;
type Params = usize; // n
type Param = usize; // n
// const Q: u64 = u64::MAX; // WIP
// const N: usize = N;
fn param(&self) -> Self::Param {
self.n
}
fn coeffs(&self) -> Vec<T64> {
self.coeffs.to_vec()
}

View File

@@ -16,11 +16,14 @@ pub struct T64(pub u64);
// `Tn<1>`.
impl Ring for T64 {
type C = T64;
type Params = ();
// type Param = ();
// const Q: u64 = u64::MAX; // WIP
// const N: usize = 1;
fn param(&self) -> Self::Param {
()
}
fn coeffs(&self) -> Vec<T64> {
vec![self.clone()]
}

View File

@@ -11,7 +11,7 @@ use std::{
ops::{Add, Mul, Neg, Sub},
};
use crate::Ring;
use crate::{Ring, RingParam};
/// Tuple of K Ring (Rq) elements. We use Vec<R> to allocate it in the heap,
/// since if using a fixed-size array it would overflow the stack.
@@ -28,7 +28,7 @@ impl<R: Ring> TR<R> {
assert_eq!(r.len(), k);
Self { k, r }
}
pub fn zero(k: usize, r_params: R::Params) -> Self {
pub fn zero(k: usize, r_params: &RingParam) -> Self {
Self {
k,
r: (0..k).into_iter().map(|_| R::zero(r_params)).collect(),
@@ -38,7 +38,7 @@ impl<R: Ring> TR<R> {
mut rng: impl Rng,
dist: impl Distribution<f64>,
k: usize,
r_params: R::Params,
r_params: &RingParam,
) -> Self {
Self {
k,