diff --git a/benches/modulus.rs b/benches/modulus.rs index 749ffa8..56429e2 100644 --- a/benches/modulus.rs +++ b/benches/modulus.rs @@ -1,14 +1,10 @@ -use bin_rs::{Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, VectorOps}; +use bin_rs::{ArithmeticOps, Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, VectorOps}; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; use itertools::{izip, Itertools}; use rand::{thread_rng, Rng}; use rand_distr::Uniform; -pub(crate) fn decompose_r( - r: &[u64], - decomp_r: &mut [Vec], - decomposer: &DefaultDecomposer, -) { +fn decompose_r(r: &[u64], decomp_r: &mut [Vec], decomposer: &DefaultDecomposer) { let ring_size = r.len(); // let d = decomposer.decomposition_count(); // let mut count = 0; @@ -24,6 +20,16 @@ pub(crate) fn decompose_r( } } +fn matrix_fma(out: &mut [u64], a: &Vec>, b: &Vec>, modop: &ModularOpsU64) { + izip!(out.iter_mut(), a[0].iter(), b[0].iter()) + .for_each(|(o, ai, bi)| *o = modop.add(o, &modop.mul_lazy(ai, bi))); + + izip!(a.iter().skip(1), b.iter().skip(1)).for_each(|(a_r, b_r)| { + izip!(out.iter_mut(), a_r.iter(), b_r.iter()) + .for_each(|(o, ai, bi)| *o = modop.add_lazy(o, &modop.mul(ai, bi))); + }); +} + fn benchmark_decomposer(c: &mut Criterion) { let mut group = c.benchmark_group("decomposer"); @@ -69,7 +75,7 @@ fn benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("modulus"); // 55 for prime in [36028797017456641] { - for ring_size in [1 << 11, 1 << 15] { + for ring_size in [1 << 11] { let modop = ModularOpsU64::new(prime); let mut rng = thread_rng(); @@ -79,12 +85,55 @@ fn benchmark(c: &mut Criterion) { let a1 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec(); let a2 = (&mut rng).sample_iter(dist).take(ring_size).collect_vec(); + let d = 2; + let a0_matrix = (0..d) + .into_iter() + .map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec()) + .collect_vec(); + // a0 in shoup representation + let a0_shoup_matrix = a0_matrix + .iter() + .map(|r| { + r.iter() + .map(|v| { + // $(v * 2^{\beta}) / p$ + ((*v as u128 * (1u128 << 64)) / prime as u128) as u64 + }) + .collect_vec() + }) + .collect_vec(); + let a1_matrix = (0..d) + .into_iter() + .map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec()) + .collect_vec(); + + group.bench_function( + BenchmarkId::new("matrix_fma_lazy", format!("q={prime}/N={ring_size}/d={d}")), + |b| { + b.iter_batched_ref( + || (vec![0u64; ring_size]), + |(out)| black_box(matrix_fma(out, &a0_matrix, &a1_matrix, &modop)), + criterion::BatchSize::PerIteration, + ) + }, + ); + group.bench_function( - BenchmarkId::new("elwise_fma", format!("q={prime}/{ring_size}")), + BenchmarkId::new( + "matrix_shoup_fma_lazy", + format!("q={prime}/N={ring_size}/d={d}"), + ), |b| { b.iter_batched_ref( - || (a0.clone(), a1.clone(), a2.clone()), - |(a0, a1, a2)| black_box(modop.elwise_fma_mut(a0, a1, a2)), + || (vec![0u64; ring_size]), + |(out)| { + black_box(modop.shoup_fma( + out, + &a0_matrix, + &a0_shoup_matrix, + &a1_matrix, + )) + }, criterion::BatchSize::PerIteration, ) }, diff --git a/src/backend.rs b/src/backend.rs index 04996b9..ed5519b 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -3,6 +3,8 @@ use std::marker::PhantomData; use itertools::izip; use num_traits::{PrimInt, Signed, ToPrimitive, WrappingAdd, WrappingMul, WrappingSub, Zero}; +use crate::{utils::ShoupMul, Matrix, RowMut}; + pub trait Modulus { type Element; /// Modulus value if it fits in Element @@ -113,7 +115,9 @@ pub trait ArithmeticOps { type Element; fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; + fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; + fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; fn neg(&self, a: &Self::Element) -> Self::Element; @@ -122,6 +126,7 @@ pub trait ArithmeticOps { pub struct ModularOpsU64 { q: u64, + q_twice: u64, logq: usize, barrett_mu: u128, barrett_alpha: usize, @@ -146,6 +151,7 @@ where ModularOpsU64 { q, + q_twice: q << 1, logq: logq as usize, barrett_alpha: alpha as usize, barrett_mu: mu, @@ -166,6 +172,17 @@ impl ModularOpsU64 { o } + fn add_mod_fast_lazy(&self, a: u64, b: u64) -> u64 { + debug_assert!(a < self.q_twice); + debug_assert!(b < self.q_twice); + + let mut o = a + b; + if o >= self.q_twice { + o -= self.q_twice; + } + o + } + fn sub_mod_fast(&self, a: u64, b: u64) -> u64 { debug_assert!(a < self.q); debug_assert!(b < self.q); @@ -177,6 +194,29 @@ impl ModularOpsU64 { } } + // returns (a * b) % q + /// + /// - both a and b must be in range [0, 2q) + /// - output is in range [0 , 2q) + fn mul_mod_fast_lazy(&self, a: u64, b: u64) -> u64 { + debug_assert!(a < 2 * self.q); + debug_assert!(b < 2 * self.q); + + let ab = a as u128 * b as u128; + + // ab / (2^{n + \beta}) + // note: \beta is assumed to -2 + let tmp = ab >> (self.logq - 2); + + // k = ((ab / (2^{n + \beta})) * \mu) / 2^{\alpha - (-2)} + let k = (tmp * self.barrett_mu) >> (self.barrett_alpha + 2); + + // ab - k*p + let tmp = k * (self.q as u128); + + (ab - tmp) as u64 + } + /// returns (a * b) % q /// /// - both a and b must be in range [0, 2q) @@ -214,10 +254,18 @@ impl ArithmeticOps for ModularOpsU64 { self.add_mod_fast(*a, *b) } + fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + self.add_mod_fast_lazy(*a, *b) + } + fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { self.mul_mod_fast(*a, *b) } + fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + self.mul_mod_fast_lazy(*a, *b) + } + fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { self.sub_mod_fast(*a, *b) } @@ -296,6 +344,55 @@ impl VectorOps for ModularOpsU64 { // } } +impl ModularOpsU64 { + /// Returns \sum a[i]b[i] + pub fn shoup_fma>(&self, out: &mut M::R, a: &M, a_shoup: &M, b: &M) + where + M::R: RowMut, + { + assert!(a.dimension() == a_shoup.dimension()); + assert!(a.dimension() == b.dimension()); + + let q = self.q; + let q_twice = self.q << 1; + + // first row (without summation) + izip!( + out.as_mut().iter_mut(), + a.get_row(0), + a_shoup.get_row(0), + b.get_row(0) + ) + .for_each(|(o, a, a_shoup, b)| { + *o = ShoupMul::mul(*b, *a, *a_shoup, q); + }); + + izip!( + a.iter_rows().skip(1), + a_shoup.iter_rows().skip(1), + b.iter_rows().skip(1) + ) + .for_each(|(a_row, a_shoup_row, b_row)| { + izip!( + out.as_mut().iter_mut(), + a_row.as_ref().iter(), + a_shoup_row.as_ref().iter(), + b_row.as_ref().iter() + ) + .for_each(|(o, a0, a0_shoup, b0)| { + let quotient = ((*a0_shoup as u128 * *b0 as u128) >> 64) as u64; + let mut v = (a0.wrapping_mul(b0)).wrapping_add(*o); + v = v.wrapping_sub(q.wrapping_mul(quotient)); + + if v >= q_twice { + v -= q_twice; + } + + *o = v; + }); + }); + } +} impl GetModulus for ModularOpsU64 where T: Modulus, @@ -333,10 +430,18 @@ where T::Element::wrapping_add(a, b) } + fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + self.add(a, b) + } + fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { T::Element::wrapping_mul(a, b) } + fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + self.mul(a, b) + } + fn neg(&self, a: &Self::Element) -> Self::Element { T::Element::wrapping_sub(&T::Element::zero(), a) } diff --git a/src/lib.rs b/src/lib.rs index 3bb54b2..8b24fb3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,7 +20,7 @@ mod rgsw; mod shortint; mod utils; -pub use backend::{ModInit, ModularOpsU64, VectorOps}; +pub use backend::{ArithmeticOps, ModInit, ModularOpsU64, VectorOps}; pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer}; pub use ntt::{Ntt, NttBackendU64, NttInit}; diff --git a/src/utils.rs b/src/utils.rs index b448a83..3118fdc 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -194,97 +194,6 @@ impl TryConvertFrom1<[P::Element], P> for Vec { } } -// pub trait TryConvertFrom { -// type Parameters: ?Sized; - -// fn try_convert_from(value: &T, parameters: &Self::Parameters) -> Self; -// } - -// impl TryConvertFrom1<[i32]> for Vec> { -// type Parameters = u32; -// fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self -// { let row0 = value -// .iter() -// .map(|v| { -// let is_neg = v.is_negative(); -// let v_u32 = v.abs() as u32; - -// assert!(v_u32 < *parameters); - -// if is_neg { -// parameters - v_u32 -// } else { -// v_u32 -// } -// }) -// .collect_vec(); - -// vec![row0] -// } -// } - -// impl TryConvertFrom1<[i32]> for Vec> { -// type Parameters = u64; -// fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self -// { let row0 = value -// .iter() -// .map(|v| { -// let is_neg = v.is_negative(); -// let v_u64 = v.abs() as u64; - -// assert!(v_u64 < *parameters); - -// if is_neg { -// parameters - v_u64 -// } else { -// v_u64 -// } -// }) -// .collect_vec(); - -// vec![row0] -// } -// } - -// impl TryConvertFrom1<[i32]> for Vec { -// type Parameters = u64; -// fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self -// { value -// .iter() -// .map(|v| { -// let is_neg = v.is_negative(); -// let v_u64 = v.abs() as u64; - -// assert!(v_u64 < *parameters); - -// if is_neg { -// parameters - v_u64 -// } else { -// v_u64 -// } -// }) -// .collect_vec() -// } -// } - -// impl TryConvertFrom1<[u64]> for Vec { -// type Parameters = u64; -// fn try_convert_from(value: &[u64], parameters: &Self::Parameters) -> Self -// { let q = *parameters; -// let qby2 = q / 2; -// value -// .iter() -// .map(|v| { -// if *v > qby2 { -// -((q - v) as i64) -// } else { -// *v as i64 -// } -// }) -// .collect_vec() -// } -// } - pub(crate) struct Stats { pub(crate) samples: Vec, } @@ -323,3 +232,28 @@ where self.samples.extend(values.iter()); } } + +#[cfg(test)] +mod tests { + use rand::{thread_rng, Rng}; + + use super::ShoupMul; + + #[test] + fn gg() { + let mut rng = thread_rng(); + let p = 36028797018820609; + + let a = rng.gen_range(0..p); + let b = rng.gen_range(0..p); + let a_shoup = ShoupMul::representation(a, p); + + // let c = ShoupMul::mul(b, a, a_shoup, p); + // assert!(c == ((a as u128 * b as u128) % p as u128) as u64); + + let mut quotient = ((a_shoup as u128 * b as u128) >> 64) as u64; + quotient -= 1; + let c = (b.wrapping_mul(a)).wrapping_sub(p.wrapping_mul(quotient)); + assert!(c - p == ((a as u128 * b as u128) % p as u128) as u64); + } +}