Browse Source

add tests for shoup_fma

par-agg-key-shares
Janmajaya Mall 10 months ago
parent
commit
e161b33402
2 changed files with 92 additions and 51 deletions
  1. +91
    -28
      src/backend.rs
  2. +1
    -23
      src/utils.rs

+ 91
- 28
src/backend.rs

@ -107,21 +107,28 @@ pub trait VectorOps {
b: &[Self::Element], b: &[Self::Element],
c: &Self::Element, c: &Self::Element,
); );
// fn modulus(&self) -> Self::Element;
} }
pub trait ArithmeticOps { pub trait ArithmeticOps {
type Element; type Element;
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn neg(&self, a: &Self::Element) -> Self::Element; fn neg(&self, a: &Self::Element) -> Self::Element;
}
pub trait ArithmeticLazyOps {
type Element;
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
}
// fn modulus(&self) -> Self::Element;
pub trait ShoupMatrixFMA<M: Matrix>
where
M::R: RowMut,
{
/// Returns summation of row-wise product of matrix a and b.
fn shoup_matrix_fma(&self, out: &mut M::R, a: &M, a_shoup: &M, b: &M);
} }
pub struct ModularOpsU64<T> { pub struct ModularOpsU64<T> {
@ -254,18 +261,10 @@ impl ArithmeticOps for ModularOpsU64 {
self.add_mod_fast(*a, *b) self.add_mod_fast(*a, *b)
} }
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.add_mod_fast_lazy(*a, *b)
}
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.mul_mod_fast(*a, *b) self.mul_mod_fast(*a, *b)
} }
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.mul_mod_fast_lazy(*a, *b)
}
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.sub_mod_fast(*a, *b) self.sub_mod_fast(*a, *b)
} }
@ -279,6 +278,16 @@ impl ArithmeticOps for ModularOpsU64 {
// } // }
} }
impl<T> ArithmeticLazyOps for ModularOpsU64<T> {
type Element = u64;
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.add_mod_fast_lazy(*a, *b)
}
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.mul_mod_fast_lazy(*a, *b)
}
}
impl<T> VectorOps for ModularOpsU64<T> { impl<T> VectorOps for ModularOpsU64<T> {
type Element = u64; type Element = u64;
@ -344,12 +353,11 @@ impl VectorOps for ModularOpsU64 {
// } // }
} }
impl<T> ModularOpsU64<T> {
/// Returns \sum a[i]b[i]
pub fn shoup_fma<M: Matrix<MatElement = u64>>(&self, out: &mut M::R, a: &M, a_shoup: &M, b: &M)
where
M::R: RowMut,
{
impl<M: Matrix<MatElement = u64>, T> ShoupMatrixFMA<M> for ModularOpsU64<T>
where
M::R: RowMut,
{
fn shoup_matrix_fma(&self, out: &mut <M as Matrix>::R, a: &M, a_shoup: &M, b: &M) {
assert!(a.dimension() == a_shoup.dimension()); assert!(a.dimension() == a_shoup.dimension());
assert!(a.dimension() == b.dimension()); assert!(a.dimension() == b.dimension());
@ -393,6 +401,7 @@ impl ModularOpsU64 {
}); });
} }
} }
impl<T> GetModulus for ModularOpsU64<T> impl<T> GetModulus for ModularOpsU64<T>
where where
T: Modulus, T: Modulus,
@ -430,18 +439,10 @@ where
T::Element::wrapping_add(a, b) T::Element::wrapping_add(a, b)
} }
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.add(a, b)
}
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
T::Element::wrapping_mul(a, b) T::Element::wrapping_mul(a, b)
} }
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.mul(a, b)
}
fn neg(&self, a: &Self::Element) -> Self::Element { fn neg(&self, a: &Self::Element) -> Self::Element {
T::Element::wrapping_sub(&T::Element::zero(), a) T::Element::wrapping_sub(&T::Element::zero(), a)
} }
@ -531,3 +532,65 @@ where
&self.modulus &self.modulus
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use itertools::Itertools;
use rand::{thread_rng, Rng};
use rand_distr::Uniform;
#[test]
fn fma() {
let mut rng = thread_rng();
let prime = 36028797017456641;
let ring_size = 1 << 3;
let dist = Uniform::new(0, prime);
let d = 2;
let a0_matrix = (0..d)
.into_iter()
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
.collect_vec();
// a0 in shoup representation
let a0_shoup_matrix = a0_matrix
.iter()
.map(|r| {
r.iter()
.map(|v| {
// $(v * 2^{\beta}) / p$
((*v as u128 * (1u128 << 64)) / prime as u128) as u64
})
.collect_vec()
})
.collect_vec();
let a1_matrix = (0..d)
.into_iter()
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
.collect_vec();
let modop = ModularOpsU64::new(prime);
let mut out_shoup_fma_lazy = vec![0u64; ring_size];
modop.shoup_matrix_fma(
&mut out_shoup_fma_lazy,
&a0_matrix,
&a0_shoup_matrix,
&a1_matrix,
);
let out_shoup_fma = out_shoup_fma_lazy
.iter()
.map(|v| if *v >= prime { v - prime } else { *v })
.collect_vec();
// expected
let mut out_expected = vec![0u64; ring_size];
izip!(a0_matrix.iter(), a1_matrix.iter()).for_each(|(a_r, b_r)| {
izip!(out_expected.iter_mut(), a_r.iter(), b_r.iter()).for_each(|(o, a0, a1)| {
*o = (*o + ((*a0 as u128 * *a1 as u128) % prime as u128) as u64) % prime;
});
});
assert_eq!(out_expected, out_shoup_fma);
}
}

+ 1
- 23
src/utils.rs

@ -234,26 +234,4 @@ where
} }
#[cfg(test)] #[cfg(test)]
mod tests {
use rand::{thread_rng, Rng};
use super::ShoupMul;
#[test]
fn gg() {
let mut rng = thread_rng();
let p = 36028797018820609;
let a = rng.gen_range(0..p);
let b = rng.gen_range(0..p);
let a_shoup = ShoupMul::representation(a, p);
// let c = ShoupMul::mul(b, a, a_shoup, p);
// assert!(c == ((a as u128 * b as u128) % p as u128) as u64);
let mut quotient = ((a_shoup as u128 * b as u128) >> 64) as u64;
quotient -= 1;
let c = (b.wrapping_mul(a)).wrapping_sub(p.wrapping_mul(quotient));
assert!(c - p == ((a as u128 * b as u128) % p as u128) as u64);
}
}
mod tests {}

Loading…
Cancel
Save