bench shoup_fma against normal fma

This commit is contained in:
Janmajaya Mall
2024-06-10 16:28:46 +05:30
parent 0f496a1032
commit 1eed18881f
4 changed files with 190 additions and 102 deletions

View File

@@ -1,14 +1,10 @@
use bin_rs::{Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, VectorOps};
use bin_rs::{ArithmeticOps, Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, VectorOps};
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use itertools::{izip, Itertools};
use rand::{thread_rng, Rng};
use rand_distr::Uniform;
pub(crate) fn decompose_r(
r: &[u64],
decomp_r: &mut [Vec<u64>],
decomposer: &DefaultDecomposer<u64>,
) {
fn decompose_r(r: &[u64], decomp_r: &mut [Vec<u64>], decomposer: &DefaultDecomposer<u64>) {
let ring_size = r.len();
// let d = decomposer.decomposition_count();
// let mut count = 0;
@@ -24,6 +20,16 @@ pub(crate) fn decompose_r(
}
}
fn matrix_fma(out: &mut [u64], a: &Vec<Vec<u64>>, b: &Vec<Vec<u64>>, modop: &ModularOpsU64<u64>) {
izip!(out.iter_mut(), a[0].iter(), b[0].iter())
.for_each(|(o, ai, bi)| *o = modop.add(o, &modop.mul_lazy(ai, bi)));
izip!(a.iter().skip(1), b.iter().skip(1)).for_each(|(a_r, b_r)| {
izip!(out.iter_mut(), a_r.iter(), b_r.iter())
.for_each(|(o, ai, bi)| *o = modop.add_lazy(o, &modop.mul(ai, bi)));
});
}
fn benchmark_decomposer(c: &mut Criterion) {
let mut group = c.benchmark_group("decomposer");
@@ -69,7 +75,7 @@ fn benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("modulus");
// 55
for prime in [36028797017456641] {
for ring_size in [1 << 11, 1 << 15] {
for ring_size in [1 << 11] {
let modop = ModularOpsU64::new(prime);
let mut rng = thread_rng();
@@ -79,12 +85,55 @@ fn benchmark(c: &mut Criterion) {
let a1 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
let a2 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
let d = 2;
let a0_matrix = (0..d)
.into_iter()
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
.collect_vec();
// a0 in shoup representation
let a0_shoup_matrix = a0_matrix
.iter()
.map(|r| {
r.iter()
.map(|v| {
// $(v * 2^{\beta}) / p$
((*v as u128 * (1u128 << 64)) / prime as u128) as u64
})
.collect_vec()
})
.collect_vec();
let a1_matrix = (0..d)
.into_iter()
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
.collect_vec();
group.bench_function(
BenchmarkId::new("elwise_fma", format!("q={prime}/{ring_size}")),
BenchmarkId::new("matrix_fma_lazy", format!("q={prime}/N={ring_size}/d={d}")),
|b| {
b.iter_batched_ref(
|| (a0.clone(), a1.clone(), a2.clone()),
|(a0, a1, a2)| black_box(modop.elwise_fma_mut(a0, a1, a2)),
|| (vec![0u64; ring_size]),
|(out)| black_box(matrix_fma(out, &a0_matrix, &a1_matrix, &modop)),
criterion::BatchSize::PerIteration,
)
},
);
group.bench_function(
BenchmarkId::new(
"matrix_shoup_fma_lazy",
format!("q={prime}/N={ring_size}/d={d}"),
),
|b| {
b.iter_batched_ref(
|| (vec![0u64; ring_size]),
|(out)| {
black_box(modop.shoup_fma(
out,
&a0_matrix,
&a0_shoup_matrix,
&a1_matrix,
))
},
criterion::BatchSize::PerIteration,
)
},