Browse Source

bench shoup_fma against normal fma

par-agg-key-shares
Janmajaya Mall 10 months ago
parent
commit
1eed18881f
4 changed files with 190 additions and 102 deletions
  1. +59
    -10
      benches/modulus.rs
  2. +105
    -0
      src/backend.rs
  3. +1
    -1
      src/lib.rs
  4. +25
    -91
      src/utils.rs

+ 59
- 10
benches/modulus.rs

@ -1,14 +1,10 @@
use bin_rs::{Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, VectorOps};
use bin_rs::{ArithmeticOps, Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, VectorOps};
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use itertools::{izip, Itertools}; use itertools::{izip, Itertools};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use rand_distr::Uniform; use rand_distr::Uniform;
pub(crate) fn decompose_r(
r: &[u64],
decomp_r: &mut [Vec<u64>],
decomposer: &DefaultDecomposer<u64>,
) {
fn decompose_r(r: &[u64], decomp_r: &mut [Vec<u64>], decomposer: &DefaultDecomposer<u64>) {
let ring_size = r.len(); let ring_size = r.len();
// let d = decomposer.decomposition_count(); // let d = decomposer.decomposition_count();
// let mut count = 0; // let mut count = 0;
@ -24,6 +20,16 @@ pub(crate) fn decompose_r(
} }
} }
fn matrix_fma(out: &mut [u64], a: &Vec<Vec<u64>>, b: &Vec<Vec<u64>>, modop: &ModularOpsU64<u64>) {
izip!(out.iter_mut(), a[0].iter(), b[0].iter())
.for_each(|(o, ai, bi)| *o = modop.add(o, &modop.mul_lazy(ai, bi)));
izip!(a.iter().skip(1), b.iter().skip(1)).for_each(|(a_r, b_r)| {
izip!(out.iter_mut(), a_r.iter(), b_r.iter())
.for_each(|(o, ai, bi)| *o = modop.add_lazy(o, &modop.mul(ai, bi)));
});
}
fn benchmark_decomposer(c: &mut Criterion) { fn benchmark_decomposer(c: &mut Criterion) {
let mut group = c.benchmark_group("decomposer"); let mut group = c.benchmark_group("decomposer");
@ -69,7 +75,7 @@ fn benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("modulus"); let mut group = c.benchmark_group("modulus");
// 55 // 55
for prime in [36028797017456641] { for prime in [36028797017456641] {
for ring_size in [1 << 11, 1 << 15] {
for ring_size in [1 << 11] {
let modop = ModularOpsU64::new(prime); let modop = ModularOpsU64::new(prime);
let mut rng = thread_rng(); let mut rng = thread_rng();
@ -79,12 +85,55 @@ fn benchmark(c: &mut Criterion) {
let a1 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec(); let a1 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
let a2 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec(); let a2 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
let d = 2;
let a0_matrix = (0..d)
.into_iter()
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
.collect_vec();
// a0 in shoup representation
let a0_shoup_matrix = a0_matrix
.iter()
.map(|r| {
r.iter()
.map(|v| {
// $(v * 2^{\beta}) / p$
((*v as u128 * (1u128 << 64)) / prime as u128) as u64
})
.collect_vec()
})
.collect_vec();
let a1_matrix = (0..d)
.into_iter()
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
.collect_vec();
group.bench_function(
BenchmarkId::new("matrix_fma_lazy", format!("q={prime}/N={ring_size}/d={d}")),
|b| {
b.iter_batched_ref(
|| (vec![0u64; ring_size]),
|(out)| black_box(matrix_fma(out, &a0_matrix, &a1_matrix, &modop)),
criterion::BatchSize::PerIteration,
)
},
);
group.bench_function( group.bench_function(
BenchmarkId::new("elwise_fma", format!("q={prime}/{ring_size}")),
BenchmarkId::new(
"matrix_shoup_fma_lazy",
format!("q={prime}/N={ring_size}/d={d}"),
),
|b| { |b| {
b.iter_batched_ref( b.iter_batched_ref(
|| (a0.clone(), a1.clone(), a2.clone()),
|(a0, a1, a2)| black_box(modop.elwise_fma_mut(a0, a1, a2)),
|| (vec![0u64; ring_size]),
|(out)| {
black_box(modop.shoup_fma(
out,
&a0_matrix,
&a0_shoup_matrix,
&a1_matrix,
))
},
criterion::BatchSize::PerIteration, criterion::BatchSize::PerIteration,
) )
}, },

+ 105
- 0
src/backend.rs

@ -3,6 +3,8 @@ use std::marker::PhantomData;
use itertools::izip; use itertools::izip;
use num_traits::{PrimInt, Signed, ToPrimitive, WrappingAdd, WrappingMul, WrappingSub, Zero}; use num_traits::{PrimInt, Signed, ToPrimitive, WrappingAdd, WrappingMul, WrappingSub, Zero};
use crate::{utils::ShoupMul, Matrix, RowMut};
pub trait Modulus { pub trait Modulus {
type Element; type Element;
/// Modulus value if it fits in Element /// Modulus value if it fits in Element
@ -113,7 +115,9 @@ pub trait ArithmeticOps {
type Element; type Element;
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
fn neg(&self, a: &Self::Element) -> Self::Element; fn neg(&self, a: &Self::Element) -> Self::Element;
@ -122,6 +126,7 @@ pub trait ArithmeticOps {
pub struct ModularOpsU64<T> { pub struct ModularOpsU64<T> {
q: u64, q: u64,
q_twice: u64,
logq: usize, logq: usize,
barrett_mu: u128, barrett_mu: u128,
barrett_alpha: usize, barrett_alpha: usize,
@ -146,6 +151,7 @@ where
ModularOpsU64 { ModularOpsU64 {
q, q,
q_twice: q << 1,
logq: logq as usize, logq: logq as usize,
barrett_alpha: alpha as usize, barrett_alpha: alpha as usize,
barrett_mu: mu, barrett_mu: mu,
@ -166,6 +172,17 @@ impl ModularOpsU64 {
o o
} }
fn add_mod_fast_lazy(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < self.q_twice);
debug_assert!(b < self.q_twice);
let mut o = a + b;
if o >= self.q_twice {
o -= self.q_twice;
}
o
}
fn sub_mod_fast(&self, a: u64, b: u64) -> u64 { fn sub_mod_fast(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < self.q); debug_assert!(a < self.q);
debug_assert!(b < self.q); debug_assert!(b < self.q);
@ -177,6 +194,29 @@ impl ModularOpsU64 {
} }
} }
// returns (a * b) % q
///
/// - both a and b must be in range [0, 2q)
/// - output is in range [0 , 2q)
fn mul_mod_fast_lazy(&self, a: u64, b: u64) -> u64 {
debug_assert!(a < 2 * self.q);
debug_assert!(b < 2 * self.q);
let ab = a as u128 * b as u128;
// ab / (2^{n + \beta})
// note: \beta is assumed to -2
let tmp = ab >> (self.logq - 2);
// k = ((ab / (2^{n + \beta})) * \mu) / 2^{\alpha - (-2)}
let k = (tmp * self.barrett_mu) >> (self.barrett_alpha + 2);
// ab - k*p
let tmp = k * (self.q as u128);
(ab - tmp) as u64
}
/// returns (a * b) % q /// returns (a * b) % q
/// ///
/// - both a and b must be in range [0, 2q) /// - both a and b must be in range [0, 2q)
@ -214,10 +254,18 @@ impl ArithmeticOps for ModularOpsU64 {
self.add_mod_fast(*a, *b) self.add_mod_fast(*a, *b)
} }
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.add_mod_fast_lazy(*a, *b)
}
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.mul_mod_fast(*a, *b) self.mul_mod_fast(*a, *b)
} }
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.mul_mod_fast_lazy(*a, *b)
}
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.sub_mod_fast(*a, *b) self.sub_mod_fast(*a, *b)
} }
@ -296,6 +344,55 @@ impl VectorOps for ModularOpsU64 {
// } // }
} }
impl<T> ModularOpsU64<T> {
/// Returns \sum a[i]b[i]
pub fn shoup_fma<M: Matrix<MatElement = u64>>(&self, out: &mut M::R, a: &M, a_shoup: &M, b: &M)
where
M::R: RowMut,
{
assert!(a.dimension() == a_shoup.dimension());
assert!(a.dimension() == b.dimension());
let q = self.q;
let q_twice = self.q << 1;
// first row (without summation)
izip!(
out.as_mut().iter_mut(),
a.get_row(0),
a_shoup.get_row(0),
b.get_row(0)
)
.for_each(|(o, a, a_shoup, b)| {
*o = ShoupMul::mul(*b, *a, *a_shoup, q);
});
izip!(
a.iter_rows().skip(1),
a_shoup.iter_rows().skip(1),
b.iter_rows().skip(1)
)
.for_each(|(a_row, a_shoup_row, b_row)| {
izip!(
out.as_mut().iter_mut(),
a_row.as_ref().iter(),
a_shoup_row.as_ref().iter(),
b_row.as_ref().iter()
)
.for_each(|(o, a0, a0_shoup, b0)| {
let quotient = ((*a0_shoup as u128 * *b0 as u128) >> 64) as u64;
let mut v = (a0.wrapping_mul(b0)).wrapping_add(*o);
v = v.wrapping_sub(q.wrapping_mul(quotient));
if v >= q_twice {
v -= q_twice;
}
*o = v;
});
});
}
}
impl<T> GetModulus for ModularOpsU64<T> impl<T> GetModulus for ModularOpsU64<T>
where where
T: Modulus, T: Modulus,
@ -333,10 +430,18 @@ where
T::Element::wrapping_add(a, b) T::Element::wrapping_add(a, b)
} }
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.add(a, b)
}
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
T::Element::wrapping_mul(a, b) T::Element::wrapping_mul(a, b)
} }
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
self.mul(a, b)
}
fn neg(&self, a: &Self::Element) -> Self::Element { fn neg(&self, a: &Self::Element) -> Self::Element {
T::Element::wrapping_sub(&T::Element::zero(), a) T::Element::wrapping_sub(&T::Element::zero(), a)
} }

