mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-10 16:11:30 +01:00
bench shoup_fma against normal fma
This commit is contained in:
@@ -1,14 +1,10 @@
|
|||||||
use bin_rs::{Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, VectorOps};
|
use bin_rs::{ArithmeticOps, Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, VectorOps};
|
||||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||||
use itertools::{izip, Itertools};
|
use itertools::{izip, Itertools};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use rand_distr::Uniform;
|
use rand_distr::Uniform;
|
||||||
|
|
||||||
pub(crate) fn decompose_r(
|
fn decompose_r(r: &[u64], decomp_r: &mut [Vec<u64>], decomposer: &DefaultDecomposer<u64>) {
|
||||||
r: &[u64],
|
|
||||||
decomp_r: &mut [Vec<u64>],
|
|
||||||
decomposer: &DefaultDecomposer<u64>,
|
|
||||||
) {
|
|
||||||
let ring_size = r.len();
|
let ring_size = r.len();
|
||||||
// let d = decomposer.decomposition_count();
|
// let d = decomposer.decomposition_count();
|
||||||
// let mut count = 0;
|
// let mut count = 0;
|
||||||
@@ -24,6 +20,16 @@ pub(crate) fn decompose_r(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn matrix_fma(out: &mut [u64], a: &Vec<Vec<u64>>, b: &Vec<Vec<u64>>, modop: &ModularOpsU64<u64>) {
|
||||||
|
izip!(out.iter_mut(), a[0].iter(), b[0].iter())
|
||||||
|
.for_each(|(o, ai, bi)| *o = modop.add(o, &modop.mul_lazy(ai, bi)));
|
||||||
|
|
||||||
|
izip!(a.iter().skip(1), b.iter().skip(1)).for_each(|(a_r, b_r)| {
|
||||||
|
izip!(out.iter_mut(), a_r.iter(), b_r.iter())
|
||||||
|
.for_each(|(o, ai, bi)| *o = modop.add_lazy(o, &modop.mul(ai, bi)));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
fn benchmark_decomposer(c: &mut Criterion) {
|
fn benchmark_decomposer(c: &mut Criterion) {
|
||||||
let mut group = c.benchmark_group("decomposer");
|
let mut group = c.benchmark_group("decomposer");
|
||||||
|
|
||||||
@@ -69,7 +75,7 @@ fn benchmark(c: &mut Criterion) {
|
|||||||
let mut group = c.benchmark_group("modulus");
|
let mut group = c.benchmark_group("modulus");
|
||||||
// 55
|
// 55
|
||||||
for prime in [36028797017456641] {
|
for prime in [36028797017456641] {
|
||||||
for ring_size in [1 << 11, 1 << 15] {
|
for ring_size in [1 << 11] {
|
||||||
let modop = ModularOpsU64::new(prime);
|
let modop = ModularOpsU64::new(prime);
|
||||||
|
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
@@ -79,12 +85,55 @@ fn benchmark(c: &mut Criterion) {
|
|||||||
let a1 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
|
let a1 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
|
||||||
let a2 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
|
let a2 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec();
|
||||||
|
|
||||||
|
let d = 2;
|
||||||
|
let a0_matrix = (0..d)
|
||||||
|
.into_iter()
|
||||||
|
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
|
||||||
|
.collect_vec();
|
||||||
|
// a0 in shoup representation
|
||||||
|
let a0_shoup_matrix = a0_matrix
|
||||||
|
.iter()
|
||||||
|
.map(|r| {
|
||||||
|
r.iter()
|
||||||
|
.map(|v| {
|
||||||
|
// $(v * 2^{\beta}) / p$
|
||||||
|
((*v as u128 * (1u128 << 64)) / prime as u128) as u64
|
||||||
|
})
|
||||||
|
.collect_vec()
|
||||||
|
})
|
||||||
|
.collect_vec();
|
||||||
|
let a1_matrix = (0..d)
|
||||||
|
.into_iter()
|
||||||
|
.map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
group.bench_function(
|
group.bench_function(
|
||||||
BenchmarkId::new("elwise_fma", format!("q={prime}/{ring_size}")),
|
BenchmarkId::new("matrix_fma_lazy", format!("q={prime}/N={ring_size}/d={d}")),
|
||||||
|b| {
|
|b| {
|
||||||
b.iter_batched_ref(
|
b.iter_batched_ref(
|
||||||
|| (a0.clone(), a1.clone(), a2.clone()),
|
|| (vec![0u64; ring_size]),
|
||||||
|(a0, a1, a2)| black_box(modop.elwise_fma_mut(a0, a1, a2)),
|
|(out)| black_box(matrix_fma(out, &a0_matrix, &a1_matrix, &modop)),
|
||||||
|
criterion::BatchSize::PerIteration,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
group.bench_function(
|
||||||
|
BenchmarkId::new(
|
||||||
|
"matrix_shoup_fma_lazy",
|
||||||
|
format!("q={prime}/N={ring_size}/d={d}"),
|
||||||
|
),
|
||||||
|
|b| {
|
||||||
|
b.iter_batched_ref(
|
||||||
|
|| (vec![0u64; ring_size]),
|
||||||
|
|(out)| {
|
||||||
|
black_box(modop.shoup_fma(
|
||||||
|
out,
|
||||||
|
&a0_matrix,
|
||||||
|
&a0_shoup_matrix,
|
||||||
|
&a1_matrix,
|
||||||
|
))
|
||||||
|
},
|
||||||
criterion::BatchSize::PerIteration,
|
criterion::BatchSize::PerIteration,
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
|||||||
105
src/backend.rs
105
src/backend.rs
@@ -3,6 +3,8 @@ use std::marker::PhantomData;
|
|||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use num_traits::{PrimInt, Signed, ToPrimitive, WrappingAdd, WrappingMul, WrappingSub, Zero};
|
use num_traits::{PrimInt, Signed, ToPrimitive, WrappingAdd, WrappingMul, WrappingSub, Zero};
|
||||||
|
|
||||||
|
use crate::{utils::ShoupMul, Matrix, RowMut};
|
||||||
|
|
||||||
pub trait Modulus {
|
pub trait Modulus {
|
||||||
type Element;
|
type Element;
|
||||||
/// Modulus value if it fits in Element
|
/// Modulus value if it fits in Element
|
||||||
@@ -113,7 +115,9 @@ pub trait ArithmeticOps {
|
|||||||
type Element;
|
type Element;
|
||||||
|
|
||||||
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||||
|
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||||
fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||||
|
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||||
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
||||||
fn neg(&self, a: &Self::Element) -> Self::Element;
|
fn neg(&self, a: &Self::Element) -> Self::Element;
|
||||||
|
|
||||||
@@ -122,6 +126,7 @@ pub trait ArithmeticOps {
|
|||||||
|
|
||||||
pub struct ModularOpsU64<T> {
|
pub struct ModularOpsU64<T> {
|
||||||
q: u64,
|
q: u64,
|
||||||
|
q_twice: u64,
|
||||||
logq: usize,
|
logq: usize,
|
||||||
barrett_mu: u128,
|
barrett_mu: u128,
|
||||||
barrett_alpha: usize,
|
barrett_alpha: usize,
|
||||||
@@ -146,6 +151,7 @@ where
|
|||||||
|
|
||||||
ModularOpsU64 {
|
ModularOpsU64 {
|
||||||
q,
|
q,
|
||||||
|
q_twice: q << 1,
|
||||||
logq: logq as usize,
|
logq: logq as usize,
|
||||||
barrett_alpha: alpha as usize,
|
barrett_alpha: alpha as usize,
|
||||||
barrett_mu: mu,
|
barrett_mu: mu,
|
||||||
@@ -166,6 +172,17 @@ impl<T> ModularOpsU64<T> {
|
|||||||
o
|
o
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn add_mod_fast_lazy(&self, a: u64, b: u64) -> u64 {
|
||||||
|
debug_assert!(a < self.q_twice);
|
||||||
|
debug_assert!(b < self.q_twice);
|
||||||
|
|
||||||
|
let mut o = a + b;
|
||||||
|
if o >= self.q_twice {
|
||||||
|
o -= self.q_twice;
|
||||||
|
}
|
||||||
|
o
|
||||||
|
}
|
||||||
|
|
||||||
fn sub_mod_fast(&self, a: u64, b: u64) -> u64 {
|
fn sub_mod_fast(&self, a: u64, b: u64) -> u64 {
|
||||||
debug_assert!(a < self.q);
|
debug_assert!(a < self.q);
|
||||||
debug_assert!(b < self.q);
|
debug_assert!(b < self.q);
|
||||||
@@ -177,6 +194,29 @@ impl<T> ModularOpsU64<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// returns (a * b) % q
|
||||||
|
///
|
||||||
|
/// - both a and b must be in range [0, 2q)
|
||||||
|
/// - output is in range [0 , 2q)
|
||||||
|
fn mul_mod_fast_lazy(&self, a: u64, b: u64) -> u64 {
|
||||||
|
debug_assert!(a < 2 * self.q);
|
||||||
|
debug_assert!(b < 2 * self.q);
|
||||||
|
|
||||||
|
let ab = a as u128 * b as u128;
|
||||||
|
|
||||||
|
// ab / (2^{n + \beta})
|
||||||
|
// note: \beta is assumed to -2
|
||||||
|
let tmp = ab >> (self.logq - 2);
|
||||||
|
|
||||||
|
// k = ((ab / (2^{n + \beta})) * \mu) / 2^{\alpha - (-2)}
|
||||||
|
let k = (tmp * self.barrett_mu) >> (self.barrett_alpha + 2);
|
||||||
|
|
||||||
|
// ab - k*p
|
||||||
|
let tmp = k * (self.q as u128);
|
||||||
|
|
||||||
|
(ab - tmp) as u64
|
||||||
|
}
|
||||||
|
|
||||||
/// returns (a * b) % q
|
/// returns (a * b) % q
|
||||||
///
|
///
|
||||||
/// - both a and b must be in range [0, 2q)
|
/// - both a and b must be in range [0, 2q)
|
||||||
@@ -214,10 +254,18 @@ impl<T> ArithmeticOps for ModularOpsU64<T> {
|
|||||||
self.add_mod_fast(*a, *b)
|
self.add_mod_fast(*a, *b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||||
|
self.add_mod_fast_lazy(*a, *b)
|
||||||
|
}
|
||||||
|
|
||||||
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||||
self.mul_mod_fast(*a, *b)
|
self.mul_mod_fast(*a, *b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||||
|
self.mul_mod_fast_lazy(*a, *b)
|
||||||
|
}
|
||||||
|
|
||||||
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||||
self.sub_mod_fast(*a, *b)
|
self.sub_mod_fast(*a, *b)
|
||||||
}
|
}
|
||||||
@@ -296,6 +344,55 @@ impl<T> VectorOps for ModularOpsU64<T> {
|
|||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T> ModularOpsU64<T> {
|
||||||
|
/// Returns \sum a[i]b[i]
|
||||||
|
pub fn shoup_fma<M: Matrix<MatElement = u64>>(&self, out: &mut M::R, a: &M, a_shoup: &M, b: &M)
|
||||||
|
where
|
||||||
|
M::R: RowMut,
|
||||||
|
{
|
||||||
|
assert!(a.dimension() == a_shoup.dimension());
|
||||||
|
assert!(a.dimension() == b.dimension());
|
||||||
|
|
||||||
|
let q = self.q;
|
||||||
|
let q_twice = self.q << 1;
|
||||||
|
|
||||||
|
// first row (without summation)
|
||||||
|
izip!(
|
||||||
|
out.as_mut().iter_mut(),
|
||||||
|
a.get_row(0),
|
||||||
|
a_shoup.get_row(0),
|
||||||
|
b.get_row(0)
|
||||||
|
)
|
||||||
|
.for_each(|(o, a, a_shoup, b)| {
|
||||||
|
*o = ShoupMul::mul(*b, *a, *a_shoup, q);
|
||||||
|
});
|
||||||
|
|
||||||
|
izip!(
|
||||||
|
a.iter_rows().skip(1),
|
||||||
|
a_shoup.iter_rows().skip(1),
|
||||||
|
b.iter_rows().skip(1)
|
||||||
|
)
|
||||||
|
.for_each(|(a_row, a_shoup_row, b_row)| {
|
||||||
|
izip!(
|
||||||
|
out.as_mut().iter_mut(),
|
||||||
|
a_row.as_ref().iter(),
|
||||||
|
a_shoup_row.as_ref().iter(),
|
||||||
|
b_row.as_ref().iter()
|
||||||
|
)
|
||||||
|
.for_each(|(o, a0, a0_shoup, b0)| {
|
||||||
|
let quotient = ((*a0_shoup as u128 * *b0 as u128) >> 64) as u64;
|
||||||
|
let mut v = (a0.wrapping_mul(b0)).wrapping_add(*o);
|
||||||
|
v = v.wrapping_sub(q.wrapping_mul(quotient));
|
||||||
|
|
||||||
|
if v >= q_twice {
|
||||||
|
v -= q_twice;
|
||||||
|
}
|
||||||
|
|
||||||
|
*o = v;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
impl<T> GetModulus for ModularOpsU64<T>
|
impl<T> GetModulus for ModularOpsU64<T>
|
||||||
where
|
where
|
||||||
T: Modulus,
|
T: Modulus,
|
||||||
@@ -333,10 +430,18 @@ where
|
|||||||
T::Element::wrapping_add(a, b)
|
T::Element::wrapping_add(a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||||
|
self.add(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||||
T::Element::wrapping_mul(a, b)
|
T::Element::wrapping_mul(a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
||||||
|
self.mul(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
fn neg(&self, a: &Self::Element) -> Self::Element {
|
fn neg(&self, a: &Self::Element) -> Self::Element {
|
||||||
T::Element::wrapping_sub(&T::Element::zero(), a)
|
T::Element::wrapping_sub(&T::Element::zero(), a)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ mod rgsw;
|
|||||||
mod shortint;
|
mod shortint;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
pub use backend::{ModInit, ModularOpsU64, VectorOps};
|
pub use backend::{ArithmeticOps, ModInit, ModularOpsU64, VectorOps};
|
||||||
pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer};
|
pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer};
|
||||||
pub use ntt::{Ntt, NttBackendU64, NttInit};
|
pub use ntt::{Ntt, NttBackendU64, NttInit};
|
||||||
|
|
||||||
|
|||||||
116
src/utils.rs
116
src/utils.rs
@@ -194,97 +194,6 @@ impl<P: Modulus> TryConvertFrom1<[P::Element], P> for Vec<i64> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub trait TryConvertFrom<T: ?Sized> {
|
|
||||||
// type Parameters: ?Sized;
|
|
||||||
|
|
||||||
// fn try_convert_from(value: &T, parameters: &Self::Parameters) -> Self;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// impl TryConvertFrom1<[i32]> for Vec<Vec<u32>> {
|
|
||||||
// type Parameters = u32;
|
|
||||||
// fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self
|
|
||||||
// { let row0 = value
|
|
||||||
// .iter()
|
|
||||||
// .map(|v| {
|
|
||||||
// let is_neg = v.is_negative();
|
|
||||||
// let v_u32 = v.abs() as u32;
|
|
||||||
|
|
||||||
// assert!(v_u32 < *parameters);
|
|
||||||
|
|
||||||
// if is_neg {
|
|
||||||
// parameters - v_u32
|
|
||||||
// } else {
|
|
||||||
// v_u32
|
|
||||||
// }
|
|
||||||
// })
|
|
||||||
// .collect_vec();
|
|
||||||
|
|
||||||
// vec![row0]
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// impl TryConvertFrom1<[i32]> for Vec<Vec<u64>> {
|
|
||||||
// type Parameters = u64;
|
|
||||||
// fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self
|
|
||||||
// { let row0 = value
|
|
||||||
// .iter()
|
|
||||||
// .map(|v| {
|
|
||||||
// let is_neg = v.is_negative();
|
|
||||||
// let v_u64 = v.abs() as u64;
|
|
||||||
|
|
||||||
// assert!(v_u64 < *parameters);
|
|
||||||
|
|
||||||
// if is_neg {
|
|
||||||
// parameters - v_u64
|
|
||||||
// } else {
|
|
||||||
// v_u64
|
|
||||||
// }
|
|
||||||
// })
|
|
||||||
// .collect_vec();
|
|
||||||
|
|
||||||
// vec![row0]
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// impl TryConvertFrom1<[i32]> for Vec<u64> {
|
|
||||||
// type Parameters = u64;
|
|
||||||
// fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self
|
|
||||||
// { value
|
|
||||||
// .iter()
|
|
||||||
// .map(|v| {
|
|
||||||
// let is_neg = v.is_negative();
|
|
||||||
// let v_u64 = v.abs() as u64;
|
|
||||||
|
|
||||||
// assert!(v_u64 < *parameters);
|
|
||||||
|
|
||||||
// if is_neg {
|
|
||||||
// parameters - v_u64
|
|
||||||
// } else {
|
|
||||||
// v_u64
|
|
||||||
// }
|
|
||||||
// })
|
|
||||||
// .collect_vec()
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// impl TryConvertFrom1<[u64]> for Vec<i64> {
|
|
||||||
// type Parameters = u64;
|
|
||||||
// fn try_convert_from(value: &[u64], parameters: &Self::Parameters) -> Self
|
|
||||||
// { let q = *parameters;
|
|
||||||
// let qby2 = q / 2;
|
|
||||||
// value
|
|
||||||
// .iter()
|
|
||||||
// .map(|v| {
|
|
||||||
// if *v > qby2 {
|
|
||||||
// -((q - v) as i64)
|
|
||||||
// } else {
|
|
||||||
// *v as i64
|
|
||||||
// }
|
|
||||||
// })
|
|
||||||
// .collect_vec()
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub(crate) struct Stats<T> {
|
pub(crate) struct Stats<T> {
|
||||||
pub(crate) samples: Vec<T>,
|
pub(crate) samples: Vec<T>,
|
||||||
}
|
}
|
||||||
@@ -323,3 +232,28 @@ where
|
|||||||
self.samples.extend(values.iter());
|
self.samples.extend(values.iter());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use rand::{thread_rng, Rng};
|
||||||
|
|
||||||
|
use super::ShoupMul;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn gg() {
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
let p = 36028797018820609;
|
||||||
|
|
||||||
|
let a = rng.gen_range(0..p);
|
||||||
|
let b = rng.gen_range(0..p);
|
||||||
|
let a_shoup = ShoupMul::representation(a, p);
|
||||||
|
|
||||||
|
// let c = ShoupMul::mul(b, a, a_shoup, p);
|
||||||
|
// assert!(c == ((a as u128 * b as u128) % p as u128) as u64);
|
||||||
|
|
||||||
|
let mut quotient = ((a_shoup as u128 * b as u128) >> 64) as u64;
|
||||||
|
quotient -= 1;
|
||||||
|
let c = (b.wrapping_mul(a)).wrapping_sub(p.wrapping_mul(quotient));
|
||||||
|
assert!(c - p == ((a as u128 * b as u128) % p as u128) as u64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user