diff --git a/math/src/lib.rs b/math/src/lib.rs index 8ee8bd7..e98dc31 100644 --- a/math/src/lib.rs +++ b/math/src/lib.rs @@ -229,4 +229,41 @@ pub mod macros{ } }; } + + #[macro_export] + macro_rules! apply_vsv { + + ($self:expr, $f:expr, $a:expr, $c:expr, $b:expr, $CHUNK:expr) => { + + let n: usize = $a.len(); + debug_assert!($b.len() == n, "invalid argument b: b.len() = {} != a.len() = {}", $b.len(), n); + debug_assert!(CHUNK&(CHUNK-1) == 0, "invalid CHUNK const: not a power of two"); + + match CHUNK{ + 8 => { + + izip!($a.chunks_exact(8), $b.chunks_exact_mut(8)).for_each(|(a, b)| { + $f(&$self, &a[0], $c, &mut b[0]); + $f(&$self, &a[1], $c, &mut b[1]); + $f(&$self, &a[2], $c, &mut b[2]); + $f(&$self, &a[3], $c, &mut b[3]); + $f(&$self, &a[4], $c, &mut b[4]); + $f(&$self, &a[5], $c, &mut b[5]); + $f(&$self, &a[6], $c, &mut b[6]); + $f(&$self, &a[7], $c, &mut b[7]); + }); + + let m = n - (n&7); + izip!($a[m..].iter(), $b[m..].iter_mut()).for_each(|(a, b)| { + $f(&$self, a, $c, b); + }); + }, + _=>{ + izip!($a.iter(), $b.iter_mut()).for_each(|(a, b)| { + $f(&$self, a, $c, b); + }); + } + } + }; + } } \ No newline at end of file diff --git a/math/src/modulus.rs b/math/src/modulus.rs index 8b14a0d..f516678 100644 --- a/math/src/modulus.rs +++ b/math/src/modulus.rs @@ -94,6 +94,9 @@ pub trait WordOperations{ // Assigns (a + 2q - b) * c to d. fn word_sum_aqqmb_prod_c_barrett_assign_d(&self, a: &O, b: &O, c: &barrett::Barrett, d: &mut O); + + // Assigns (a + 2q - b) * c to b. + fn word_sum_aqqmb_prod_c_barrett_assign_b(&self, a: &u64, c: &barrett::Barrett, b: &mut u64); } pub trait VecOperations{ @@ -136,6 +139,9 @@ pub trait VecOperations{ // Assigns (a[i] + 2q - b[i]) * c to d[i]. fn vec_sum_aqqmb_prod_c_scalar_barrett_assign_d(&self, a: &[u64], b: &[u64], c: &barrett::Barrett, d: &mut [u64]); + + // Assigns (a[i] + 2q - b[i]) * c to b[i]. + fn vec_sum_aqqmb_prod_c_scalar_barrett_assign_b(&self, a: &[u64], c: &barrett::Barrett, b: &mut [u64]); } diff --git a/math/src/modulus/impl_u64/operations.rs b/math/src/modulus/impl_u64/operations.rs index 3acf34c..6711a75 100644 --- a/math/src/modulus/impl_u64/operations.rs +++ b/math/src/modulus/impl_u64/operations.rs @@ -5,7 +5,7 @@ use crate::modulus::ReduceOnce; use crate::modulus::montgomery::Montgomery; use crate::modulus::barrett::Barrett; use crate::modulus::REDUCEMOD; -use crate::{apply_v, apply_vv, apply_vvv, apply_sv, apply_svv, apply_vvsv}; +use crate::{apply_v, apply_vv, apply_vvv, apply_sv, apply_svv, apply_vvsv, apply_vsv}; use itertools::izip; impl WordOperations for Prime{ @@ -86,6 +86,12 @@ impl WordOperations for Prime{ *d = self.two_q.wrapping_sub(*b).wrapping_add(*a); self.barrett.mul_external_assign::(*c, d); } + + #[inline(always)] + fn word_sum_aqqmb_prod_c_barrett_assign_b(&self, a: &u64, c: &Barrett, b: &mut u64){ + *b = self.two_q.wrapping_sub(*b).wrapping_add(*a); + self.barrett.mul_external_assign::(*c, b); + } } impl VecOperations for Prime{ @@ -160,4 +166,8 @@ impl VecOperations for Prime{ fn vec_sum_aqqmb_prod_c_scalar_barrett_assign_d(&self, a: &[u64], b: &[u64], c: &Barrett, d: &mut [u64]){ apply_vvsv!(self, Self::word_sum_aqqmb_prod_c_barrett_assign_d::, a, b, c, d, CHUNK); } + + fn vec_sum_aqqmb_prod_c_scalar_barrett_assign_b(&self, a: &[u64], c: &Barrett, b: &mut [u64]){ + apply_vsv!(self, Self::word_sum_aqqmb_prod_c_barrett_assign_b::, a, c, b, CHUNK); + } } diff --git a/math/src/poly.rs b/math/src/poly.rs index d153249..2d63277 100644 --- a/math/src/poly.rs +++ b/math/src/poly.rs @@ -1,10 +1,11 @@ pub mod poly; +use std::cmp::PartialEq; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, Eq)] pub struct Poly(pub Vec); impl Polywhere - O: Default + Clone, + O: Default + Clone + Copy, { pub fn new(n: usize) -> Self{ Self(vec![O::default();n]) @@ -20,15 +21,45 @@ impl Polywhere } pub fn n(&self) -> usize{ - return self.0.len() + (usize::BITS - self.0.len().leading_zeros()) as usize + } + + pub fn log_n(&self) -> usize{ + self.0.len()-1 + } + + pub fn resize(&mut self, n:usize){ + self.0.resize(n, O::default()); + } + + pub fn set_all(&mut self, v: &O){ + self.0.fill(*v) + } + + pub fn zero(&mut self){ + self.set_all(&O::default()) + } + + pub fn copy_from(&mut self, other: &Poly){ + if std::ptr::eq(self, other){ + return + } + self.resize(other.n()); + self.0.copy_from_slice(&other.0) } } -#[derive(Clone, Debug, PartialEq, Eq)] +impl PartialEq for Poly { + fn eq(&self, other: &Self) -> bool { + std::ptr::eq(self, other) || (self.0 == other.0) + } +} + +#[derive(Clone, Debug, Eq)] pub struct PolyRNS(pub Vec>); impl PolyRNSwhere - O: Default + Clone, + O: Default + Clone + Copy, { pub fn new(n: usize, level: usize) -> Self{ @@ -42,6 +73,10 @@ impl PolyRNSwhere self.0[0].n() } + pub fn log_n(&self) -> usize{ + self.0[0].log_n() + } + pub fn level(&self) -> usize{ self.0.len()-1 } @@ -60,13 +95,50 @@ impl PolyRNSwhere } } + pub fn resize(&mut self, level:usize){ + self.0.resize(level+1, Poly::::new(self.n())); + } + + pub fn split_at_mut(&mut self, level:usize) -> (&mut [Poly], &mut [Poly]){ + self.0.split_at_mut(level) + } + pub fn at(&self, level:usize) -> &Poly{ + assert!(level <= self.level(), "invalid argument level: level={} > self.level()={}", level, self.level()); &self.0[level] } pub fn at_mut(&mut self, level:usize) -> &mut Poly{ &mut self.0[level] } + + pub fn set_all(&mut self, v: &O){ + (0..self.level()+1).for_each(|i| self.at_mut(i).set_all(v)) + } + + pub fn zero(&mut self){ + self.set_all(&O::default()) + } + + pub fn copy(&mut self, other: &PolyRNS){ + if std::ptr::eq(self, other){ + return + } + self.resize(other.level()); + self.copy_level(other.level(), other); + } + + pub fn copy_level(&mut self, level:usize, other: &PolyRNS){ + assert!(self.level() <= level, "invalid argument level: level={} > self.level()={}", level, self.level()); + assert!(other.level() <= level, "invalid argument level: level={} > other.level()={}", level, other.level()); + (0..level+1).for_each(|i| self.at_mut(i).copy_from(other.at(i))) + } +} + +impl PartialEq for PolyRNS { + fn eq(&self, other: &Self) -> bool { + std::ptr::eq(self, other) && (self.0 == other.0) + } } impl Default for PolyRNS{ diff --git a/math/src/ring.rs b/math/src/ring.rs index 4a3de26..b62f9a9 100644 --- a/math/src/ring.rs +++ b/math/src/ring.rs @@ -2,6 +2,7 @@ pub mod impl_u64; use crate::modulus::prime::Prime; use crate::poly::{Poly, PolyRNS}; +use num_bigint::BigInt; use crate::dft::DFT; @@ -21,12 +22,9 @@ impl Ring{ } } - -//pub struct RingRNS<'a, O: Copy>(pub Vec>>); - pub struct RingRNS<'a, O>(& 'a [Ring]); -impl RingRNS<'_, O>{ +impl RingRNS<'_, O> { pub fn n(&self) -> usize{ self.0[0].n() @@ -39,11 +37,7 @@ impl RingRNS<'_, O>{ pub fn max_level(&self) -> usize{ self.0.len()-1 } - - pub fn modulus(&self) -> O{ - self.0[LEVEL].modulus.q - } - + pub fn level(&self) -> usize{ self.0.len()-1 } diff --git a/math/src/ring/impl_u64/rescaling_rns.rs b/math/src/ring/impl_u64/rescaling_rns.rs index 505d225..3634e00 100644 --- a/math/src/ring/impl_u64/rescaling_rns.rs +++ b/math/src/ring/impl_u64/rescaling_rns.rs @@ -1,12 +1,12 @@ use crate::ring::RingRNS; -use crate::poly::PolyRNS; +use crate::poly::{Poly, PolyRNS}; use crate::modulus::barrett::Barrett; use crate::modulus::ONCE; extern crate test; impl RingRNS<'_, u64>{ - /// Updates b to floor(b / q[b.level()]). + /// Updates b to floor(a / q[b.level()]). /// Expects a and b to be in the NTT domain. pub fn div_floor_by_last_modulus_ntt(&self, a: &PolyRNS, buf: &mut PolyRNS, b: &mut PolyRNS){ assert!(b.level() >= a.level()-1, "invalid input b: b.level()={} < a.level()-1={}", b.level(), a.level()-1); @@ -21,6 +21,19 @@ impl RingRNS<'_, u64>{ } /// Updates b to floor(b / q[b.level()]). + /// Expects b to be in the NTT domain. + pub fn div_floor_by_last_modulus_ntt_inplace(&self, buf: &mut PolyRNS, b: &mut PolyRNS){ + let level = self.level(); + self.0[level].intt::(b.at(level), buf.at_mut(0)); + let rescaling_constants: Vec> = self.rescaling_constant(); + let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); + for (i, r) in self.0[0..level].iter().enumerate(){ + r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); + r.sum_aqqmb_prod_c_scalar_barrett_inplace::(&buf_ntt_qi_scaling[0], &rescaling_constants[i], b.at_mut(i)); + } + } + + /// Updates b to floor(a / q[b.level()]). pub fn div_floor_by_last_modulus(&self, a: &PolyRNS, b: &mut PolyRNS){ assert!(b.level() >= a.level()-1, "invalid input b: b.level()={} < a.level()-1={}", b.level(), a.level()-1); let level = self.level(); @@ -29,6 +42,35 @@ impl RingRNS<'_, u64>{ r.sum_aqqmb_prod_c_scalar_barrett::(a.at(level), a.at(i), &rescaling_constants[i], b.at_mut(i)); } } + + /// Updates a to floor(b / q[b.level()]). + pub fn div_floor_by_last_modulus_inplace(&self, a: &mut PolyRNS){ + let level = self.level(); + let rescaling_constants: Vec> = self.rescaling_constant(); + let (a_i, a_level) = a.split_at_mut(level); + for (i, r) in self.0[0..level].iter().enumerate(){ + r.sum_aqqmb_prod_c_scalar_barrett_inplace::(&a_level[0], &rescaling_constants[i], &mut a_i[i]); + } + } + + pub fn div_floor_by_last_moduli(&self, nb_moduli:usize, a: &PolyRNS, b: &mut PolyRNS){ + if nb_moduli == 0{ + if a != b{ + b.copy(a); + } + }else{ + self.div_floor_by_last_modulus(a, b); + (1..nb_moduli).for_each(|i|{self.at_level(self.level()-i).div_floor_by_last_modulus_inplace(b)}); + } + } + + pub fn div_floor_by_last_moduli_inplace(&self, nb_moduli:usize, a: &mut PolyRNS){ + (0..nb_moduli).for_each(|i|{self.at_level(self.level()-i).div_floor_by_last_modulus_inplace(a)}); + } + + pub fn div_round_by_last_modulus_ntt(&self, a: &PolyRNS, buf: &mut PolyRNS, b: &mut PolyRNS){ + let level = self.level(); + } } diff --git a/math/src/ring/impl_u64/ring.rs b/math/src/ring/impl_u64/ring.rs index ed6934d..b04d901 100644 --- a/math/src/ring/impl_u64/ring.rs +++ b/math/src/ring/impl_u64/ring.rs @@ -139,4 +139,11 @@ impl Ring{ debug_assert!(d.n() == self.n(), "d.n()={} != n={}", d.n(), self.n()); self.modulus.vec_sum_aqqmb_prod_c_scalar_barrett_assign_d::(&a.0, &b.0, c, &mut d.0); } + + #[inline(always)] + pub fn sum_aqqmb_prod_c_scalar_barrett_inplace(&self, a: &Poly, c: &Barrett, b: &mut Poly){ + debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); + debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); + self.modulus.vec_sum_aqqmb_prod_c_scalar_barrett_assign_b::(&a.0, c, &mut b.0); + } } \ No newline at end of file diff --git a/math/src/ring/impl_u64/ring_rns.rs b/math/src/ring/impl_u64/ring_rns.rs index b4b6339..dc82615 100644 --- a/math/src/ring/impl_u64/ring_rns.rs +++ b/math/src/ring/impl_u64/ring_rns.rs @@ -19,6 +19,12 @@ impl<'a> RingRNS<'a, u64>{ RingRNS(rings) } + pub fn modulus(&self) -> BigInt{ + let mut modulus = BigInt::from(1); + self.0.iter().enumerate().for_each(|(_, r)|modulus *= BigInt::from(r.modulus.q)); + modulus + } + pub fn rescaling_constant(&self) -> Vec> { let level = self.level(); let q_scale: u64 = self.0[level].modulus.q; @@ -30,6 +36,32 @@ impl<'a> RingRNS<'a, u64>{ assert!(level <= a.level(), "invalid level: level={} > a.level()={}", level, a.level()); (0..level).for_each(|i|{self.0[i].from_bigint(coeffs, step, a.at_mut(i))}); } + + pub fn set_bigint_from_poly(&self, a: &PolyRNS, step: usize, coeffs: &mut [BigInt]){ + assert!(step <= a.n(), "invalid step: step={} > a.n()={}", step, a.n()); + assert!(coeffs.len() <= a.n() / step, "invalid coeffs: coeffs.len()={} > a.n()/step={}", coeffs.len(), a.n()/step); + + let mut inv_crt: Vec = vec![BigInt::default(); self.level()+1]; + let q_big: BigInt = self.modulus(); + let q_big_half: BigInt = &q_big>>1; + + inv_crt.iter_mut().enumerate().for_each(|(i, a)|{ + let qi_big = BigInt::from(self.0[i].modulus.q); + *a = (&q_big / &qi_big); + *a *= a.modinv(&qi_big).unwrap(); + }); + + (0..self.n()).step_by(step).enumerate().for_each(|(i, j)|{ + coeffs[j] = BigInt::from(a.at(0).0[i]) * &inv_crt[0]; + (1..self.level()+1).for_each(|k|{ + coeffs[j] += BigInt::from(a.at(k).0[i] * &inv_crt[k]); + }); + coeffs[j] %= &q_big; + if &coeffs[j] >= &q_big_half{ + coeffs[j] -= &q_big; + } + }); + } } impl RingRNS<'_, u64>{