Browse Source

add arith::{complex, matrix} primitives

gfhe-over-ring-trait
arnaucube 1 month ago
parent
commit
267422a3b5
7 changed files with 575 additions and 1 deletions
  1. +8
    -0
      arith/Cargo.toml
  2. +329
    -0
      arith/src/complex.rs
  3. +5
    -1
      arith/src/lib.rs
  4. +187
    -0
      arith/src/matrix.rs
  5. +23
    -0
      arith/src/ring.rs
  6. +18
    -0
      arith/src/ringq.rs
  7. +5
    -0
      arith/src/zq.rs

+ 8
- 0
arith/Cargo.toml

@ -7,3 +7,11 @@ edition = "2024"
anyhow = { workspace = true }
rand = { workspace = true }
rand_distr = { workspace = true }
# TMP: the next 4 imports are TMP, to solve systems of linear equations. Used
# for the CKKS encoding step, probably remvoed once in ckks the encoding is done
# as in 2018-1043 or 2018-1073.
num = "0.4.3"
num-complex = "0.4.6"
ndarray = "0.16.1"
ndarray-linalg = { version = "0.17.0", features = ["intel-mkl"] }

+ 329
- 0
arith/src/complex.rs

@ -0,0 +1,329 @@
//! Complex
use rand::Rng;
use std::ops::{Add, Div, Mul, Neg, Sub};
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub struct C<T> {
pub re: T,
pub im: T,
}
impl From<f64> for C<f64> {
fn from(v: f64) -> Self {
Self { re: v, im: 0_f64 }
}
}
impl<T> C<T>
where
T: Default + From<i32>,
{
pub fn new(re: T, im: T) -> Self {
Self { re, im }
}
pub fn rand(mut rng: impl Rng, max: u64) -> Self {
Self::new(
T::from(rng.gen_range(0..max) as i32),
T::from(rng.gen_range(0..max) as i32),
)
}
pub fn zero() -> C<T> {
Self {
re: T::from(0),
im: T::from(0),
}
}
pub fn one() -> C<T> {
Self {
re: T::from(1),
im: T::from(0),
}
}
pub fn i() -> C<T> {
Self {
re: T::from(0),
im: T::from(1),
}
}
}
impl C<f64> {
// cos & sin from Taylor series approximation, details at
// https://en.wikipedia.org/wiki/Sine_and_cosine#Series_and_polynomials
fn cos(x: f64) -> f64 {
let mut r = 1.0;
let mut term = 1.0;
let mut n = 1;
for _ in 0..10 {
term *= -(x * x) / ((2 * n - 1) * (2 * n)) as f64;
r += term;
n += 1;
}
r
}
fn sin(x: f64) -> f64 {
let mut r = x;
let mut term = x;
let mut n = 1;
for _ in 0..10 {
term *= -(x * x) / ((2 * n) * (2 * n + 1)) as f64;
r += term;
n += 1;
}
r
}
// e^(self))
pub fn exp(self) -> Self {
Self {
re: Self::cos(self.im), // TODO WIP review
im: Self::sin(self.im),
}
}
pub fn pow(self, k: u32) -> Self {
let mut k = k;
if k == 0 {
return Self::one();
}
let mut base = self.clone();
while k & 1 == 0 {
base = base.clone() * base;
k >>= 1;
}
if k == 1 {
return base;
}
let mut acc = base.clone();
while k > 1 {
k >>= 1;
base = base.clone() * base;
if k & 1 == 1 {
acc = acc * base.clone();
}
}
acc
}
pub fn modulus<const Q: u64>(self) -> Self {
let q: f64 = Q as f64;
let re = (self.re % q + q) % q;
let im = (self.im % q + q) % q;
Self { re, im }
}
pub fn modulus_centered<const Q: u64>(self) -> Self {
let re = modulus_centered_f64::<Q>(self.re);
let im = modulus_centered_f64::<Q>(self.im);
Self { re, im }
}
}
fn modulus_centered_f64<const Q: u64>(v: f64) -> f64 {
let q = Q as f64;
let mut res = v % q;
if res > q / 2.0 {
res = res - q;
}
res
}
impl<T: Default> Default for C<T> {
fn default() -> Self {
C {
re: T::default(),
im: T::default(),
}
}
}
impl<T> Add for C<T>
where
T: Add<Output = T> + Copy,
{
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
C {
re: self.re + rhs.re,
im: self.im + rhs.im,
}
}
}
impl<T> Sub for C<T>
where
T: Sub<Output = T> + Copy,
{
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
C {
re: self.re - rhs.re,
im: self.im - rhs.im,
}
}
}
impl<T> Mul for C<T>
where
T: Mul<Output = T> + Sub<Output = T> + Add<Output = T> + Copy,
{
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
C {
re: self.re * rhs.re - self.im * rhs.im,
im: self.re * rhs.im + self.im * rhs.re,
}
}
}
impl<T> Neg for C<T>
where
T: Neg<Output = T> + Copy,
{
type Output = Self;
fn neg(self) -> Self::Output {
C {
re: -self.re,
im: -self.im,
}
}
}
impl<T> Div for C<T>
where
T: Copy
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Neg<Output = T>,
{
type Output = Self;
fn div(self, rhs: Self) -> Self {
// (a+ib)/(c+id) = (ac + bd)/(c^2 + d^2) + i* (bc -ad)/(c^2 + d^2)
let den = rhs.re * rhs.re + rhs.im * rhs.im;
C {
re: (self.re * rhs.re + self.im * rhs.im) / den,
im: (self.im * rhs.re - self.re * rhs.im) / den,
}
}
}
impl<T> C<T>
where
T: Neg<Output = T> + Copy,
{
pub fn conj(&self) -> Self {
C {
re: self.re,
im: -self.im,
}
}
}
impl<T: Add<Output = T> + Default + Copy> std::iter::Sum for C<T> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(C::default(), |acc, x| acc + x)
}
}
impl C<f64> {
pub fn abs(&self) -> f64 {
(self.re * self.re + self.im * self.im).sqrt()
}
}
// poly mul with complex coefficients
pub fn naive_poly_mul<const N: usize>(poly1: &Vec<C<f64>>, poly2: &Vec<C<f64>>) -> Vec<C<f64>> {
let mut result: Vec<C<f64>> = vec![C::<f64>::zero(); (N * 2) - 1];
for i in 0..N {
for j in 0..N {
result[i + j] = result[i + j] + poly1[i] * poly2[j];
}
}
// apply mod (X^N + 1))
// R::<N>::from_vec(result.iter().map(|c| *c as i64).collect())
// modulus_i128::<N>(&mut result);
// dbg!(&result);
// dbg!(R::<N>(array::from_fn(|i| result[i] as i64)).coeffs());
// R(array::from_fn(|i| result[i] as i64))
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_i() {
assert_eq!(C::i(), C::new(0.0, 1.0));
}
#[test]
fn test_add() {
let a = C::new(2.0, 3.0);
let b = C::new(1.0, 4.0);
assert_eq!(a + b, C::new(3.0, 7.0));
}
#[test]
fn test_sub() {
let a = C::new(5.0, 7.0);
let b = C::new(2.0, 3.0);
assert_eq!(a - b, C::new(3.0, 4.0));
}
#[test]
fn test_mult() {
let a = C::new(1.0, 2.0);
let b = C::new(3.0, 4.0);
assert_eq!(a * b, C::new(-5.0, 10.0));
}
#[test]
fn test_div() {
let a: C<f64> = C::new(1.0, 2.0);
let b: C<f64> = C::new(3.0, 4.0);
let r = a / b;
let expected = C::new(0.44, 0.08);
let epsilon = 1e-2;
assert!((r.re - expected.re).abs() < epsilon);
assert!((r.im - expected.im).abs() < epsilon);
}
#[test]
fn test_conj() {
let a = C::new(3.0, -4.0);
assert_eq!(a.conj(), C::new(3.0, 4.0));
assert_eq!(a.conj().conj(), a);
}
#[test]
fn test_neg() {
let a = C::new(1.0, -2.0);
assert_eq!(-a, C::new(-1.0, 2.0));
}
#[test]
fn test_abs() {
let a = C::new(3.0, 4.0);
assert!((a.abs() - 5.0).abs() < 1e-10);
}
#[test]
fn test_exp() {
// let a = C::new(3.0, 4.0);
let pi = C::<f64>::from(std::f64::consts::PI);
let n = 4;
let a = ((C::<f64>::from(2f64) * pi * C::<f64>::i()) / C::<f64>::new(n as f64, 0f64)).exp();
dbg!(&a);
assert_eq!(a.exp(), a.exp());
}
}

