From 160e7a33da29b4e58ac39d41edb7a09092e03ecd Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 8 Jan 2025 11:07:04 +0100 Subject: [PATCH] fmt --- math/src/lib.rs | 16 +--- math/src/modulus.rs | 54 ++++++++--- math/src/modulus/impl_u64/operations.rs | 120 +++++++++++++++--------- math/src/num_bigint.rs | 28 +++--- math/src/ring/impl_u64/rescaling_rns.rs | 48 +++++----- math/src/ring/impl_u64/ring.rs | 87 +++++++++++++---- math/src/ring/impl_u64/ring_rns.rs | 82 +++++++++++----- math/src/ring/impl_u64/sampling.rs | 50 ++++++---- math/tests/rescaling_rns.rs | 105 ++++++++++++--------- 9 files changed, 383 insertions(+), 207 deletions(-) diff --git a/math/src/lib.rs b/math/src/lib.rs index 5a94a97..b1990ff 100644 --- a/math/src/lib.rs +++ b/math/src/lib.rs @@ -3,10 +3,10 @@ pub mod dft; pub mod modulus; +pub mod num_bigint; pub mod poly; pub mod ring; pub mod scalar; -pub mod num_bigint; pub const CHUNK: usize = 8; @@ -333,11 +333,7 @@ pub mod macros { match CHUNK { 8 => { - izip!( - $a.chunks_exact(8), - $d.chunks_exact_mut(8) - ) - .for_each(|(a, d)| { + izip!($a.chunks_exact(8), $d.chunks_exact_mut(8)).for_each(|(a, d)| { $f(&$self, &a[0], $b, $c, &mut d[0]); $f(&$self, &a[1], $b, $c, &mut d[1]); $f(&$self, &a[2], $b, $c, &mut d[2]); @@ -349,11 +345,9 @@ pub mod macros { }); let m = n - (n & 7); - izip!($a[m..].iter(), $d[m..].iter_mut()).for_each( - |(a, d)| { - $f(&$self, a, $b, $c, d); - }, - ); + izip!($a[m..].iter(), $d[m..].iter_mut()).for_each(|(a, d)| { + $f(&$self, a, $b, $c, d); + }); } _ => { izip!($a.iter(), $d.iter_mut()).for_each(|(a, d)| { diff --git a/math/src/modulus.rs b/math/src/modulus.rs index d9ed6a9..4d9db52 100644 --- a/math/src/modulus.rs +++ b/math/src/modulus.rs @@ -74,7 +74,12 @@ pub trait ScalarOperations { fn sa_add_sb_into_sb(&self, a: &O, b: &mut O); // Assigns a - b to c. - fn sa_sub_sb_into_sc(&self, a: &O, b: &O, c: &mut O); + fn sa_sub_sb_into_sc( + &self, + a: &O, + b: &O, + c: &mut O, + ); // Assigns a - b to b. fn sa_sub_sb_into_sb(&self, a: &O, b: &mut O); @@ -147,7 +152,7 @@ pub trait ScalarOperations { &self, b: &u64, c: &barrett::Barrett, - a: &mut u64 + a: &mut u64, ); // Assigns (a + b) * c to d. @@ -156,25 +161,25 @@ pub trait ScalarOperations { a: &u64, b: &u64, c: &barrett::Barrett, - d: &mut u64 + d: &mut u64, ); // Assigns (a - b + c) * d to e. - fn sb_sub_sa_add_sc_mul_sd_into_se( + fn sb_sub_sa_add_sc_mul_sd_into_se( &self, a: &u64, b: &u64, c: &u64, d: &barrett::Barrett, - e: &mut u64 + e: &mut u64, ); - fn sb_sub_sa_add_sc_mul_sd_into_sa( + fn sb_sub_sa_add_sc_mul_sd_into_sa( &self, b: &u64, c: &u64, d: &barrett::Barrett, - a: &mut u64 + a: &mut u64, ); } @@ -206,10 +211,18 @@ pub trait VectorOperations { fn va_add_sb_into_va(&self, a: &O, b: &mut [O]); // vec(b) <- vec(a) - vec(b). - fn va_sub_vb_into_vb(&self, a: &[O], b: &mut [O]); + fn va_sub_vb_into_vb( + &self, + a: &[O], + b: &mut [O], + ); // vec(a) <- vec(a) - vec(b). - fn va_sub_vb_into_va(&self, b: &[O], a: &mut [O]); + fn va_sub_vb_into_va( + &self, + b: &[O], + a: &mut [O], + ); // vec(c) <- vec(a) - vec(b). fn va_sub_vb_into_vc( @@ -220,10 +233,17 @@ pub trait VectorOperations { ); // vec(a) <- -vec(a). - fn va_neg_into_va(&self, a: &mut [O]); + fn va_neg_into_va( + &self, + a: &mut [O], + ); // vec(b) <- -vec(a). - fn va_neg_into_vb(&self, a: &[O], b: &mut [O]); + fn va_neg_into_vb( + &self, + a: &[O], + b: &mut [O], + ); // vec(b) <- vec(a) fn va_prep_mont_into_vb( @@ -297,7 +317,11 @@ pub trait VectorOperations { ); // vec(e) <- (vec(b) - vec(a) + scalar(c)) * scalar(e). - fn vb_sub_va_add_sc_mul_sd_into_ve( + fn vb_sub_va_add_sc_mul_sd_into_ve< + const CHUNK: usize, + const VBRANGE: u8, + const REDUCE: REDUCEMOD, + >( &self, va: &[u64], vb: &[u64], @@ -307,7 +331,11 @@ pub trait VectorOperations { ); // vec(a) <- (vec(b) - vec(a) + scalar(c)) * scalar(e). - fn vb_sub_va_add_sc_mul_sd_into_va( + fn vb_sub_va_add_sc_mul_sd_into_va< + const CHUNK: usize, + const VBRANGE: u8, + const REDUCE: REDUCEMOD, + >( &self, vb: &[u64], sc: &u64, diff --git a/math/src/modulus/impl_u64/operations.rs b/math/src/modulus/impl_u64/operations.rs index e6e2a09..5be7ec8 100644 --- a/math/src/modulus/impl_u64/operations.rs +++ b/math/src/modulus/impl_u64/operations.rs @@ -1,9 +1,12 @@ use crate::modulus::barrett::Barrett; use crate::modulus::montgomery::Montgomery; use crate::modulus::prime::Prime; -use crate::modulus::{REDUCEMOD, NONE}; use crate::modulus::{ScalarOperations, VectorOperations}; -use crate::{apply_sv, apply_svv, apply_v, apply_vsv, apply_vv, apply_vvsv, apply_vvv, apply_ssv, apply_vssv, apply_vvssv}; +use crate::modulus::{NONE, REDUCEMOD}; +use crate::{ + apply_ssv, apply_sv, apply_svv, apply_v, apply_vssv, apply_vsv, apply_vv, apply_vvssv, + apply_vvsv, apply_vvv, +}; use itertools::izip; impl ScalarOperations for Prime { @@ -32,11 +35,16 @@ impl ScalarOperations for Prime { } #[inline(always)] - fn sa_sub_sb_into_sc(&self, a: &u64, b: &u64, c: &mut u64) { - match SBRANGE{ - 1 =>{*c = *a + self.q - *b} - 2 =>{*c = *a + self.two_q - *b} - 4 =>{*c = *a + self.four_q - *b} + fn sa_sub_sb_into_sc( + &self, + a: &u64, + b: &u64, + c: &mut u64, + ) { + match SBRANGE { + 1 => *c = *a + self.q - *b, + 2 => *c = *a + self.two_q - *b, + 4 => *c = *a + self.four_q - *b, _ => unreachable!("invalid SBRANGE argument"), } self.sa_reduce_into_sa::(c) @@ -44,10 +52,10 @@ impl ScalarOperations for Prime { #[inline(always)] fn sa_sub_sb_into_sa(&self, b: &u64, a: &mut u64) { - match SBRANGE{ - 1 =>{*a = *a + self.q - *b} - 2 =>{*a = *a + self.two_q - *b} - 4 =>{*a = *a + self.four_q - *b} + match SBRANGE { + 1 => *a = *a + self.q - *b, + 2 => *a = *a + self.two_q - *b, + 4 => *a = *a + self.four_q - *b, _ => unreachable!("invalid SBRANGE argument"), } self.sa_reduce_into_sa::(a) @@ -55,10 +63,10 @@ impl ScalarOperations for Prime { #[inline(always)] fn sa_sub_sb_into_sb(&self, a: &u64, b: &mut u64) { - match SBRANGE{ - 1 =>{*b = *a + self.q - *b} - 2 =>{*b = *a + self.two_q - *b} - 4 =>{*b = *a + self.four_q - *b} + match SBRANGE { + 1 => *b = *a + self.q - *b, + 2 => *b = *a + self.two_q - *b, + 4 => *b = *a + self.four_q - *b, _ => unreachable!("invalid SBRANGE argument"), } self.sa_reduce_into_sa::(b) @@ -66,10 +74,10 @@ impl ScalarOperations for Prime { #[inline(always)] fn sa_neg_into_sa(&self, a: &mut u64) { - match SBRANGE{ - 1 =>{*a = self.q - *a} - 2 =>{*a = self.two_q - *a} - 4 =>{*a = self.four_q - *a} + match SBRANGE { + 1 => *a = self.q - *a, + 2 => *a = self.two_q - *a, + 4 => *a = self.four_q - *a, _ => unreachable!("invalid SBRANGE argument"), } self.sa_reduce_into_sa::(a) @@ -77,10 +85,10 @@ impl ScalarOperations for Prime { #[inline(always)] fn sa_neg_into_sb(&self, a: &u64, b: &mut u64) { - match SBRANGE{ - 1 =>{*b = self.q - *a} - 2 =>{*b = self.two_q - *a} - 4 =>{*b = self.four_q - *a} + match SBRANGE { + 1 => *b = self.q - *a, + 2 => *b = self.two_q - *a, + 4 => *b = self.four_q - *a, _ => unreachable!("invalid SBRANGE argument"), } self.sa_reduce_into_sa::(b) @@ -129,10 +137,10 @@ impl ScalarOperations for Prime { c: &Barrett, d: &mut u64, ) { - match VBRANGE{ - 1 =>{*d = a + self.q - b} - 2 =>{*d = a + self.two_q - b} - 4 =>{*d = a + self.four_q - b} + match VBRANGE { + 1 => *d = a + self.q - b, + 2 => *d = a + self.two_q - b, + 4 => *d = a + self.four_q - b, _ => unreachable!("invalid SBRANGE argument"), } self.barrett.mul_external_assign::(*c, d); @@ -155,7 +163,7 @@ impl ScalarOperations for Prime { a: &u64, b: &u64, c: &Barrett, - d: &mut u64 + d: &mut u64, ) { *d = self.barrett.mul_external::(*c, *a + *b); } @@ -165,19 +173,19 @@ impl ScalarOperations for Prime { &self, b: &u64, c: &Barrett, - a: &mut u64 + a: &mut u64, ) { *a = self.barrett.mul_external::(*c, *a + *b); } - #[inline(always)] + #[inline(always)] fn sb_sub_sa_add_sc_mul_sd_into_se( &self, a: &u64, b: &u64, c: &u64, d: &Barrett, - e: &mut u64 + e: &mut u64, ) { self.sa_sub_sb_into_sc::(&(b + c), a, e); self.barrett.mul_external_assign::(*d, e); @@ -189,12 +197,11 @@ impl ScalarOperations for Prime { b: &u64, c: &u64, d: &Barrett, - a: &mut u64 + a: &mut u64, ) { self.sa_sub_sb_into_sb::(&(b + c), a); self.barrett.mul_external_assign::(*d, a); } - } impl VectorOperations for Prime { @@ -255,7 +262,14 @@ impl VectorOperations for Prime { b: &[u64], c: &mut [u64], ) { - apply_vvv!(self, Self::sa_sub_sb_into_sc::, a, b, c, CHUNK); + apply_vvv!( + self, + Self::sa_sub_sb_into_sc::, + a, + b, + c, + CHUNK + ); } #[inline(always)] @@ -264,7 +278,13 @@ impl VectorOperations for Prime { b: &[u64], a: &mut [u64], ) { - apply_vv!(self, Self::sa_sub_sb_into_sa::, b, a, CHUNK); + apply_vv!( + self, + Self::sa_sub_sb_into_sa::, + b, + a, + CHUNK + ); } #[inline(always)] @@ -273,11 +293,20 @@ impl VectorOperations for Prime { a: &[u64], b: &mut [u64], ) { - apply_vv!(self, Self::sa_sub_sb_into_sb::, a, b, CHUNK); + apply_vv!( + self, + Self::sa_sub_sb_into_sb::, + a, + b, + CHUNK + ); } #[inline(always)] - fn va_neg_into_va(&self, a: &mut [u64]) { + fn va_neg_into_va( + &self, + a: &mut [u64], + ) { apply_v!(self, Self::sa_neg_into_sa::, a, CHUNK); } @@ -415,14 +444,18 @@ impl VectorOperations for Prime { } // vec(e) <- (vec(a) - vec(b) + scalar(c)) * scalar(e). - fn vb_sub_va_add_sc_mul_sd_into_ve( + fn vb_sub_va_add_sc_mul_sd_into_ve< + const CHUNK: usize, + const VBRANGE: u8, + const REDUCE: REDUCEMOD, + >( &self, va: &[u64], vb: &[u64], sc: &u64, sd: &Barrett, ve: &mut [u64], - ){ + ) { apply_vvssv!( self, Self::sb_sub_sa_add_sc_mul_sd_into_se::, @@ -436,14 +469,17 @@ impl VectorOperations for Prime { } // vec(a) <- (vec(b) - vec(a) + scalar(c)) * scalar(e). - fn vb_sub_va_add_sc_mul_sd_into_va( + fn vb_sub_va_add_sc_mul_sd_into_va< + const CHUNK: usize, + const VBRANGE: u8, + const REDUCE: REDUCEMOD, + >( &self, vb: &[u64], sc: &u64, sd: &Barrett, va: &mut [u64], - ){ - + ) { apply_vssv!( self, Self::sb_sub_sa_add_sc_mul_sd_into_sa::, diff --git a/math/src/num_bigint.rs b/math/src/num_bigint.rs index dff66a3..84a2312 100644 --- a/math/src/num_bigint.rs +++ b/math/src/num_bigint.rs @@ -1,34 +1,32 @@ use num_bigint::BigInt; use num_bigint::Sign; use num_integer::Integer; -use num_traits::{Zero, One, Signed}; +use num_traits::{One, Signed, Zero}; -pub trait Div{ +pub trait Div { fn div_floor(&self, other: &Self) -> Self; fn div_round(&self, other: &Self) -> Self; } -impl Div for BigInt{ - - fn div_floor(&self, other:&Self) -> Self{ +impl Div for BigInt { + fn div_floor(&self, other: &Self) -> Self { let quo: BigInt = self / other; if self.sign() == Sign::Minus { - return quo - BigInt::one() + return quo - BigInt::one(); } - return quo + return quo; } - fn div_round(&self, other:&Self) -> Self{ + fn div_round(&self, other: &Self) -> Self { let (quo, mut rem) = self.div_rem(other); rem <<= 1; - if rem != BigInt::zero() && &rem.abs() > other{ - if self.sign() == other.sign(){ - return quo + BigInt::one() - }else{ - return quo - BigInt::one() + if rem != BigInt::zero() && &rem.abs() > other { + if self.sign() == other.sign() { + return quo + BigInt::one(); + } else { + return quo - BigInt::one(); } } - return quo + return quo; } } - diff --git a/math/src/ring/impl_u64/rescaling_rns.rs b/math/src/ring/impl_u64/rescaling_rns.rs index 891390b..d1c9690 100644 --- a/math/src/ring/impl_u64/rescaling_rns.rs +++ b/math/src/ring/impl_u64/rescaling_rns.rs @@ -1,5 +1,5 @@ use crate::modulus::barrett::Barrett; -use crate::modulus::{NONE, ONCE, BARRETT}; +use crate::modulus::{BARRETT, NONE, ONCE}; use crate::poly::PolyRNS; use crate::ring::Ring; use crate::ring::RingRNS; @@ -31,9 +31,8 @@ impl RingRNS { let level = self.level(); let rescaling_constants: ScalarRNS> = self.rescaling_constant(); let r_last: &Ring = &self.0[level]; - - if ROUND{ + if ROUND { let q_level_half: u64 = r_last.modulus.q >> 1; let (buf_q_scaling, buf_qi_scaling) = buf.0.split_at_mut(1); @@ -56,7 +55,11 @@ impl RingRNS { ); } } else { - r_last.a_add_b_scalar_into_c::(a.at(self.level()), &q_level_half, &mut buf_q_scaling[0]); + r_last.a_add_b_scalar_into_c::( + a.at(self.level()), + &q_level_half, + &mut buf_q_scaling[0], + ); for (i, r) in self.0[0..level].iter().enumerate() { r_last.a_add_b_scalar_into_c::( &buf_q_scaling[0], @@ -71,7 +74,7 @@ impl RingRNS { ); } } - }else{ + } else { if NTT { let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); self.0[level].intt::(a.at(level), &mut buf_ntt_q_scaling[0]); @@ -115,8 +118,7 @@ impl RingRNS { let rescaling_constants: ScalarRNS> = self.rescaling_constant(); let r_last: &Ring = &self.0[level]; - if ROUND{ - + if ROUND { let q_level_half: u64 = r_last.modulus.q >> 1; let (buf_q_scaling, buf_qi_scaling) = buf.0.split_at_mut(1); @@ -148,8 +150,7 @@ impl RingRNS { ); } } - }else{ - + } else { if NTT { let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); r_last.intt::(a.at(level), &mut buf_ntt_q_scaling[0]); @@ -161,7 +162,7 @@ impl RingRNS { a.at_mut(i), ); } - }else{ + } else { let (a_i, a_level) = a.0.split_at_mut(level); for (i, r) in self.0[0..level].iter().enumerate() { r.b_sub_a_mul_c_scalar_barrett_into_a::<2, ONCE>( @@ -172,7 +173,6 @@ impl RingRNS { } } } - } /// Updates b to floor(a / prod_{level - nb_moduli}^{level} q[i]) @@ -207,14 +207,13 @@ impl RingRNS { c.level(), a.level() - nb_moduli ); - + if nb_moduli == 0 { if a != c { c.copy(a); } } else { - - if NTT{ + if NTT { self.intt::(a, buf); (0..nb_moduli).for_each(|i| { self.at_level(self.level() - i) @@ -224,24 +223,24 @@ impl RingRNS { ) }); self.at_level(self.level() - nb_moduli).ntt::(buf, c); - }else{ - + } else { println!("{} {:?}", self.level(), buf.level()); self.div_by_last_modulus::(a, buf, c); - (1..nb_moduli-1).for_each(|i| { + (1..nb_moduli - 1).for_each(|i| { println!("{} {:?}", self.level() - i, buf.level()); self.at_level(self.level() - i) .div_by_last_modulus_inplace::(buf, c); }); - - self.at_level(self.level()-nb_moduli+1).div_by_last_modulus_inplace::(buf, c); + + self.at_level(self.level() - nb_moduli + 1) + .div_by_last_modulus_inplace::(buf, c); } } } /// Updates a to floor(a / prod_{level - nb_moduli}^{level} q[i]) - pub fn div_by_last_moduli_inplace( + pub fn div_by_last_moduli_inplace( &self, nb_moduli: usize, buf: &mut PolyRNS, @@ -259,15 +258,18 @@ impl RingRNS { nb_moduli, a.level() ); - if nb_moduli == 0{ - return + if nb_moduli == 0 { + return; } if NTT { self.intt::(a, buf); (0..nb_moduli).for_each(|i| { self.at_level(self.level() - i) - .div_by_last_modulus_inplace::(&mut PolyRNS::::default(), buf) + .div_by_last_modulus_inplace::( + &mut PolyRNS::::default(), + buf, + ) }); self.at_level(self.level() - nb_moduli).ntt::(buf, a); } else { diff --git a/math/src/ring/impl_u64/ring.rs b/math/src/ring/impl_u64/ring.rs index 4f22de0..29bf97a 100644 --- a/math/src/ring/impl_u64/ring.rs +++ b/math/src/ring/impl_u64/ring.rs @@ -83,7 +83,12 @@ impl Ring { } #[inline(always)] - pub fn a_add_b_into_c(&self, a: &Poly, b: &Poly, c: &mut Poly) { + pub fn a_add_b_into_c( + &self, + a: &Poly, + b: &Poly, + c: &mut Poly, + ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); @@ -98,7 +103,12 @@ impl Ring { } #[inline(always)] - pub fn a_add_b_scalar_into_c(&self, a: &Poly, b: &u64, c: &mut Poly) { + pub fn a_add_b_scalar_into_c( + &self, + a: &Poly, + b: &u64, + c: &mut Poly, + ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); self.modulus @@ -106,13 +116,25 @@ impl Ring { } #[inline(always)] - pub fn a_add_scalar_b_mul_c_scalar_barrett_into_a(&self, b: &u64, c: &Barrett, a: &mut Poly) { + pub fn a_add_scalar_b_mul_c_scalar_barrett_into_a( + &self, + b: &u64, + c: &Barrett, + a: &mut Poly, + ) { debug_assert!(a.n() == self.n(), "b.n()={} != n={}", a.n(), self.n()); - self.modulus.va_add_sb_mul_sc_into_va::(b, c, &mut a.0); + self.modulus + .va_add_sb_mul_sc_into_va::(b, c, &mut a.0); } #[inline(always)] - pub fn add_scalar_then_mul_scalar_barrett(&self, a: &Poly, b: &u64, c: &Barrett, d: &mut Poly) { + pub fn add_scalar_then_mul_scalar_barrett( + &self, + a: &Poly, + b: &u64, + c: &Barrett, + d: &mut Poly, + ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(d.n() == self.n(), "c.n()={} != n={}", d.n(), self.n()); self.modulus @@ -120,7 +142,11 @@ impl Ring { } #[inline(always)] - pub fn a_sub_b_into_b(&self, a: &Poly, b: &mut Poly) { + pub fn a_sub_b_into_b( + &self, + a: &Poly, + b: &mut Poly, + ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus @@ -128,7 +154,11 @@ impl Ring { } #[inline(always)] - pub fn a_sub_b_into_a(&self, b: &Poly, a: &mut Poly) { + pub fn a_sub_b_into_a( + &self, + b: &Poly, + a: &mut Poly, + ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus @@ -136,7 +166,12 @@ impl Ring { } #[inline(always)] - pub fn a_sub_b_into_c(&self, a: &Poly, b: &Poly, c: &mut Poly) { + pub fn a_sub_b_into_c( + &self, + a: &Poly, + b: &Poly, + c: &mut Poly, + ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); @@ -145,16 +180,22 @@ impl Ring { } #[inline(always)] - pub fn a_neg_into_b(&self, a: &Poly, b: &mut Poly) { + pub fn a_neg_into_b( + &self, + a: &Poly, + b: &mut Poly, + ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); - self.modulus.va_neg_into_vb::(&a.0, &mut b.0); + self.modulus + .va_neg_into_vb::(&a.0, &mut b.0); } #[inline(always)] - pub fn a_neg_into_a(&self, a: &mut Poly) { + pub fn a_neg_into_a(&self, a: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); - self.modulus.va_neg_into_va::(&mut a.0); + self.modulus + .va_neg_into_va::(&mut a.0); } #[inline(always)] @@ -184,7 +225,12 @@ impl Ring { } #[inline(always)] - pub fn a_mul_b_scalar_into_c(&self, a: &Poly, b: &u64, c: &mut Poly) { + pub fn a_mul_b_scalar_into_c( + &self, + a: &Poly, + b: &u64, + c: &mut Poly, + ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); self.modulus.sa_barrett_mul_vb_into_vc::( @@ -258,14 +304,17 @@ impl Ring { } #[inline(always)] - pub fn a_sub_b_add_c_scalar_mul_d_scalar_barrett_into_e( + pub fn a_sub_b_add_c_scalar_mul_d_scalar_barrett_into_e< + const BRANGE: u8, + const REDUCE: REDUCEMOD, + >( &self, a: &Poly, b: &Poly, c: &u64, d: &Barrett, e: &mut Poly, - ){ + ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(e.n() == self.n(), "e.n()={} != n={}", e.n(), self.n()); @@ -274,17 +323,19 @@ impl Ring { } #[inline(always)] - pub fn b_sub_a_add_c_scalar_mul_d_scalar_barrett_into_a( + pub fn b_sub_a_add_c_scalar_mul_d_scalar_barrett_into_a< + const BRANGE: u8, + const REDUCE: REDUCEMOD, + >( &self, b: &Poly, c: &u64, d: &Barrett, a: &mut Poly, - ){ + ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus .vb_sub_va_add_sc_mul_sd_into_va::(&b.0, c, d, &mut a.0); } - } diff --git a/math/src/ring/impl_u64/ring_rns.rs b/math/src/ring/impl_u64/ring_rns.rs index 92c6d73..c424dd5 100644 --- a/math/src/ring/impl_u64/ring_rns.rs +++ b/math/src/ring/impl_u64/ring_rns.rs @@ -7,8 +7,6 @@ use crate::scalar::ScalarRNS; use num_bigint::BigInt; use std::sync::Arc; - - impl RingRNS { pub fn new(n: usize, moduli: Vec) -> Self { assert!(!moduli.is_empty(), "moduli cannot be empty"); @@ -198,14 +196,17 @@ impl RingRNS { c.level(), self.level() ); - self.0 - .iter() - .enumerate() - .for_each(|(i, ring)| ring.a_sub_b_into_c::(&a.0[i], &b.0[i], &mut c.0[i])); + self.0.iter().enumerate().for_each(|(i, ring)| { + ring.a_sub_b_into_c::(&a.0[i], &b.0[i], &mut c.0[i]) + }); } #[inline(always)] - pub fn a_sub_b_into_b(&self, a: &PolyRNS, b: &mut PolyRNS) { + pub fn a_sub_b_into_b( + &self, + a: &PolyRNS, + b: &mut PolyRNS, + ) { debug_assert!( a.level() >= self.level(), "a.level()={} < self.level()={}", @@ -225,7 +226,11 @@ impl RingRNS { } #[inline(always)] - pub fn a_sub_b_into_a(&self, b: &PolyRNS, a: &mut PolyRNS) { + pub fn a_sub_b_into_a( + &self, + b: &PolyRNS, + a: &mut PolyRNS, + ) { debug_assert!( a.level() >= self.level(), "a.level()={} < self.level()={}", @@ -245,7 +250,11 @@ impl RingRNS { } #[inline(always)] - pub fn a_neg_into_b(&self, a: &PolyRNS, b: &mut PolyRNS) { + pub fn a_neg_into_b( + &self, + a: &PolyRNS, + b: &mut PolyRNS, + ) { debug_assert!( a.level() >= self.level(), "a.level()={} < self.level()={}", @@ -326,9 +335,10 @@ impl RingRNS { b.level(), self.level() ); - self.0.iter().enumerate().for_each(|(i, ring)| { - ring.a_mul_b_montgomery_into_a::(&a.0[i], &mut b.0[i]) - }); + self.0 + .iter() + .enumerate() + .for_each(|(i, ring)| ring.a_mul_b_montgomery_into_a::(&a.0[i], &mut b.0[i])); } #[inline(always)] @@ -371,7 +381,17 @@ impl RingRNS { } #[inline(always)] - pub fn a_sub_b_add_scalar_mul_scalar_barrett_into_e(&self, a: &PolyRNS, b: &PolyRNS, c: &u64, d: &Barrett, e: &mut PolyRNS){ + pub fn a_sub_b_add_scalar_mul_scalar_barrett_into_e< + const BRANGE: u8, + const REDUCE: REDUCEMOD, + >( + &self, + a: &PolyRNS, + b: &PolyRNS, + c: &u64, + d: &Barrett, + e: &mut PolyRNS, + ) { debug_assert!( a.level() >= self.level(), "a.level()={} < self.level()={}", @@ -390,14 +410,28 @@ impl RingRNS { e.level(), self.level() ); - self.0 - .iter() - .enumerate() - .for_each(|(i, ring)| ring.a_sub_b_add_c_scalar_mul_d_scalar_barrett_into_e::(&a.0[i], &b.0[i], c, d, &mut e.0[i])); + self.0.iter().enumerate().for_each(|(i, ring)| { + ring.a_sub_b_add_c_scalar_mul_d_scalar_barrett_into_e::( + &a.0[i], + &b.0[i], + c, + d, + &mut e.0[i], + ) + }); } #[inline(always)] - pub fn b_sub_a_add_c_scalar_mul_d_scalar_barrett_into_a(&self, b: &PolyRNS, c: &u64, d: &Barrett, a: &mut PolyRNS){ + pub fn b_sub_a_add_c_scalar_mul_d_scalar_barrett_into_a< + const BRANGE: u8, + const REDUCE: REDUCEMOD, + >( + &self, + b: &PolyRNS, + c: &u64, + d: &Barrett, + a: &mut PolyRNS, + ) { debug_assert!( a.level() >= self.level(), "a.level()={} < self.level()={}", @@ -410,9 +444,13 @@ impl RingRNS { b.level(), self.level() ); - self.0 - .iter() - .enumerate() - .for_each(|(i, ring)| ring.b_sub_a_add_c_scalar_mul_d_scalar_barrett_into_a::(&b.0[i], c, d, &mut a.0[i])); + self.0.iter().enumerate().for_each(|(i, ring)| { + ring.b_sub_a_add_c_scalar_mul_d_scalar_barrett_into_a::( + &b.0[i], + c, + d, + &mut a.0[i], + ) + }); } } diff --git a/math/src/ring/impl_u64/sampling.rs b/math/src/ring/impl_u64/sampling.rs index 0aafe1a..4955a00 100644 --- a/math/src/ring/impl_u64/sampling.rs +++ b/math/src/ring/impl_u64/sampling.rs @@ -2,7 +2,7 @@ use crate::modulus::WordOps; use crate::poly::{Poly, PolyRNS}; use crate::ring::{Ring, RingRNS}; use num::ToPrimitive; -use rand_distr::{Normal, Distribution}; +use rand_distr::{Distribution, Normal}; use sampling::source::Source; impl Ring { @@ -13,21 +13,25 @@ impl Ring { .for_each(|a| *a = source.next_u64n(max, mask)); } - pub fn fill_dist_f64>(&self, source: &mut Source, dist: T, bound: f64, a: &mut Poly) { + pub fn fill_dist_f64>( + &self, + source: &mut Source, + dist: T, + bound: f64, + a: &mut Poly, + ) { let max: u64 = self.modulus.q; - a.0.iter_mut() - .for_each(|a| { + a.0.iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); - let mut dist_f64: f64 = dist.sample(source); - - while dist_f64.abs() > bound{ - dist_f64 = dist.sample(source) - } + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } - let dist_u64: u64 = (dist_f64+0.5).abs().to_u64().unwrap(); - let sign: u64 = dist_f64.to_bits()>>63; + let dist_u64: u64 = (dist_f64 + 0.5).abs().to_u64().unwrap(); + let sign: u64 = dist_f64.to_bits() >> 63; - *a = (dist_u64 * sign) | (max-dist_u64)*(sign^1) + *a = (dist_u64 * sign) | (max - dist_u64) * (sign ^ 1) }); } } @@ -40,19 +44,25 @@ impl RingRNS { .for_each(|(i, r)| r.fill_uniform(source, a.at_mut(i))); } - pub fn fill_dist_f64>(&self, source: &mut Source, dist: T, bound: f64, a: &mut PolyRNS) { - (0..a.n()).for_each(|j|{ + pub fn fill_dist_f64>( + &self, + source: &mut Source, + dist: T, + bound: f64, + a: &mut PolyRNS, + ) { + (0..a.n()).for_each(|j| { let mut dist_f64: f64 = dist.sample(source); - - while dist_f64.abs() > bound{ + + while dist_f64.abs() > bound { dist_f64 = dist.sample(source) } - let dist_u64: u64 = (dist_f64+0.5).abs().to_u64().unwrap(); - let sign: u64 = dist_f64.to_bits()>>63; + let dist_u64: u64 = (dist_f64 + 0.5).abs().to_u64().unwrap(); + let sign: u64 = dist_f64.to_bits() >> 63; - self.0.iter().enumerate().for_each(|(i, r)|{ - a.at_mut(i).0[j] = (dist_u64 * sign) | (r.modulus.q-dist_u64)*(sign^1); + self.0.iter().enumerate().for_each(|(i, r)| { + a.at_mut(i).0[j] = (dist_u64 * sign) | (r.modulus.q - dist_u64) * (sign ^ 1); }) }) } diff --git a/math/tests/rescaling_rns.rs b/math/tests/rescaling_rns.rs index 60d684f..cba05bc 100644 --- a/math/tests/rescaling_rns.rs +++ b/math/tests/rescaling_rns.rs @@ -1,30 +1,49 @@ +use itertools::izip; +use math::num_bigint::Div; use math::poly::PolyRNS; use math::ring::RingRNS; use num_bigint::BigInt; -use math::num_bigint::Div; use sampling::source::Source; -use itertools::izip; #[test] fn rescaling_rns_u64() { let n = 1 << 10; - let moduli: Vec = vec![0x1fffffffffc80001u64, 0x1fffffffffe00001u64, 0x1fffffffffb40001, 0x1fffffffff500001]; + let moduli: Vec = vec![ + 0x1fffffffffc80001u64, + 0x1fffffffffe00001u64, + 0x1fffffffffb40001, + 0x1fffffffff500001, + ]; let ring_rns: RingRNS = RingRNS::new(n, moduli); - - sub_test("test_div_by_last_modulus::", ||{test_div_by_last_modulus::(&ring_rns)}); - sub_test("test_div_by_last_modulus::", ||{test_div_by_last_modulus::(&ring_rns)}); - sub_test("test_div_by_last_modulus::", ||{test_div_by_last_modulus::(&ring_rns)}); - sub_test("test_div_by_last_modulus::", ||{test_div_by_last_modulus::(&ring_rns)}); - sub_test("test_div_by_last_modulus_inplace::", ||{test_div_by_last_modulus_inplace::(&ring_rns)}); - sub_test("test_div_by_last_modulus_inplace::", ||{test_div_by_last_modulus_inplace::(&ring_rns)}); - sub_test("test_div_by_last_modulus_inplace::", ||{test_div_by_last_modulus_inplace::(&ring_rns)}); - sub_test("test_div_by_last_modulus_inplace::", ||{test_div_by_last_modulus_inplace::(&ring_rns)}); - - - - - + sub_test("test_div_by_last_modulus::", || { + test_div_by_last_modulus::(&ring_rns) + }); + sub_test("test_div_by_last_modulus::", || { + test_div_by_last_modulus::(&ring_rns) + }); + sub_test("test_div_by_last_modulus::", || { + test_div_by_last_modulus::(&ring_rns) + }); + sub_test("test_div_by_last_modulus::", || { + test_div_by_last_modulus::(&ring_rns) + }); + sub_test( + "test_div_by_last_modulus_inplace::", + || test_div_by_last_modulus_inplace::(&ring_rns), + ); + sub_test( + "test_div_by_last_modulus_inplace::", + || test_div_by_last_modulus_inplace::(&ring_rns), + ); + sub_test( + "test_div_by_last_modulus_inplace::", + || test_div_by_last_modulus_inplace::(&ring_rns), + ); + sub_test( + "test_div_by_last_modulus_inplace::", + || test_div_by_last_modulus_inplace::(&ring_rns), + ); //sub_test("test_div_by_last_moduli::", ||{test_div_by_last_moduli::(&ring_rns)}); } @@ -34,8 +53,7 @@ fn sub_test(name: &str, f: F) { f(); } -fn test_div_by_last_modulus(ring_rns: &RingRNS){ - +fn test_div_by_last_modulus(ring_rns: &RingRNS) { let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); @@ -57,8 +75,7 @@ fn test_div_by_last_modulus(ring_rns: &RingRNS ring_rns.ntt_inplace::(&mut a); } - ring_rns.div_by_last_modulus::(&a, &mut b, &mut c); - + ring_rns.div_by_last_modulus::(&a, &mut b, &mut c); if NTT { ring_rns.at_level(c.level()).intt_inplace::(&mut c); @@ -73,18 +90,17 @@ fn test_div_by_last_modulus(ring_rns: &RingRNS // Performs floor division on a let scalar_big = BigInt::from(ring_rns.0[ring_rns.level()].modulus.q); coeffs_a.iter_mut().for_each(|a| { - if ROUND{ + if ROUND { *a = a.div_round(&scalar_big); - }else{ + } else { *a = a.div_floor(&scalar_big); - } + } }); izip!(coeffs_a, coeffs_c).for_each(|(a, b)| assert_eq!(a, b)); } -fn test_div_by_last_modulus_inplace(ring_rns: &RingRNS) { - +fn test_div_by_last_modulus_inplace(ring_rns: &RingRNS) { let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); @@ -105,34 +121,34 @@ fn test_div_by_last_modulus_inplace(ring_rns: ring_rns.ntt_inplace::(&mut a); } - ring_rns.div_by_last_modulus_inplace::(&mut buf, &mut a); + ring_rns.div_by_last_modulus_inplace::(&mut buf, &mut a); if NTT { - ring_rns.at_level(a.level()-1).intt_inplace::(&mut a); + ring_rns + .at_level(a.level() - 1) + .intt_inplace::(&mut a); } // Exports c to coeffs_c let mut coeffs_c = vec![BigInt::from(0); a.n()]; ring_rns - .at_level(a.level()-1) + .at_level(a.level() - 1) .to_bigint_inplace(&a, 1, &mut coeffs_c); // Performs floor division on a let scalar_big = BigInt::from(ring_rns.0[ring_rns.level()].modulus.q); coeffs_a.iter_mut().for_each(|a| { - if ROUND{ + if ROUND { *a = a.div_round(&scalar_big); - }else{ + } else { *a = a.div_floor(&scalar_big); - } + } }); izip!(coeffs_a, coeffs_c).for_each(|(a, b)| assert_eq!(a, b)); } - -fn test_div_by_last_moduli(ring_rns: &RingRNS){ - +fn test_div_by_last_moduli(ring_rns: &RingRNS) { let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); @@ -140,7 +156,9 @@ fn test_div_by_last_moduli(ring_rns: &RingRNS< let mut a: PolyRNS = ring_rns.new_polyrns(); let mut buf: PolyRNS = ring_rns.new_polyrns(); - let mut c: PolyRNS = ring_rns.at_level(ring_rns.level() - nb_moduli).new_polyrns(); + let mut c: PolyRNS = ring_rns + .at_level(ring_rns.level() - nb_moduli) + .new_polyrns(); // Allocates a random PolyRNS ring_rns.fill_uniform(&mut source, &mut a); @@ -156,7 +174,7 @@ fn test_div_by_last_moduli(ring_rns: &RingRNS< ring_rns.ntt_inplace::(&mut a); } - ring_rns.div_by_last_moduli::(nb_moduli, &a, &mut buf, &mut c); + ring_rns.div_by_last_moduli::(nb_moduli, &a, &mut buf, &mut c); if NTT { ring_rns.at_level(c.level()).intt_inplace::(&mut c); @@ -170,19 +188,20 @@ fn test_div_by_last_moduli(ring_rns: &RingRNS< // Performs floor division on a let mut scalar_big = BigInt::from(1); - (0..nb_moduli).for_each(|i|{scalar_big *= BigInt::from(ring_rns.0[ring_rns.level()-i].modulus.q)}); + (0..nb_moduli) + .for_each(|i| scalar_big *= BigInt::from(ring_rns.0[ring_rns.level() - i].modulus.q)); coeffs_a.iter_mut().for_each(|a| { - if ROUND{ + if ROUND { *a = a.div_round(&scalar_big); - }else{ + } else { *a = a.div_floor(&scalar_big); - } + } }); izip!(coeffs_a, coeffs_c).for_each(|(a, b)| assert_eq!(a, b)); } -/* +/* fn test_div_floor_by_last_moduli_inplace(ring_rns: &RingRNS) { let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); @@ -225,4 +244,4 @@ fn test_div_floor_by_last_moduli_inplace(ring_rns: &RingRNS