mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-11 16:41:29 +01:00
add tests for shoup_fma
This commit is contained in:
117
src/backend.rs
117
src/backend.rs
@@ -107,21 +107,28 @@ pub trait VectorOps {
|
|||||||
b: &[Self::Element],
|
b: &[Self::Element],
|
||||||
c: &Self::Element,
|
c: &Self::Element,
|
||||||
);
|
);
|
||||||
|
|
||||||
// fn modulus(&self) -> Self::Element;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait ArithmeticOps {
|
pub trait ArithmeticOps {
|
||||||
type Element;
|
type Element;
|
||||||
|
|
||||||
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||||
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
|
||||||
fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||||
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
|
||||||
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||||
fn neg(&self, a: &Self::Element) -> Self::Element;
|
fn neg(&self, a: &Self::Element) -> Self::Element;
|
||||||
|
}
|
||||||
|
|
||||||
// fn modulus(&self) -> Self::Element;
|
pub trait ArithmeticLazyOps {
|
||||||
|
type Element;
|
||||||
|
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||||
|
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait ShoupMatrixFMA<M: Matrix>
|
||||||
|
where
|
||||||
|
M::R: RowMut,
|
||||||
|
{
|
||||||
|
/// Returns summation of row-wise product of matrix a and b.
|
||||||
|
fn shoup_matrix_fma(&self, out: &mut M::R, a: &M, a_shoup: &M, b: &M);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ModularOpsU64<T> {
|
pub struct ModularOpsU64<T> {
|
||||||
@@ -254,18 +261,10 @@ impl<T> ArithmeticOps for ModularOpsU64<T> {
|
|||||||
self.add_mod_fast(*a, *b)
|
self.add_mod_fast(*a, *b)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
|
||||||
self.add_mod_fast_lazy(*a, *b)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||||
self.mul_mod_fast(*a, *b)
|
self.mul_mod_fast(*a, *b)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
|
||||||
self.mul_mod_fast_lazy(*a, *b)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||||
self.sub_mod_fast(*a, *b)
|
self.sub_mod_fast(*a, *b)
|
||||||
}
|
}
|
||||||
@@ -279,6 +278,16 @@ impl<T> ArithmeticOps for ModularOpsU64<T> {
|
|||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T> ArithmeticLazyOps for ModularOpsU64<T> {
|
||||||
|
type Element = u64;
|
||||||
|
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||||
|
self.add_mod_fast_lazy(*a, *b)
|
||||||
|
}
|
||||||
|
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||||
|
self.mul_mod_fast_lazy(*a, *b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T> VectorOps for ModularOpsU64<T> {
|
impl<T> VectorOps for ModularOpsU64<T> {
|
||||||
type Element = u64;
|
type Element = u64;
|
||||||
|
|
||||||
@@ -344,12 +353,11 @@ impl<T> VectorOps for ModularOpsU64<T> {
|
|||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> ModularOpsU64<T> {
|
impl<M: Matrix<MatElement = u64>, T> ShoupMatrixFMA<M> for ModularOpsU64<T>
|
||||||
/// Returns \sum a[i]b[i]
|
where
|
||||||
pub fn shoup_fma<M: Matrix<MatElement = u64>>(&self, out: &mut M::R, a: &M, a_shoup: &M, b: &M)
|
|
||||||
where
|
|
||||||
M::R: RowMut,
|
M::R: RowMut,
|
||||||
{
|
{
|
||||||
|
fn shoup_matrix_fma(&self, out: &mut <M as Matrix>::R, a: &M, a_shoup: &M, b: &M) {
|
||||||
assert!(a.dimension() == a_shoup.dimension());
|
assert!(a.dimension() == a_shoup.dimension());
|
||||||
assert!(a.dimension() == b.dimension());
|
assert!(a.dimension() == b.dimension());
|
||||||
|
|
||||||
@@ -393,6 +401,7 @@ impl<T> ModularOpsU64<T> {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> GetModulus for ModularOpsU64<T>
|
impl<T> GetModulus for ModularOpsU64<T>
|
||||||
where
|
where
|
||||||
T: Modulus,
|
T: Modulus,
|
||||||
@@ -430,18 +439,10 @@ where
|
|||||||
T::Element::wrapping_add(a, b)
|
T::Element::wrapping_add(a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
|
||||||
self.add(a, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||||
T::Element::wrapping_mul(a, b)
|
T::Element::wrapping_mul(a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
|
||||||
self.mul(a, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn neg(&self, a: &Self::Element) -> Self::Element {
|
fn neg(&self, a: &Self::Element) -> Self::Element {
|
||||||
T::Element::wrapping_sub(&T::Element::zero(), a)
|
T::Element::wrapping_sub(&T::Element::zero(), a)
|
||||||
}
|
}
|
||||||
@@ -531,3 +532,65 @@ where
|
|||||||
&self.modulus
|
&self.modulus
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use itertools::Itertools;
|
||||||
|
use rand::{thread_rng, Rng};
|
||||||
|
use rand_distr::Uniform;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fma() {
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
let prime = 36028797017456641;
|
||||||
|
let ring_size = 1 << 3;
|
||||||
|
|
||||||
|
let dist = Uniform::new(0, prime);
|
||||||
|
let d = 2;
|
||||||
|
let a0_matrix = (0..d)
|
||||||
|
.into_iter()
|
||||||
|
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
|
||||||
|
.collect_vec();
|
||||||
|
// a0 in shoup representation
|
||||||
|
let a0_shoup_matrix = a0_matrix
|
||||||
|
.iter()
|
||||||
|
.map(|r| {
|
||||||
|
r.iter()
|
||||||
|
.map(|v| {
|
||||||
|
// $(v * 2^{\beta}) / p$
|
||||||
|
((*v as u128 * (1u128 << 64)) / prime as u128) as u64
|
||||||
|
})
|
||||||
|
.collect_vec()
|
||||||
|
})
|
||||||
|
.collect_vec();
|
||||||
|
let a1_matrix = (0..d)
|
||||||
|
.into_iter()
|
||||||
|
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
let modop = ModularOpsU64::new(prime);
|
||||||
|
|
||||||
|
let mut out_shoup_fma_lazy = vec![0u64; ring_size];
|
||||||
|
modop.shoup_matrix_fma(
|
||||||
|
&mut out_shoup_fma_lazy,
|
||||||
|
&a0_matrix,
|
||||||
|
&a0_shoup_matrix,
|
||||||
|
&a1_matrix,
|
||||||
|
);
|
||||||
|
let out_shoup_fma = out_shoup_fma_lazy
|
||||||
|
.iter()
|
||||||
|
.map(|v| if *v >= prime { v - prime } else { *v })
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
// expected
|
||||||
|
let mut out_expected = vec![0u64; ring_size];
|
||||||
|
izip!(a0_matrix.iter(), a1_matrix.iter()).for_each(|(a_r, b_r)| {
|
||||||
|
izip!(out_expected.iter_mut(), a_r.iter(), b_r.iter()).for_each(|(o, a0, a1)| {
|
||||||
|
*o = (*o + ((*a0 as u128 * *a1 as u128) % prime as u128) as u64) % prime;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(out_expected, out_shoup_fma);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
24
src/utils.rs
24
src/utils.rs
@@ -234,26 +234,4 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {}
|
||||||
use rand::{thread_rng, Rng};
|
|
||||||
|
|
||||||
use super::ShoupMul;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn gg() {
|
|
||||||
let mut rng = thread_rng();
|
|
||||||
let p = 36028797018820609;
|
|
||||||
|
|
||||||
let a = rng.gen_range(0..p);
|
|
||||||
let b = rng.gen_range(0..p);
|
|
||||||
let a_shoup = ShoupMul::representation(a, p);
|
|
||||||
|
|
||||||
// let c = ShoupMul::mul(b, a, a_shoup, p);
|
|
||||||
// assert!(c == ((a as u128 * b as u128) % p as u128) as u64);
|
|
||||||
|
|
||||||
let mut quotient = ((a_shoup as u128 * b as u128) >> 64) as u64;
|
|
||||||
quotient -= 1;
|
|
||||||
let c = (b.wrapping_mul(a)).wrapping_sub(p.wrapping_mul(quotient));
|
|
||||||
assert!(c - p == ((a as u128 * b as u128) % p as u128) as u64);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user