+ 1
- 1
src/lib.rs

@ -20,7 +20,7 @@ mod rgsw;
mod shortint; mod shortint;
mod utils; mod utils;
pub use backend::{ModInit, ModularOpsU64, VectorOps};
pub use backend::{ArithmeticOps, ModInit, ModularOpsU64, VectorOps};
pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer}; pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer};
pub use ntt::{Ntt, NttBackendU64, NttInit}; pub use ntt::{Ntt, NttBackendU64, NttInit};

+ 25
- 91
src/utils.rs

@ -194,97 +194,6 @@ impl TryConvertFrom1<[P::Element], P> for Vec {
} }
} }
// pub trait TryConvertFrom<T: ?Sized> {
// type Parameters: ?Sized;
// fn try_convert_from(value: &T, parameters: &Self::Parameters) -> Self;
// }
// impl TryConvertFrom1<[i32]> for Vec<Vec<u32>> {
// type Parameters = u32;
// fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self
// { let row0 = value
// .iter()
// .map(|v| {
// let is_neg = v.is_negative();
// let v_u32 = v.abs() as u32;
// assert!(v_u32 < *parameters);
// if is_neg {
// parameters - v_u32
// } else {
// v_u32
// }
// })
// .collect_vec();
// vec![row0]
// }
// }
// impl TryConvertFrom1<[i32]> for Vec<Vec<u64>> {
// type Parameters = u64;
// fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self
// { let row0 = value
// .iter()
// .map(|v| {
// let is_neg = v.is_negative();
// let v_u64 = v.abs() as u64;
// assert!(v_u64 < *parameters);
// if is_neg {
// parameters - v_u64
// } else {
// v_u64
// }
// })
// .collect_vec();
// vec![row0]
// }
// }
// impl TryConvertFrom1<[i32]> for Vec<u64> {
// type Parameters = u64;
// fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self
// { value
// .iter()
// .map(|v| {
// let is_neg = v.is_negative();
// let v_u64 = v.abs() as u64;
// assert!(v_u64 < *parameters);
// if is_neg {
// parameters - v_u64
// } else {
// v_u64
// }
// })
// .collect_vec()
// }
// }
// impl TryConvertFrom1<[u64]> for Vec<i64> {
// type Parameters = u64;
// fn try_convert_from(value: &[u64], parameters: &Self::Parameters) -> Self
// { let q = *parameters;
// let qby2 = q / 2;
// value
// .iter()
// .map(|v| {
// if *v > qby2 {
// -((q - v) as i64)
// } else {
// *v as i64
// }
// })
// .collect_vec()
// }
// }
pub(crate) struct Stats<T> { pub(crate) struct Stats<T> {
pub(crate) samples: Vec<T>, pub(crate) samples: Vec<T>,
} }
@ -323,3 +232,28 @@ where
self.samples.extend(values.iter()); self.samples.extend(values.iter());
} }
} }
#[cfg(test)]
mod tests {
use rand::{thread_rng, Rng};
use super::ShoupMul;
#[test]
fn gg() {
let mut rng = thread_rng();
let p = 36028797018820609;
let a = rng.gen_range(0..p);
let b = rng.gen_range(0..p);
let a_shoup = ShoupMul::representation(a, p);
// let c = ShoupMul::mul(b, a, a_shoup, p);
// assert!(c == ((a as u128 * b as u128) % p as u128) as u64);
let mut quotient = ((a_shoup as u128 * b as u128) >> 64) as u64;
quotient -= 1;
let c = (b.wrapping_mul(a)).wrapping_sub(p.wrapping_mul(quotient));
assert!(c - p == ((a as u128 * b as u128) % p as u128) as u64);
}
}

Loading…
Cancel
Save