|
|
@ -3,6 +3,8 @@ use std::marker::PhantomData; |
|
|
|
use itertools::izip;
|
|
|
|
use num_traits::{PrimInt, Signed, ToPrimitive, WrappingAdd, WrappingMul, WrappingSub, Zero};
|
|
|
|
|
|
|
|
use crate::{utils::ShoupMul, Matrix, RowMut};
|
|
|
|
|
|
|
|
pub trait Modulus {
|
|
|
|
type Element;
|
|
|
|
/// Modulus value if it fits in Element
|
|
|
@ -113,7 +115,9 @@ pub trait ArithmeticOps { |
|
|
|
type Element;
|
|
|
|
|
|
|
|
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
|
|
|
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
|
|
|
fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
|
|
|
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
|
|
|
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element;
|
|
|
|
fn neg(&self, a: &Self::Element) -> Self::Element;
|
|
|
|
|
|
|
@ -122,6 +126,7 @@ pub trait ArithmeticOps { |
|
|
|
|
|
|
|
pub struct ModularOpsU64<T> {
|
|
|
|
q: u64,
|
|
|
|
q_twice: u64,
|
|
|
|
logq: usize,
|
|
|
|
barrett_mu: u128,
|
|
|
|
barrett_alpha: usize,
|
|
|
@ -146,6 +151,7 @@ where |
|
|
|
|
|
|
|
ModularOpsU64 {
|
|
|
|
q,
|
|
|
|
q_twice: q << 1,
|
|
|
|
logq: logq as usize,
|
|
|
|
barrett_alpha: alpha as usize,
|
|
|
|
barrett_mu: mu,
|
|
|
@ -166,6 +172,17 @@ impl ModularOpsU64 { |
|
|
|
o
|
|
|
|
}
|
|
|
|
|
|
|
|
fn add_mod_fast_lazy(&self, a: u64, b: u64) -> u64 {
|
|
|
|
debug_assert!(a < self.q_twice);
|
|
|
|
debug_assert!(b < self.q_twice);
|
|
|
|
|
|
|
|
let mut o = a + b;
|
|
|
|
if o >= self.q_twice {
|
|
|
|
o -= self.q_twice;
|
|
|
|
}
|
|
|
|
o
|
|
|
|
}
|
|
|
|
|
|
|
|
fn sub_mod_fast(&self, a: u64, b: u64) -> u64 {
|
|
|
|
debug_assert!(a < self.q);
|
|
|
|
debug_assert!(b < self.q);
|
|
|
@ -177,6 +194,29 @@ impl ModularOpsU64 { |
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// returns (a * b) % q
|
|
|
|
///
|
|
|
|
/// - both a and b must be in range [0, 2q)
|
|
|
|
/// - output is in range [0 , 2q)
|
|
|
|
fn mul_mod_fast_lazy(&self, a: u64, b: u64) -> u64 {
|
|
|
|
debug_assert!(a < 2 * self.q);
|
|
|
|
debug_assert!(b < 2 * self.q);
|
|
|
|
|
|
|
|
let ab = a as u128 * b as u128;
|
|
|
|
|
|
|
|
// ab / (2^{n + \beta})
|
|
|
|
// note: \beta is assumed to -2
|
|
|
|
let tmp = ab >> (self.logq - 2);
|
|
|
|
|
|
|
|
// k = ((ab / (2^{n + \beta})) * \mu) / 2^{\alpha - (-2)}
|
|
|
|
let k = (tmp * self.barrett_mu) >> (self.barrett_alpha + 2);
|
|
|
|
|
|
|
|
// ab - k*p
|
|
|
|
let tmp = k * (self.q as u128);
|
|
|
|
|
|
|
|
(ab - tmp) as u64
|
|
|
|
}
|
|
|
|
|
|
|
|
/// returns (a * b) % q
|
|
|
|
///
|
|
|
|
/// - both a and b must be in range [0, 2q)
|
|
|
@ -214,10 +254,18 @@ impl ArithmeticOps for ModularOpsU64 { |
|
|
|
self.add_mod_fast(*a, *b)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
|
|
|
self.add_mod_fast_lazy(*a, *b)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
|
|
|
self.mul_mod_fast(*a, *b)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
|
|
|
self.mul_mod_fast_lazy(*a, *b)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
|
|
|
self.sub_mod_fast(*a, *b)
|
|
|
|
}
|
|
|
@ -296,6 +344,55 @@ impl VectorOps for ModularOpsU64 { |
|
|
|
// }
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<T> ModularOpsU64<T> {
|
|
|
|
/// Returns \sum a[i]b[i]
|
|
|
|
pub fn shoup_fma<M: Matrix<MatElement = u64>>(&self, out: &mut M::R, a: &M, a_shoup: &M, b: &M)
|
|
|
|
where
|
|
|
|
M::R: RowMut,
|
|
|
|
{
|
|
|
|
assert!(a.dimension() == a_shoup.dimension());
|
|
|
|
assert!(a.dimension() == b.dimension());
|
|
|
|
|
|
|
|
let q = self.q;
|
|
|
|
let q_twice = self.q << 1;
|
|
|
|
|
|
|
|
// first row (without summation)
|
|
|
|
izip!(
|
|
|
|
out.as_mut().iter_mut(),
|
|
|
|
a.get_row(0),
|
|
|
|
a_shoup.get_row(0),
|
|
|
|
b.get_row(0)
|
|
|
|
)
|
|
|
|
.for_each(|(o, a, a_shoup, b)| {
|
|
|
|
*o = ShoupMul::mul(*b, *a, *a_shoup, q);
|
|
|
|
});
|
|
|
|
|
|
|
|
izip!(
|
|
|
|
a.iter_rows().skip(1),
|
|
|
|
a_shoup.iter_rows().skip(1),
|
|
|
|
b.iter_rows().skip(1)
|
|
|
|
)
|
|
|
|
.for_each(|(a_row, a_shoup_row, b_row)| {
|
|
|
|
izip!(
|
|
|
|
out.as_mut().iter_mut(),
|
|
|
|
a_row.as_ref().iter(),
|
|
|
|
a_shoup_row.as_ref().iter(),
|
|
|
|
b_row.as_ref().iter()
|
|
|
|
)
|
|
|
|
.for_each(|(o, a0, a0_shoup, b0)| {
|
|
|
|
let quotient = ((*a0_shoup as u128 * *b0 as u128) >> 64) as u64;
|
|
|
|
let mut v = (a0.wrapping_mul(b0)).wrapping_add(*o);
|
|
|
|
v = v.wrapping_sub(q.wrapping_mul(quotient));
|
|
|
|
|
|
|
|
if v >= q_twice {
|
|
|
|
v -= q_twice;
|
|
|
|
}
|
|
|
|
|
|
|
|
*o = v;
|
|
|
|
});
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
impl<T> GetModulus for ModularOpsU64<T>
|
|
|
|
where
|
|
|
|
T: Modulus,
|
|
|
@ -333,10 +430,18 @@ where |
|
|
|
T::Element::wrapping_add(a, b)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
|
|
|
self.add(a, b)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
|
|
|
T::Element::wrapping_mul(a, b)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
|
|
|
|
self.mul(a, b)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn neg(&self, a: &Self::Element) -> Self::Element {
|
|
|
|
T::Element::wrapping_sub(&T::Element::zero(), a)
|
|
|
|
}
|
|
|
|