From e161b33402e26a971594819bfbb3d699adffb3d3 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Mon, 10 Jun 2024 16:50:14 +0530 Subject: [PATCH] add tests for shoup_fma --- src/backend.rs | 119 +++++++++++++++++++++++++++++++++++++------------ src/utils.rs | 24 +--------- 2 files changed, 92 insertions(+), 51 deletions(-) diff --git a/src/backend.rs b/src/backend.rs index ed5519b..70f849c 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -107,21 +107,28 @@ pub trait VectorOps { b: &[Self::Element], c: &Self::Element, ); - - // fn modulus(&self) -> Self::Element; } pub trait ArithmeticOps { type Element; - fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; - fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; - fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn neg(&self, a: &Self::Element) -> Self::Element; +} + +pub trait ArithmeticLazyOps { + type Element; + fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; + fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; +} - // fn modulus(&self) -> Self::Element; +pub trait ShoupMatrixFMA +where + M::R: RowMut, +{ + /// Returns summation of row-wise product of matrix a and b. + fn shoup_matrix_fma(&self, out: &mut M::R, a: &M, a_shoup: &M, b: &M); } pub struct ModularOpsU64 { @@ -254,18 +261,10 @@ impl ArithmeticOps for ModularOpsU64 { self.add_mod_fast(*a, *b) } - fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { - self.add_mod_fast_lazy(*a, *b) - } - fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { self.mul_mod_fast(*a, *b) } - fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { - self.mul_mod_fast_lazy(*a, *b) - } - fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { self.sub_mod_fast(*a, *b) } @@ -279,6 +278,16 @@ impl ArithmeticOps for ModularOpsU64 { // } } +impl ArithmeticLazyOps for ModularOpsU64 { + type Element = u64; + fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + self.add_mod_fast_lazy(*a, *b) + } + fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + self.mul_mod_fast_lazy(*a, *b) + } +} + impl VectorOps for ModularOpsU64 { type Element = u64; @@ -344,12 +353,11 @@ impl VectorOps for ModularOpsU64 { // } } -impl ModularOpsU64 { - /// Returns \sum a[i]b[i] - pub fn shoup_fma>(&self, out: &mut M::R, a: &M, a_shoup: &M, b: &M) - where - M::R: RowMut, - { +impl, T> ShoupMatrixFMA for ModularOpsU64 +where + M::R: RowMut, +{ + fn shoup_matrix_fma(&self, out: &mut ::R, a: &M, a_shoup: &M, b: &M) { assert!(a.dimension() == a_shoup.dimension()); assert!(a.dimension() == b.dimension()); @@ -393,6 +401,7 @@ impl ModularOpsU64 { }); } } + impl GetModulus for ModularOpsU64 where T: Modulus, @@ -430,18 +439,10 @@ where T::Element::wrapping_add(a, b) } - fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { - self.add(a, b) - } - fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { T::Element::wrapping_mul(a, b) } - fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { - self.mul(a, b) - } - fn neg(&self, a: &Self::Element) -> Self::Element { T::Element::wrapping_sub(&T::Element::zero(), a) } @@ -531,3 +532,65 @@ where &self.modulus } } + +#[cfg(test)] +mod tests { + use super::*; + use itertools::Itertools; + use rand::{thread_rng, Rng}; + use rand_distr::Uniform; + + #[test] + fn fma() { + let mut rng = thread_rng(); + let prime = 36028797017456641; + let ring_size = 1 << 3; + + let dist = Uniform::new(0, prime); + let d = 2; + let a0_matrix = (0..d) + .into_iter() + .map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec()) + .collect_vec(); + // a0 in shoup representation + let a0_shoup_matrix = a0_matrix + .iter() + .map(|r| { + r.iter() + .map(|v| { + // $(v * 2^{\beta}) / p$ + ((*v as u128 * (1u128 << 64)) / prime as u128) as u64 + }) + .collect_vec() + }) + .collect_vec(); + let a1_matrix = (0..d) + .into_iter() + .map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec()) + .collect_vec(); + + let modop = ModularOpsU64::new(prime); + + let mut out_shoup_fma_lazy = vec![0u64; ring_size]; + modop.shoup_matrix_fma( + &mut out_shoup_fma_lazy, + &a0_matrix, + &a0_shoup_matrix, + &a1_matrix, + ); + let out_shoup_fma = out_shoup_fma_lazy + .iter() + .map(|v| if *v >= prime { v - prime } else { *v }) + .collect_vec(); + + // expected + let mut out_expected = vec![0u64; ring_size]; + izip!(a0_matrix.iter(), a1_matrix.iter()).for_each(|(a_r, b_r)| { + izip!(out_expected.iter_mut(), a_r.iter(), b_r.iter()).for_each(|(o, a0, a1)| { + *o = (*o + ((*a0 as u128 * *a1 as u128) % prime as u128) as u64) % prime; + }); + }); + + assert_eq!(out_expected, out_shoup_fma); + } +} diff --git a/src/utils.rs b/src/utils.rs index 3118fdc..6fd9c08 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -234,26 +234,4 @@ where } #[cfg(test)] -mod tests { - use rand::{thread_rng, Rng}; - - use super::ShoupMul; - - #[test] - fn gg() { - let mut rng = thread_rng(); - let p = 36028797018820609; - - let a = rng.gen_range(0..p); - let b = rng.gen_range(0..p); - let a_shoup = ShoupMul::representation(a, p); - - // let c = ShoupMul::mul(b, a, a_shoup, p); - // assert!(c == ((a as u128 * b as u128) % p as u128) as u64); - - let mut quotient = ((a_shoup as u128 * b as u128) >> 64) as u64; - quotient -= 1; - let c = (b.wrapping_mul(a)).wrapping_sub(p.wrapping_mul(quotient)); - assert!(c - p == ((a as u128 * b as u128) % p as u128) as u64); - } -} +mod tests {}