mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-10 16:11:30 +01:00
bench shoup_fma against normal fma
This commit is contained in:
@@ -1,14 +1,10 @@
|
||||
use bin_rs::{Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, VectorOps};
|
||||
use bin_rs::{ArithmeticOps, Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, VectorOps};
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use itertools::{izip, Itertools};
|
||||
use rand::{thread_rng, Rng};
|
||||
use rand_distr::Uniform;
|
||||
|
||||
pub(crate) fn decompose_r(
|
||||
r: &[u64],
|
||||
decomp_r: &mut [Vec<u64>],
|
||||
decomposer: &DefaultDecomposer<u64>,
|
||||
) {
|
||||
fn decompose_r(r: &[u64], decomp_r: &mut [Vec<u64>], decomposer: &DefaultDecomposer<u64>) {
|
||||
let ring_size = r.len();
|
||||
// let d = decomposer.decomposition_count();
|
||||
// let mut count = 0;
|
||||
@@ -24,6 +20,16 @@ pub(crate) fn decompose_r(
|
||||
}
|
||||
}
|
||||
|
||||
fn matrix_fma(out: &mut [u64], a: &Vec<Vec<u64>>, b: &Vec<Vec<u64>>, modop: &ModularOpsU64<u64>) {
|
||||
izip!(out.iter_mut(), a[0].iter(), b[0].iter())
|
||||
.for_each(|(o, ai, bi)| *o = modop.add(o, &modop.mul_lazy(ai, bi)));
|
||||
|
||||
izip!(a.iter().skip(1), b.iter().skip(1)).for_each(|(a_r, b_r)| {
|
||||
izip!(out.iter_mut(), a_r.iter(), b_r.iter())
|
||||
.for_each(|(o, ai, bi)| *o = modop.add_lazy(o, &modop.mul(ai, bi)));
|
||||
});
|
||||
}
|
||||
|
||||
fn benchmark_decomposer(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("decomposer");
|
||||
|
||||
@@ -69,7 +75,7 @@ fn benchmark(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("modulus");
|
||||
// 55
|
||||
for prime in [36028797017456641] {
|
||||
for ring_size in [1 << 11, 1 << 15] {
|
||||
for ring_size in [1 << 11] {
|
||||
let modop = ModularOpsU64::new(prime);
|
||||
|
||||
let mut rng = thread_rng();
|
||||
@@ -79,12 +85,55 @@ fn benchmark(c: &mut Criterion) {
|
||||
let a1 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
|
||||
let a2 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
|
||||
|
||||
let d = 2;
|
||||
let a0_matrix = (0..d)
|
||||
.into_iter()
|
||||
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
|
||||
.collect_vec();
|
||||
// a0 in shoup representation
|
||||
let a0_shoup_matrix = a0_matrix
|
||||
.iter()
|
||||
.map(|r| {
|
||||
r.iter()
|
||||
.map(|v| {
|
||||
// $(v * 2^{\beta}) / p$
|
||||
((*v as u128 * (1u128 << 64)) / prime as u128) as u64
|
||||
})
|
||||
.collect_vec()
|
||||
})
|
||||
.collect_vec();
|
||||
let a1_matrix = (0..d)
|
||||
.into_iter()
|
||||
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
|
||||
.collect_vec();
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new("elwise_fma", format!("q={prime}/{ring_size}")),
|
||||
BenchmarkId::new("matrix_fma_lazy", format!("q={prime}/N={ring_size}/d={d}")),
|
||||
|b| {
|
||||
b.iter_batched_ref(
|
||||
|| (a0.clone(), a1.clone(), a2.clone()),
|
||||
|(a0, a1, a2)| black_box(modop.elwise_fma_mut(a0, a1, a2)),
|
||||
|| (vec![0u64; ring_size]),
|
||||
|(out)| black_box(matrix_fma(out, &a0_matrix, &a1_matrix, &modop)),
|
||||
criterion::BatchSize::PerIteration,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_function(
|
||||
BenchmarkId::new(
|
||||
"matrix_shoup_fma_lazy",
|
||||
format!("q={prime}/N={ring_size}/d={d}"),
|
||||
),
|
||||
|b| {
|
||||
b.iter_batched_ref(
|
||||
|| (vec![0u64; ring_size]),
|
||||
|(out)| {
|
||||
black_box(modop.shoup_fma(
|
||||
out,
|
||||
&a0_matrix,
|
||||
&a0_shoup_matrix,
|
||||
&a1_matrix,
|
||||
))
|
||||
},
|
||||
criterion::BatchSize::PerIteration,
|
||||
)
|
||||
},
|
||||
|
||||
105
src/backend.rs
105
src/backend.rs
@@ -3,6 +3,8 @@ use std::marker::PhantomData;
|
||||
use itertools::izip;
|
||||
use num_traits::{PrimInt, Signed, ToPrimitive, WrappingAdd, WrappingMul, WrappingSub, Zero};
|
||||
|
||||
use crate::{utils::ShoupMul, Matrix, RowMut};
|
||||
|
||||
pub trait Modulus {
|
||||
type Element;
|
||||
/// Modulus value if it fits in Element
|
||||
@@ -113,7 +115,9 @@ pub trait ArithmeticOps {
|
||||
type Element;
|
||||
|
||||
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||
fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||
fn neg(&self, a: &Self::Element) -> Self::Element;
|
||||
|
||||
@@ -122,6 +126,7 @@ pub trait ArithmeticOps {
|
||||
|
||||
pub struct ModularOpsU64<T> {
|
||||
q: u64,
|
||||
q_twice: u64,
|
||||
logq: usize,
|
||||
barrett_mu: u128,
|
||||
barrett_alpha: usize,
|
||||
@@ -146,6 +151,7 @@ where
|
||||
|
||||
ModularOpsU64 {
|
||||
q,
|
||||
q_twice: q << 1,
|
||||
logq: logq as usize,
|
||||
barrett_alpha: alpha as usize,
|
||||
barrett_mu: mu,
|
||||
@@ -166,6 +172,17 @@ impl<T> ModularOpsU64<T> {
|
||||
o
|
||||
}
|
||||
|
||||
fn add_mod_fast_lazy(&self, a: u64, b: u64) -> u64 {
|
||||
debug_assert!(a < self.q_twice);
|
||||
debug_assert!(b < self.q_twice);
|
||||
|
||||
let mut o = a + b;
|
||||
if o >= self.q_twice {
|
||||
o -= self.q_twice;
|
||||
}
|
||||
o
|
||||
}
|
||||
|
||||
fn sub_mod_fast(&self, a: u64, b: u64) -> u64 {
|
||||
debug_assert!(a < self.q);
|
||||
debug_assert!(b < self.q);
|
||||
@@ -177,6 +194,29 @@ impl<T> ModularOpsU64<T> {
|
||||
}
|
||||
}
|
||||
|
||||
// returns (a * b) % q
|
||||
///
|
||||
/// - both a and b must be in range [0, 2q)
|
||||
/// - output is in range [0 , 2q)
|
||||
fn mul_mod_fast_lazy(&self, a: u64, b: u64) -> u64 {
|
||||
debug_assert!(a < 2 * self.q);
|
||||
debug_assert!(b < 2 * self.q);
|
||||
|
||||
let ab = a as u128 * b as u128;
|
||||
|
||||
// ab / (2^{n + \beta})
|
||||
// note: \beta is assumed to -2
|
||||
let tmp = ab >> (self.logq - 2);
|
||||
|
||||
// k = ((ab / (2^{n + \beta})) * \mu) / 2^{\alpha - (-2)}
|
||||
let k = (tmp * self.barrett_mu) >> (self.barrett_alpha + 2);
|
||||
|
||||
// ab - k*p
|
||||
let tmp = k * (self.q as u128);
|
||||
|
||||
(ab - tmp) as u64
|
||||
}
|
||||
|
||||
/// returns (a * b) % q
|
||||
///
|
||||
/// - both a and b must be in range [0, 2q)
|
||||
@@ -214,10 +254,18 @@ impl<T> ArithmeticOps for ModularOpsU64<T> {
|
||||
self.add_mod_fast(*a, *b)
|
||||
}
|
||||
|
||||
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||
self.add_mod_fast_lazy(*a, *b)
|
||||
}
|
||||
|
||||
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||
self.mul_mod_fast(*a, *b)
|
||||
}
|
||||
|
||||
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||
self.mul_mod_fast_lazy(*a, *b)
|
||||
}
|
||||
|
||||
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||
self.sub_mod_fast(*a, *b)
|
||||
}
|
||||
@@ -296,6 +344,55 @@ impl<T> VectorOps for ModularOpsU64<T> {
|
||||
// }
|
||||
}
|
||||
|
||||
impl<T> ModularOpsU64<T> {
|
||||
/// Returns \sum a[i]b[i]
|
||||
pub fn shoup_fma<M: Matrix<MatElement = u64>>(&self, out: &mut M::R, a: &M, a_shoup: &M, b: &M)
|
||||
where
|
||||
M::R: RowMut,
|
||||
{
|
||||
assert!(a.dimension() == a_shoup.dimension());
|
||||
assert!(a.dimension() == b.dimension());
|
||||
|
||||
let q = self.q;
|
||||
let q_twice = self.q << 1;
|
||||
|
||||
// first row (without summation)
|
||||
izip!(
|
||||
out.as_mut().iter_mut(),
|
||||
a.get_row(0),
|
||||
a_shoup.get_row(0),
|
||||
b.get_row(0)
|
||||
)
|
||||
.for_each(|(o, a, a_shoup, b)| {
|
||||
*o = ShoupMul::mul(*b, *a, *a_shoup, q);
|
||||
});
|
||||
|
||||
izip!(
|
||||
a.iter_rows().skip(1),
|
||||
a_shoup.iter_rows().skip(1),
|
||||
b.iter_rows().skip(1)
|
||||
)
|
||||
.for_each(|(a_row, a_shoup_row, b_row)| {
|
||||
izip!(
|
||||
out.as_mut().iter_mut(),
|
||||
a_row.as_ref().iter(),
|
||||
a_shoup_row.as_ref().iter(),
|
||||
b_row.as_ref().iter()
|
||||
)
|
||||
.for_each(|(o, a0, a0_shoup, b0)| {
|
||||
let quotient = ((*a0_shoup as u128 * *b0 as u128) >> 64) as u64;
|
||||
let mut v = (a0.wrapping_mul(b0)).wrapping_add(*o);
|
||||
v = v.wrapping_sub(q.wrapping_mul(quotient));
|
||||
|
||||
if v >= q_twice {
|
||||
v -= q_twice;
|
||||
}
|
||||
|
||||
*o = v;
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
impl<T> GetModulus for ModularOpsU64<T>
|
||||
where
|
||||
T: Modulus,
|
||||
@@ -333,10 +430,18 @@ where
|
||||
T::Element::wrapping_add(a, b)
|
||||
}
|
||||
|
||||
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||
self.add(a, b)
|
||||
}
|
||||
|
||||
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||
T::Element::wrapping_mul(a, b)
|
||||
}
|
||||
|
||||
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||
self.mul(a, b)
|
||||
}
|
||||
|
||||
fn neg(&self, a: &Self::Element) -> Self::Element {
|
||||
T::Element::wrapping_sub(&T::Element::zero(), a)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ mod rgsw;
|
||||
mod shortint;
|
||||
mod utils;
|
||||
|
||||
pub use backend::{ModInit, ModularOpsU64, VectorOps};
|
||||
pub use backend::{ArithmeticOps, ModInit, ModularOpsU64, VectorOps};
|
||||
pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer};
|
||||
pub use ntt::{Ntt, NttBackendU64, NttInit};
|
||||
|
||||
|
||||
116
src/utils.rs
116
src/utils.rs
@@ -194,97 +194,6 @@ impl<P: Modulus> TryConvertFrom1<[P::Element], P> for Vec<i64> {
|
||||
}
|
||||
}
|
||||
|
||||
// pub trait TryConvertFrom<T: ?Sized> {
|
||||
// type Parameters: ?Sized;
|
||||
|
||||
// fn try_convert_from(value: &T, parameters: &Self::Parameters) -> Self;
|
||||
// }
|
||||
|
||||
// impl TryConvertFrom1<[i32]> for Vec<Vec<u32>> {
|
||||
// type Parameters = u32;
|
||||
// fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self
|
||||
// { let row0 = value
|
||||
// .iter()
|
||||
// .map(|v| {
|
||||
// let is_neg = v.is_negative();
|
||||
// let v_u32 = v.abs() as u32;
|
||||
|
||||
// assert!(v_u32 < *parameters);
|
||||
|
||||
// if is_neg {
|
||||
// parameters - v_u32
|
||||
// } else {
|
||||
// v_u32
|
||||
// }
|
||||
// })
|
||||
// .collect_vec();
|
||||
|
||||
// vec![row0]
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl TryConvertFrom1<[i32]> for Vec<Vec<u64>> {
|
||||
// type Parameters = u64;
|
||||
// fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self
|
||||
// { let row0 = value
|
||||
// .iter()
|
||||
// .map(|v| {
|
||||
// let is_neg = v.is_negative();
|
||||
// let v_u64 = v.abs() as u64;
|
||||
|
||||
// assert!(v_u64 < *parameters);
|
||||
|
||||
// if is_neg {
|
||||
// parameters - v_u64
|
||||
// } else {
|
||||
// v_u64
|
||||
// }
|
||||
// })
|
||||
// .collect_vec();
|
||||
|
||||
// vec![row0]
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl TryConvertFrom1<[i32]> for Vec<u64> {
|
||||
// type Parameters = u64;
|
||||
// fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self
|
||||
// { value
|
||||
// .iter()
|
||||
// .map(|v| {
|
||||
// let is_neg = v.is_negative();
|
||||
// let v_u64 = v.abs() as u64;
|
||||
|
||||
// assert!(v_u64 < *parameters);
|
||||
|
||||
// if is_neg {
|
||||
// parameters - v_u64
|
||||
// } else {
|
||||
// v_u64
|
||||
// }
|
||||
// })
|
||||
// .collect_vec()
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl TryConvertFrom1<[u64]> for Vec<i64> {
|
||||
// type Parameters = u64;
|
||||
// fn try_convert_from(value: &[u64], parameters: &Self::Parameters) -> Self
|
||||
// { let q = *parameters;
|
||||
// let qby2 = q / 2;
|
||||
// value
|
||||
// .iter()
|
||||
// .map(|v| {
|
||||
// if *v > qby2 {
|
||||
// -((q - v) as i64)
|
||||
// } else {
|
||||
// *v as i64
|
||||
// }
|
||||
// })
|
||||
// .collect_vec()
|
||||
// }
|
||||
// }
|
||||
|
||||
pub(crate) struct Stats<T> {
|
||||
pub(crate) samples: Vec<T>,
|
||||
}
|
||||
@@ -323,3 +232,28 @@ where
|
||||
self.samples.extend(values.iter());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
use super::ShoupMul;
|
||||
|
||||
#[test]
|
||||
fn gg() {
|
||||
let mut rng = thread_rng();
|
||||
let p = 36028797018820609;
|
||||
|
||||
let a = rng.gen_range(0..p);
|
||||
let b = rng.gen_range(0..p);
|
||||
let a_shoup = ShoupMul::representation(a, p);
|
||||
|
||||
// let c = ShoupMul::mul(b, a, a_shoup, p);
|
||||
// assert!(c == ((a as u128 * b as u128) % p as u128) as u64);
|
||||
|
||||
let mut quotient = ((a_shoup as u128 * b as u128) >> 64) as u64;
|
||||
quotient -= 1;
|
||||
let c = (b.wrapping_mul(a)).wrapping_sub(p.wrapping_mul(quotient));
|
||||
assert!(c - p == ((a as u128 * b as u128) % p as u128) as u64);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user