+ 5
- 1
arith/src/lib.rs

@ -4,12 +4,16 @@
#![allow(clippy::upper_case_acronyms)]
#![allow(dead_code)] // TMP
mod naive_ntt; // TODO rm
pub mod complex;
pub mod matrix;
mod naive_ntt; // note: for dev only
pub mod ntt;
pub mod ring;
pub mod ringq;
pub mod zq;
pub use complex::C;
pub use matrix::Matrix;
pub use ntt::NTT;
pub use ring::R;
pub use ringq::Rq;

+ 187
- 0
arith/src/matrix.rs

@ -0,0 +1,187 @@
use anyhow::{anyhow, Result};
use std::ops::{Add, Mul};
#[derive(Debug, Clone, PartialEq)]
pub struct Matrix<T>(pub Vec<Vec<T>>);
impl<T> Matrix<T>
where
T: Copy + Add<Output = T> + Mul<Output = T> + Default + std::fmt::Debug,
{
// TODO maybe rm this method, move it to tests only
pub fn new(rows: usize, cols: usize, value: T) -> Self {
Matrix(vec![vec![value; cols]; rows])
}
pub fn add(&self, other: &Matrix<T>) -> Result<Matrix<T>> {
if self.0.len() != other.0.len() || self.0[0].len() != other.0[0].len() {
return Err(anyhow!("dimensions don't match"));
}
let r = self
.0
.iter()
.zip(&other.0)
.map(|(row1, row2)| {
row1.iter()
.zip(row2)
.map(|(a, b)| *a + *b)
.collect::<Vec<T>>()
})
.collect::<Vec<Vec<T>>>();
Ok(Matrix(r))
}
pub fn mul(&self, other: &Matrix<T>) -> Result<Matrix<T>> {
let rows_a = self.0.len();
let cols_a = self.0[0].len();
let rows_b = other.0.len();
let cols_b = other.0[0].len();
if cols_a != rows_b {
return Err(anyhow!("self.n_cols != other.n_rows"));
}
let mut r = vec![vec![T::default(); cols_b]; rows_a];
for i in 0..rows_a {
for j in 0..cols_b {
for k in 0..cols_a {
r[i][j] = r[i][j] + self.0[i][k] * other.0[k][j];
}
}
}
Ok(Matrix(r))
}
pub fn mul_vec(&self, v: &Vec<T>) -> Result<Vec<T>> {
let rows = self.0.len();
let cols = self.0[0].len();
if cols != v.len() {
return Err(anyhow!(
"Number of columns in matrix does not match the length of the vector"
));
}
let mut r = vec![T::default(); rows];
for i in 0..rows {
for j in 0..cols {
r[i] = r[i] + self.0[i][j] * v[j];
}
}
Ok(r)
}
pub fn transpose(&self) -> Matrix<T> {
let rows = self.0.len();
let cols = self.0[0].len();
let mut r = vec![vec![T::default(); rows]; cols];
for i in 0..rows {
for j in 0..cols {
r[j][i] = self.0[i][j];
}
}
Matrix(r)
}
pub fn scalar_mul(&self, scalar: T) -> Matrix<T> {
let r = self
.0
.iter()
.map(|row| row.iter().map(|&val| val * scalar).collect::<Vec<T>>())
.collect::<Vec<Vec<T>>>();
Matrix(r)
}
}
// WIP. Currently uses ndarray, ndarray_linalg, num_complex to solve a system of
// linear equations A*x=b for x.
use crate::C;
impl Matrix<C<f64>> {
pub fn solve(&self, b: &Vec<C<f64>>) -> Result<Vec<C<f64>>> {
use ndarray::{Array1, Array2};
use ndarray_linalg::Solve;
use num_complex::Complex64;
let m: Array2<Complex64> = Array2::from_shape_vec(
(self.0.len(), self.0[0].len()),
self.0
.clone()
.into_iter()
.flatten()
.map(|e| Complex64::new(e.re, e.im))
.collect(),
)
.unwrap();
let v: Array1<Complex64> = Array1::from_shape_vec(
b.len(),
b.iter().map(|e| Complex64::new(e.re, e.im)).collect(),
)
.unwrap();
let r = m.solve(&v)?;
let r: Vec<C<f64>> = r.into_iter().map(|e| C::<f64>::new(e.re, e.im)).collect();
Ok(r)
}
}
impl Matrix<f64> {
// tmp (rm)
pub fn solve(&self, b: Vec<f64>) -> Result<Vec<f64>> {
use ndarray::{Array1, Array2};
use ndarray_linalg::Solve;
let m: Array2<f64> = Array2::from_shape_vec(
(self.0.len(), self.0[0].len()),
self.0.clone().into_iter().flatten().collect(),
)
.unwrap();
let v: Array1<f64> = Array1::from_shape_vec(b.len(), b).unwrap();
let r = m.solve(&v)?;
let r: Vec<f64> = r.into_iter().map(|e| e).collect();
Ok(r)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add() -> Result<()> {
let a = Matrix::new(2, 3, 1);
let b = Matrix::new(2, 3, 2);
let expected = Matrix::new(2, 3, 3);
assert_eq!(a.add(&b).unwrap(), expected);
Ok(())
}
#[test]
fn test_mul() -> Result<()> {
let a = Matrix::new(2, 3, 1);
let b = Matrix::new(3, 2, 1);
let expected = Matrix::new(2, 2, 3); // 2x3 * 3x2 = 2x2 matrix (with all values 3)
assert_eq!(a.mul(&b).unwrap(), expected);
Ok(())
}
#[test]
fn test_transpose() -> Result<()> {
let a = Matrix::new(2, 3, 1);
let expected = Matrix::new(3, 2, 1);
assert_eq!(a.transpose(), expected);
Ok(())
}
#[test]
fn test_scalar_mul() -> Result<()> {
let a = Matrix::new(2, 3, 1);
let expected = Matrix::new(2, 3, 3);
assert_eq!(a.scalar_mul(3), expected);
Ok(())
}
}

+ 23
- 0
arith/src/ring.rs

@ -53,6 +53,29 @@ impl R {
.collect();
crate::Rq::<Q, N>::from_vec_f64(r)
}
pub fn infinity_norm(&self) -> u64 {
self.coeffs()
.iter()
// .map(|x| if x.0 > (Q / 2) { Q - x.0 } else { x.0 })
.map(|x| x.abs() as u64)
.fold(0, |a, b| a.max(b))
}
pub fn mod_centered_q<const Q: u64>(&self) -> R<N> {
let q = Q as i64;
let r = self
.0
.iter()
.map(|v| {
let mut res = v % q;
if res > q / 2 {
res = res - q;
}
res
})
.collect::<Vec<i64>>();
R::<N>::from_vec(r)
}
}
pub fn mul_div_round<const Q: u64, const N: usize>(

+ 18
- 0
arith/src/ringq.rs

@ -52,6 +52,9 @@ impl Rq {
pub fn coeffs(&self) -> [Zq<Q>; N] {
self.coeffs
}
pub fn compute_evals(&mut self) {
self.evals = Some(NTT::<Q, N>::ntt(self.coeffs));
}
pub fn to_r(self) -> crate::R<N> {
crate::R::<N>::from(self)
}
@ -131,6 +134,15 @@ impl Rq {
pub fn remodule<const P: u64>(&self) -> Rq<P, N> {
Rq::<P, N>::from_vec_u64(self.coeffs().iter().map(|m_i| m_i.0).collect())
}
/// perform the mod switch operation from Q to Q', where Q2=Q'
fn mod_switch<const Q2: u64>(&self) -> Rq<Q2, N> {
Rq::<Q2, N> {
coeffs: array::from_fn(|i| self.coeffs[i].mod_switch::<Q2>()),
evals: None,
}
}
// applies mod(T) to all coefficients of self
pub fn coeffs_mod<const T: u64>(&self) -> Self {
Rq::<Q, N>::from_vec_u64(
@ -236,6 +248,9 @@ impl Rq {
.map(|x| if x.0 > (Q / 2) { Q - x.0 } else { x.0 })
.fold(0, |a, b| a.max(b))
}
pub fn mod_centered_q(&self) -> crate::ring::R<N> {
self.to_r().mod_centered_q::<Q>()
}
}
pub fn matrix_vec_product<const Q: u64>(m: &Vec<Vec<Zq<Q>>>, v: &Vec<Zq<Q>>) -> Result<Vec<Zq<Q>>> {
// assert_eq!(m.len(), m[0].len()); // TODO change to returning err
@ -399,6 +414,7 @@ impl ops::Neg for Rq {
}
}
// note: this assumes that Q is prime
fn mul_mut<const Q: u64, const N: usize>(lhs: &mut Rq<Q, N>, rhs: &mut Rq<Q, N>) -> Rq<Q, N> {
// reuse evaluations if already computed
if !lhs.evals.is_some() {
@ -414,6 +430,8 @@ fn mul_mut(lhs: &mut Rq, rhs: &mut Rq)
let c = NTT::<Q, { N }>::intt(c_ntt);
Rq::new(c, Some(c_ntt))
}
// note: this assumes that Q is prime
// TODO impl karatsuba for non-prime Q
fn mul<const Q: u64, const N: usize>(lhs: &Rq<Q, N>, rhs: &Rq<Q, N>) -> Rq<Q, N> {
// reuse evaluations if already computed
let lhs_evals = if lhs.evals.is_some() {

+ 5
- 0
arith/src/zq.rs

@ -117,6 +117,11 @@ impl Zq {
(g, y - (b / a) * x, x)
}
}
/// perform the mod switch operation from Q to Q', where Q2=Q'
pub fn mod_switch<const Q2: u64>(&self) -> Zq<Q2> {
Zq::<Q2>::from_u64(((self.0 as f64 * Q2 as f64) / Q as f64).round() as u64)
}
}
impl<const Q: u64> Zq<Q> {

Loading…
Cancel
Save