mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-11 16:41:29 +01:00
add tests for shoup_fma
This commit is contained in:
117
src/backend.rs
117
src/backend.rs
@@ -107,21 +107,28 @@ pub trait VectorOps {
|
||||
b: &[Self::Element],
|
||||
c: &Self::Element,
|
||||
);
|
||||
|
||||
// fn modulus(&self) -> Self::Element;
|
||||
}
|
||||
|
||||
pub trait ArithmeticOps {
|
||||
type Element;
|
||||
|
||||
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||
fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||
fn neg(&self, a: &Self::Element) -> Self::Element;
|
||||
}
|
||||
|
||||
// fn modulus(&self) -> Self::Element;
|
||||
pub trait ArithmeticLazyOps {
|
||||
type Element;
|
||||
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||
}
|
||||
|
||||
pub trait ShoupMatrixFMA<M: Matrix>
|
||||
where
|
||||
M::R: RowMut,
|
||||
{
|
||||
/// Returns summation of row-wise product of matrix a and b.
|
||||
fn shoup_matrix_fma(&self, out: &mut M::R, a: &M, a_shoup: &M, b: &M);
|
||||
}
|
||||
|
||||
pub struct ModularOpsU64<T> {
|
||||
@@ -254,18 +261,10 @@ impl<T> ArithmeticOps for ModularOpsU64<T> {
|
||||
self.add_mod_fast(*a, *b)
|
||||
}
|
||||
|
||||
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||
self.add_mod_fast_lazy(*a, *b)
|
||||
}
|
||||
|
||||
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||
self.mul_mod_fast(*a, *b)
|
||||
}
|
||||
|
||||
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||
self.mul_mod_fast_lazy(*a, *b)
|
||||
}
|
||||
|
||||
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||
self.sub_mod_fast(*a, *b)
|
||||
}
|
||||
@@ -279,6 +278,16 @@ impl<T> ArithmeticOps for ModularOpsU64<T> {
|
||||
// }
|
||||
}
|
||||
|
||||
impl<T> ArithmeticLazyOps for ModularOpsU64<T> {
|
||||
type Element = u64;
|
||||
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||
self.add_mod_fast_lazy(*a, *b)
|
||||
}
|
||||
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||
self.mul_mod_fast_lazy(*a, *b)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> VectorOps for ModularOpsU64<T> {
|
||||
type Element = u64;
|
||||
|
||||
@@ -344,12 +353,11 @@ impl<T> VectorOps for ModularOpsU64<T> {
|
||||
// }
|
||||
}
|
||||
|
||||
impl<T> ModularOpsU64<T> {
|
||||
/// Returns \sum a[i]b[i]
|
||||
pub fn shoup_fma<M: Matrix<MatElement = u64>>(&self, out: &mut M::R, a: &M, a_shoup: &M, b: &M)
|
||||
where
|
||||
impl<M: Matrix<MatElement = u64>, T> ShoupMatrixFMA<M> for ModularOpsU64<T>
|
||||
where
|
||||
M::R: RowMut,
|
||||
{
|
||||
{
|
||||
fn shoup_matrix_fma(&self, out: &mut <M as Matrix>::R, a: &M, a_shoup: &M, b: &M) {
|
||||
assert!(a.dimension() == a_shoup.dimension());
|
||||
assert!(a.dimension() == b.dimension());
|
||||
|
||||
@@ -393,6 +401,7 @@ impl<T> ModularOpsU64<T> {
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> GetModulus for ModularOpsU64<T>
|
||||
where
|
||||
T: Modulus,
|
||||
@@ -430,18 +439,10 @@ where
|
||||
T::Element::wrapping_add(a, b)
|
||||
}
|
||||
|
||||
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||
self.add(a, b)
|
||||
}
|
||||
|
||||
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||
T::Element::wrapping_mul(a, b)
|
||||
}
|
||||
|
||||
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||
self.mul(a, b)
|
||||
}
|
||||
|
||||
fn neg(&self, a: &Self::Element) -> Self::Element {
|
||||
T::Element::wrapping_sub(&T::Element::zero(), a)
|
||||
}
|
||||
@@ -531,3 +532,65 @@ where
|
||||
&self.modulus
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use itertools::Itertools;
|
||||
use rand::{thread_rng, Rng};
|
||||
use rand_distr::Uniform;
|
||||
|
||||
#[test]
|
||||
fn fma() {
|
||||
let mut rng = thread_rng();
|
||||
let prime = 36028797017456641;
|
||||
let ring_size = 1 << 3;
|
||||
|
||||
let dist = Uniform::new(0, prime);
|
||||
let d = 2;
|
||||
let a0_matrix = (0..d)
|
||||
.into_iter()
|
||||
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
|
||||
.collect_vec();
|
||||
// a0 in shoup representation
|
||||
let a0_shoup_matrix = a0_matrix
|
||||
.iter()
|
||||
.map(|r| {
|
||||
r.iter()
|
||||
.map(|v| {
|
||||
// $(v * 2^{\beta}) / p$
|
||||
((*v as u128 * (1u128 << 64)) / prime as u128) as u64
|
||||
})
|
||||
.collect_vec()
|
||||
})
|
||||
.collect_vec();
|
||||
let a1_matrix = (0..d)
|
||||
.into_iter()
|
||||
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
|
||||
.collect_vec();
|
||||
|
||||
let modop = ModularOpsU64::new(prime);
|
||||
|
||||
let mut out_shoup_fma_lazy = vec![0u64; ring_size];
|
||||
modop.shoup_matrix_fma(
|
||||
&mut out_shoup_fma_lazy,
|
||||
&a0_matrix,
|
||||
&a0_shoup_matrix,
|
||||
&a1_matrix,
|
||||
);
|
||||
let out_shoup_fma = out_shoup_fma_lazy
|
||||
.iter()
|
||||
.map(|v| if *v >= prime { v - prime } else { *v })
|
||||
.collect_vec();
|
||||
|
||||
// expected
|
||||
let mut out_expected = vec![0u64; ring_size];
|
||||
izip!(a0_matrix.iter(), a1_matrix.iter()).for_each(|(a_r, b_r)| {
|
||||
izip!(out_expected.iter_mut(), a_r.iter(), b_r.iter()).for_each(|(o, a0, a1)| {
|
||||
*o = (*o + ((*a0 as u128 * *a1 as u128) % prime as u128) as u64) % prime;
|
||||
});
|
||||
});
|
||||
|
||||
assert_eq!(out_expected, out_shoup_fma);
|
||||
}
|
||||
}
|
||||
|
||||
24
src/utils.rs
24
src/utils.rs
@@ -234,26 +234,4 @@ where
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
use super::ShoupMul;
|
||||
|
||||
#[test]
|
||||
fn gg() {
|
||||
let mut rng = thread_rng();
|
||||
let p = 36028797018820609;
|
||||
|
||||
let a = rng.gen_range(0..p);
|
||||
let b = rng.gen_range(0..p);
|
||||
let a_shoup = ShoupMul::representation(a, p);
|
||||
|
||||
// let c = ShoupMul::mul(b, a, a_shoup, p);
|
||||
// assert!(c == ((a as u128 * b as u128) % p as u128) as u64);
|
||||
|
||||
let mut quotient = ((a_shoup as u128 * b as u128) >> 64) as u64;
|
||||
quotient -= 1;
|
||||
let c = (b.wrapping_mul(a)).wrapping_sub(p.wrapping_mul(quotient));
|
||||
assert!(c - p == ((a as u128 * b as u128) % p as u128) as u64);
|
||||
}
|
||||
}
|
||||
mod tests {}
|
||||
|
||||
Reference in New Issue
Block a user