diff --git a/Cargo.lock b/Cargo.lock index 215ec83..23c4b48 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -256,6 +256,12 @@ version = "0.2.167" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc" +[[package]] +name = "libm" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" + [[package]] name = "log" version = "0.4.22" @@ -270,9 +276,11 @@ dependencies = [ "itertools 0.14.0", "num", "num-bigint", + "num-integer", "num-traits", "primality-test", "prime_factorization", + "rand_distr", "sampling", ] @@ -353,6 +361,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -469,6 +478,16 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "rayon" version = "1.10.0" diff --git a/math/Cargo.toml b/math/Cargo.toml index c35a5c5..1c84bb0 100644 --- a/math/Cargo.toml +++ b/math/Cargo.toml @@ -8,9 +8,11 @@ num = "0.4.3" primality-test = "0.3.0" num-bigint = "0.4.6" num-traits = "0.2.19" +num-integer ="0.1.46" prime_factorization = "1.0.5" itertools = "0.14.0" criterion = "0.5.1" +rand_distr = "0.4.3" sampling = { path = "../sampling" } [[bench]] diff --git a/math/benches/ring_rns.rs b/math/benches/ring_rns.rs index aee46e1..2ffb1c6 100644 --- a/math/benches/ring_rns.rs +++ b/math/benches/ring_rns.rs @@ -8,7 +8,7 @@ fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) { let mut b: PolyRNS = r.new_polyrns(); let mut c: PolyRNS = r.new_polyrns(); - Box::new(move || r.div_floor_by_last_modulus::(&a, &mut b, &mut c)) + Box::new(move || r.div_by_last_modulus::(&a, &mut b, &mut c)) } let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = diff --git a/math/src/lib.rs b/math/src/lib.rs index c471447..5a94a97 100644 --- a/math/src/lib.rs +++ b/math/src/lib.rs @@ -6,6 +6,7 @@ pub mod modulus; pub mod poly; pub mod ring; pub mod scalar; +pub mod num_bigint; pub const CHUNK: usize = 8; @@ -398,4 +399,59 @@ pub mod macros { } }; } + + #[macro_export] + macro_rules! apply_vvssv { + ($self:expr, $f:expr, $a:expr, $b:expr, $c:expr, $d:expr, $e:expr, $CHUNK:expr) => { + let n: usize = $a.len(); + debug_assert!( + $b.len() == n, + "invalid argument b: b.len() = {} != a.len() = {}", + $b.len(), + n + ); + debug_assert!( + $e.len() == n, + "invalid argument e: e.len() = {} != a.len() = {}", + $e.len(), + n + ); + debug_assert!( + CHUNK & (CHUNK - 1) == 0, + "invalid CHUNK const: not a power of two" + ); + + match CHUNK { + 8 => { + izip!( + $a.chunks_exact(8), + $b.chunks_exact(8), + $e.chunks_exact_mut(8) + ) + .for_each(|(a, b, e)| { + $f(&$self, &a[0], &b[0], $c, $d, &mut e[0]); + $f(&$self, &a[1], &b[1], $c, $d, &mut e[1]); + $f(&$self, &a[2], &b[2], $c, $d, &mut e[2]); + $f(&$self, &a[3], &b[3], $c, $d, &mut e[3]); + $f(&$self, &a[4], &b[4], $c, $d, &mut e[4]); + $f(&$self, &a[5], &b[5], $c, $d, &mut e[5]); + $f(&$self, &a[6], &b[6], $c, $d, &mut e[6]); + $f(&$self, &a[7], &b[7], $c, $d, &mut e[7]); + }); + + let m = n - (n & 7); + izip!($a[m..].iter(), $b[m..].iter(), $e[m..].iter_mut()).for_each( + |(a, b, e)| { + $f(&$self, a, b, $c, $d, e); + }, + ); + } + _ => { + izip!($a.iter(), $b.iter(), $e.iter_mut()).for_each(|(a, b, e)| { + $f(&$self, a, b, $c, $d, e); + }); + } + } + }; + } } diff --git a/math/src/modulus.rs b/math/src/modulus.rs index dc97687..d9ed6a9 100644 --- a/math/src/modulus.rs +++ b/math/src/modulus.rs @@ -79,6 +79,9 @@ pub trait ScalarOperations { // Assigns a - b to b. fn sa_sub_sb_into_sb(&self, a: &O, b: &mut O); + // Assigns a - b to a. + fn sa_sub_sb_into_sa(&self, b: &O, a: &mut O); + // Assigns -a to a. fn sa_neg_into_sa(&self, a: &mut O); @@ -155,6 +158,24 @@ pub trait ScalarOperations { c: &barrett::Barrett, d: &mut u64 ); + + // Assigns (a - b + c) * d to e. + fn sb_sub_sa_add_sc_mul_sd_into_se( + &self, + a: &u64, + b: &u64, + c: &u64, + d: &barrett::Barrett, + e: &mut u64 + ); + + fn sb_sub_sa_add_sc_mul_sd_into_sa( + &self, + b: &u64, + c: &u64, + d: &barrett::Barrett, + a: &mut u64 + ); } pub trait VectorOperations { @@ -187,6 +208,9 @@ pub trait VectorOperations { // vec(b) <- vec(a) - vec(b). fn va_sub_vb_into_vb(&self, a: &[O], b: &mut [O]); + // vec(a) <- vec(a) - vec(b). + fn va_sub_vb_into_va(&self, b: &[O], a: &mut [O]); + // vec(c) <- vec(a) - vec(b). fn va_sub_vb_into_vc( &self, @@ -271,4 +295,23 @@ pub trait VectorOperations { sc: &barrett::Barrett, va: &mut [u64], ); + + // vec(e) <- (vec(b) - vec(a) + scalar(c)) * scalar(e). + fn vb_sub_va_add_sc_mul_sd_into_ve( + &self, + va: &[u64], + vb: &[u64], + sc: &u64, + sd: &barrett::Barrett, + ve: &mut [u64], + ); + + // vec(a) <- (vec(b) - vec(a) + scalar(c)) * scalar(e). + fn vb_sub_va_add_sc_mul_sd_into_va( + &self, + vb: &[u64], + sc: &u64, + sd: &barrett::Barrett, + va: &mut [u64], + ); } diff --git a/math/src/modulus/impl_u64/operations.rs b/math/src/modulus/impl_u64/operations.rs index ae2a3cd..e6e2a09 100644 --- a/math/src/modulus/impl_u64/operations.rs +++ b/math/src/modulus/impl_u64/operations.rs @@ -3,7 +3,7 @@ use crate::modulus::montgomery::Montgomery; use crate::modulus::prime::Prime; use crate::modulus::{REDUCEMOD, NONE}; use crate::modulus::{ScalarOperations, VectorOperations}; -use crate::{apply_sv, apply_svv, apply_v, apply_vsv, apply_vv, apply_vvsv, apply_vvv, apply_ssv, apply_vssv}; +use crate::{apply_sv, apply_svv, apply_v, apply_vsv, apply_vv, apply_vvsv, apply_vvv, apply_ssv, apply_vssv, apply_vvssv}; use itertools::izip; impl ScalarOperations for Prime { @@ -42,6 +42,17 @@ impl ScalarOperations for Prime { self.sa_reduce_into_sa::(c) } + #[inline(always)] + fn sa_sub_sb_into_sa(&self, b: &u64, a: &mut u64) { + match SBRANGE{ + 1 =>{*a = *a + self.q - *b} + 2 =>{*a = *a + self.two_q - *b} + 4 =>{*a = *a + self.four_q - *b} + _ => unreachable!("invalid SBRANGE argument"), + } + self.sa_reduce_into_sa::(a) + } + #[inline(always)] fn sa_sub_sb_into_sb(&self, a: &u64, b: &mut u64) { match SBRANGE{ @@ -159,6 +170,31 @@ impl ScalarOperations for Prime { *a = self.barrett.mul_external::(*c, *a + *b); } + #[inline(always)] + fn sb_sub_sa_add_sc_mul_sd_into_se( + &self, + a: &u64, + b: &u64, + c: &u64, + d: &Barrett, + e: &mut u64 + ) { + self.sa_sub_sb_into_sc::(&(b + c), a, e); + self.barrett.mul_external_assign::(*d, e); + } + + #[inline(always)] + fn sb_sub_sa_add_sc_mul_sd_into_sa( + &self, + b: &u64, + c: &u64, + d: &Barrett, + a: &mut u64 + ) { + self.sa_sub_sb_into_sb::(&(b + c), a); + self.barrett.mul_external_assign::(*d, a); + } + } impl VectorOperations for Prime { @@ -222,6 +258,15 @@ impl VectorOperations for Prime { apply_vvv!(self, Self::sa_sub_sb_into_sc::, a, b, c, CHUNK); } + #[inline(always)] + fn va_sub_vb_into_va( + &self, + b: &[u64], + a: &mut [u64], + ) { + apply_vv!(self, Self::sa_sub_sb_into_sa::, b, a, CHUNK); + } + #[inline(always)] fn va_sub_vb_into_vb( &self, @@ -368,4 +413,45 @@ impl VectorOperations for Prime { CHUNK ); } + + // vec(e) <- (vec(a) - vec(b) + scalar(c)) * scalar(e). + fn vb_sub_va_add_sc_mul_sd_into_ve( + &self, + va: &[u64], + vb: &[u64], + sc: &u64, + sd: &Barrett, + ve: &mut [u64], + ){ + apply_vvssv!( + self, + Self::sb_sub_sa_add_sc_mul_sd_into_se::, + va, + vb, + sc, + sd, + ve, + CHUNK + ); + } + + // vec(a) <- (vec(b) - vec(a) + scalar(c)) * scalar(e). + fn vb_sub_va_add_sc_mul_sd_into_va( + &self, + vb: &[u64], + sc: &u64, + sd: &Barrett, + va: &mut [u64], + ){ + + apply_vssv!( + self, + Self::sb_sub_sa_add_sc_mul_sd_into_sa::, + vb, + sc, + sd, + va, + CHUNK + ); + } } diff --git a/math/src/num_bigint.rs b/math/src/num_bigint.rs new file mode 100644 index 0000000..dff66a3 --- /dev/null +++ b/math/src/num_bigint.rs @@ -0,0 +1,34 @@ +use num_bigint::BigInt; +use num_bigint::Sign; +use num_integer::Integer; +use num_traits::{Zero, One, Signed}; + +pub trait Div{ + fn div_floor(&self, other: &Self) -> Self; + fn div_round(&self, other: &Self) -> Self; +} + +impl Div for BigInt{ + + fn div_floor(&self, other:&Self) -> Self{ + let quo: BigInt = self / other; + if self.sign() == Sign::Minus { + return quo - BigInt::one() + } + return quo + } + + fn div_round(&self, other:&Self) -> Self{ + let (quo, mut rem) = self.div_rem(other); + rem <<= 1; + if rem != BigInt::zero() && &rem.abs() > other{ + if self.sign() == other.sign(){ + return quo + BigInt::one() + }else{ + return quo - BigInt::one() + } + } + return quo + } +} + diff --git a/math/src/ring/impl_u64/rescaling_rns.rs b/math/src/ring/impl_u64/rescaling_rns.rs index c1526e3..891390b 100644 --- a/math/src/ring/impl_u64/rescaling_rns.rs +++ b/math/src/ring/impl_u64/rescaling_rns.rs @@ -8,7 +8,7 @@ extern crate test; impl RingRNS { /// Updates b to floor(a / q[b.level()]). - pub fn div_floor_by_last_modulus( + pub fn div_by_last_modulus( &self, a: &PolyRNS, buf: &mut PolyRNS, @@ -30,34 +30,76 @@ impl RingRNS { let level = self.level(); let rescaling_constants: ScalarRNS> = self.rescaling_constant(); + let r_last: &Ring = &self.0[level]; + + if ROUND{ - if NTT { - let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); - self.0[level].intt::(a.at(level), &mut buf_ntt_q_scaling[0]); - for (i, r) in self.0[0..level].iter().enumerate() { - r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); - r.a_sub_b_mul_c_scalar_barrett::<2, ONCE>( - &buf_ntt_qi_scaling[0], - a.at(i), - &rescaling_constants.0[i], - b.at_mut(i), - ); + let q_level_half: u64 = r_last.modulus.q >> 1; + + let (buf_q_scaling, buf_qi_scaling) = buf.0.split_at_mut(1); + + if NTT { + r_last.intt::(a.at(level), &mut buf_q_scaling[0]); + r_last.a_add_b_scalar_into_a::(&q_level_half, &mut buf_q_scaling[0]); + for (i, r) in self.0[0..level].iter().enumerate() { + r_last.a_add_b_scalar_into_c::( + &buf_q_scaling[0], + &(r.modulus.q - r_last.modulus.barrett.reduce::(&q_level_half)), + &mut buf_qi_scaling[0], + ); + r.ntt_inplace::(&mut buf_qi_scaling[0]); + r.a_sub_b_mul_c_scalar_barrett_into_d::<2, ONCE>( + &buf_qi_scaling[0], + a.at(i), + &rescaling_constants.0[i], + b.at_mut(i), + ); + } + } else { + r_last.a_add_b_scalar_into_c::(a.at(self.level()), &q_level_half, &mut buf_q_scaling[0]); + for (i, r) in self.0[0..level].iter().enumerate() { + r_last.a_add_b_scalar_into_c::( + &buf_q_scaling[0], + &(r.modulus.q - r_last.modulus.barrett.reduce::(&q_level_half)), + &mut buf_qi_scaling[0], + ); + r.a_sub_b_mul_c_scalar_barrett_into_d::<2, ONCE>( + &buf_qi_scaling[0], + a.at(i), + &rescaling_constants.0[i], + b.at_mut(i), + ); + } } - } else { - for (i, r) in self.0[0..level].iter().enumerate() { - r.a_sub_b_mul_c_scalar_barrett::<2, ONCE>( - a.at(level), - a.at(i), - &rescaling_constants.0[i], - b.at_mut(i), - ); + }else{ + if NTT { + let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); + self.0[level].intt::(a.at(level), &mut buf_ntt_q_scaling[0]); + for (i, r) in self.0[0..level].iter().enumerate() { + r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); + r.a_sub_b_mul_c_scalar_barrett_into_d::<2, ONCE>( + &buf_ntt_qi_scaling[0], + a.at(i), + &rescaling_constants.0[i], + b.at_mut(i), + ); + } + } else { + for (i, r) in self.0[0..level].iter().enumerate() { + r.a_sub_b_mul_c_scalar_barrett_into_d::<2, ONCE>( + a.at(level), + a.at(i), + &rescaling_constants.0[i], + b.at_mut(i), + ); + } } } } /// Updates a to floor(a / q[b.level()]). /// Expects a to be in the NTT domain. - pub fn div_floor_by_last_modulus_inplace( + pub fn div_by_last_modulus_inplace( &self, buf: &mut PolyRNS, a: &mut PolyRNS, @@ -71,32 +113,70 @@ impl RingRNS { let level = self.level(); let rescaling_constants: ScalarRNS> = self.rescaling_constant(); + let r_last: &Ring = &self.0[level]; - if NTT { - let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); - self.0[level].intt::(a.at(level), &mut buf_ntt_q_scaling[0]); - for (i, r) in self.0[0..level].iter().enumerate() { - r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); - r.a_sub_b_mul_c_scalar_barrett_inplace::<2, ONCE>( - &buf_ntt_qi_scaling[0], - &rescaling_constants.0[i], - a.at_mut(i), - ); + if ROUND{ + + let q_level_half: u64 = r_last.modulus.q >> 1; + let (buf_q_scaling, buf_qi_scaling) = buf.0.split_at_mut(1); + + if NTT { + r_last.intt::(a.at(level), &mut buf_q_scaling[0]); + r_last.a_add_b_scalar_into_a::(&q_level_half, &mut buf_q_scaling[0]); + for (i, r) in self.0[0..level].iter().enumerate() { + r_last.a_add_b_scalar_into_c::( + &buf_q_scaling[0], + &(r.modulus.q - r_last.modulus.barrett.reduce::(&q_level_half)), + &mut buf_qi_scaling[0], + ); + r.ntt_inplace::(&mut buf_qi_scaling[0]); + r.b_sub_a_mul_c_scalar_barrett_into_a::<2, ONCE>( + &buf_qi_scaling[0], + &rescaling_constants.0[i], + a.at_mut(i), + ); + } + } else { + let (a_qi, a_q_last) = a.0.split_at_mut(self.level()); + r_last.a_add_b_scalar_into_a::(&q_level_half, &mut a_q_last[0]); + for (i, r) in self.0[0..level].iter().enumerate() { + r.b_sub_a_add_c_scalar_mul_d_scalar_barrett_into_a::<1, ONCE>( + &a_q_last[0], + &(r.modulus.q - r_last.modulus.barrett.reduce::(&q_level_half)), + &rescaling_constants.0[i], + &mut a_qi[i], + ); + } } - } else { - let (a_i, a_level) = a.0.split_at_mut(level); - for (i, r) in self.0[0..level].iter().enumerate() { - r.a_sub_b_mul_c_scalar_barrett_inplace::<2, ONCE>( - &a_level[0], - &rescaling_constants.0[i], - &mut a_i[i], - ); + }else{ + + if NTT { + let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); + r_last.intt::(a.at(level), &mut buf_ntt_q_scaling[0]); + for (i, r) in self.0[0..level].iter().enumerate() { + r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); + r.b_sub_a_mul_c_scalar_barrett_into_a::<2, ONCE>( + &buf_ntt_qi_scaling[0], + &rescaling_constants.0[i], + a.at_mut(i), + ); + } + }else{ + let (a_i, a_level) = a.0.split_at_mut(level); + for (i, r) in self.0[0..level].iter().enumerate() { + r.b_sub_a_mul_c_scalar_barrett_into_a::<2, ONCE>( + &a_level[0], + &rescaling_constants.0[i], + &mut a_i[i], + ); + } } } + } /// Updates b to floor(a / prod_{level - nb_moduli}^{level} q[i]) - pub fn div_floor_by_last_moduli( + pub fn div_by_last_moduli( &self, nb_moduli: usize, a: &PolyRNS, @@ -133,38 +213,35 @@ impl RingRNS { c.copy(a); } } else { - if NTT { + + if NTT{ self.intt::(a, buf); (0..nb_moduli).for_each(|i| { self.at_level(self.level() - i) - .div_floor_by_last_modulus_inplace::( + .div_by_last_modulus_inplace::( &mut PolyRNS::::default(), buf, ) }); self.at_level(self.level() - nb_moduli).ntt::(buf, c); - } else { - - let empty_buf: &mut PolyRNS = &mut PolyRNS::::default(); - - if nb_moduli == 1{ - self.div_floor_by_last_modulus::(a, empty_buf, c); - }else{ - self.div_floor_by_last_modulus::(a, empty_buf, buf); - } + }else{ - (1..nb_moduli-1).for_each(|i| { - self.at_level(self.level() - i) - .div_floor_by_last_modulus_inplace::(empty_buf, buf); - }); + println!("{} {:?}", self.level(), buf.level()); + self.div_by_last_modulus::(a, buf, c); - self.at_level(self.level()-nb_moduli+1).div_floor_by_last_modulus::(buf, empty_buf, c); + (1..nb_moduli-1).for_each(|i| { + println!("{} {:?}", self.level() - i, buf.level()); + self.at_level(self.level() - i) + .div_by_last_modulus_inplace::(buf, c); + }); + + self.at_level(self.level()-nb_moduli+1).div_by_last_modulus_inplace::(buf, c); } } } /// Updates a to floor(a / prod_{level - nb_moduli}^{level} q[i]) - pub fn div_floor_by_last_moduli_inplace( + pub fn div_by_last_moduli_inplace( &self, nb_moduli: usize, buf: &mut PolyRNS, @@ -185,218 +262,18 @@ impl RingRNS { if nb_moduli == 0{ return } + if NTT { self.intt::(a, buf); (0..nb_moduli).for_each(|i| { self.at_level(self.level() - i) - .div_floor_by_last_modulus_inplace::(&mut PolyRNS::::default(), buf) - }); - self.at_level(self.level() - nb_moduli+1).ntt::(buf, a); - } else { - (0..nb_moduli).for_each(|i| { - self.at_level(self.level() - i) - .div_floor_by_last_modulus_inplace::(buf, a); - }); - } - } - - /// Updates b to round(a / q[b.level()]). - /// Expects b to be in the NTT domain. - pub fn div_round_by_last_modulus( - &self, - a: &PolyRNS, - buf: &mut PolyRNS, - b: &mut PolyRNS, - ) { - debug_assert!( - self.level() <= a.level(), - "invalid input a: self.level()={} > a.level()={}", - self.level(), - a.level() - ); - debug_assert!( - b.level() >= a.level() - 1, - "invalid input b: b.level()={} < a.level()-1={}", - b.level(), - a.level() - 1 - ); - - let level: usize = self.level(); - let r_last: &Ring = &self.0[level]; - let q_level_half: u64 = r_last.modulus.q >> 1; - let rescaling_constants: ScalarRNS> = self.rescaling_constant(); - let (buf_q_scaling, buf_qi_scaling) = buf.0.split_at_mut(1); - - if NTT { - r_last.intt::(a.at(level), &mut buf_q_scaling[0]); - r_last.add_scalar_inplace::(&q_level_half, &mut buf_q_scaling[0]); - for (i, r) in self.0[0..level].iter().enumerate() { - r_last.add_scalar::( - &buf_q_scaling[0], - &(r.modulus.q - r_last.modulus.barrett.reduce::(&q_level_half)), - &mut buf_qi_scaling[0], - ); - r.ntt_inplace::(&mut buf_qi_scaling[0]); - r.a_sub_b_mul_c_scalar_barrett::<2, ONCE>( - &buf_qi_scaling[0], - a.at(i), - &rescaling_constants.0[i], - b.at_mut(i), - ); - } - } else { - r_last.add_scalar_inplace::(&q_level_half, &mut buf_q_scaling[0]); - for (i, r) in self.0[0..level].iter().enumerate() { - r_last.add_scalar::( - &buf_q_scaling[0], - &(r.modulus.q - r_last.modulus.barrett.reduce::(&q_level_half)), - &mut buf_qi_scaling[0], - ); - r.a_sub_b_mul_c_scalar_barrett::<2, ONCE>( - &buf_qi_scaling[0], - a.at(i), - &rescaling_constants.0[i], - b.at_mut(i), - ); - } - } - } - - /// Updates a to round(a / q[b.level()]). - /// Expects a to be in the NTT domain. - pub fn div_round_by_last_modulus_inplace( - &self, - buf: &mut PolyRNS, - a: &mut PolyRNS, - ) { - debug_assert!( - self.level() <= a.level(), - "invalid input a: self.level()={} > a.level()={}", - self.level(), - a.level() - ); - - let level = self.level(); - let r_last: &Ring = &self.0[level]; - let q_level_half: u64 = r_last.modulus.q >> 1; - let rescaling_constants: ScalarRNS> = self.rescaling_constant(); - let (buf_q_scaling, buf_qi_scaling) = buf.0.split_at_mut(1); - - if NTT { - r_last.intt::(a.at(level), &mut buf_q_scaling[0]); - r_last.add_scalar_inplace::(&q_level_half, &mut buf_q_scaling[0]); - for (i, r) in self.0[0..level].iter().enumerate() { - r_last.add_scalar::( - &buf_q_scaling[0], - &(r.modulus.q - r_last.modulus.barrett.reduce::(&q_level_half)), - &mut buf_qi_scaling[0], - ); - r.ntt_inplace::(&mut buf_qi_scaling[0]); - r.a_sub_b_mul_c_scalar_barrett_inplace::<2, ONCE>( - &buf_qi_scaling[0], - &rescaling_constants.0[i], - a.at_mut(i), - ); - } - } else { - r_last.add_scalar_inplace::(&q_level_half, &mut buf_q_scaling[0]); - for (i, r) in self.0[0..level].iter().enumerate() { - r_last.add_scalar::( - &buf_q_scaling[0], - &(r.modulus.q - r_last.modulus.barrett.reduce::(&q_level_half)), - &mut buf_qi_scaling[0], - ); - r.a_sub_b_mul_c_scalar_barrett_inplace::<2, ONCE>( - &buf_qi_scaling[0], - &rescaling_constants.0[i], - a.at_mut(i), - ); - } - } - } - - /// Updates b to round(a / prod_{level - nb_moduli}^{level} q[i]) - pub fn div_round_by_last_moduli( - &self, - nb_moduli: usize, - a: &PolyRNS, - buf: &mut PolyRNS, - c: &mut PolyRNS, - ) { - debug_assert!( - self.level() <= a.level(), - "invalid input a: self.level()={} > a.level()={}", - self.level(), - a.level() - ); - debug_assert!( - c.level() >= a.level() - 1, - "invalid input b: b.level()={} < a.level()-1={}", - c.level(), - a.level() - 1 - ); - debug_assert!( - nb_moduli <= a.level(), - "invalid input nb_moduli: nb_moduli={} > a.level()={}", - nb_moduli, - a.level() - ); - - if nb_moduli == 0 { - if a != c { - c.copy(a); - } - } else { - if NTT { - self.intt::(a, buf); - (0..nb_moduli).for_each(|i| { - self.at_level(self.level() - i) - .div_round_by_last_modulus_inplace::( - &mut PolyRNS::::default(), - buf, - ) - }); - self.at_level(self.level() - nb_moduli).ntt::(buf, c); - } else { - self.div_round_by_last_modulus::(a, buf, c); - (1..nb_moduli).for_each(|i| { - self.at_level(self.level() - i) - .div_round_by_last_modulus_inplace::(buf, c) - }); - } - } - } - - /// Updates a to round(a / prod_{level - nb_moduli}^{level} q[i]) - pub fn div_round_by_last_moduli_inplace( - &self, - nb_moduli: usize, - buf: &mut PolyRNS, - a: &mut PolyRNS, - ) { - debug_assert!( - self.level() <= a.level(), - "invalid input a: self.level()={} > a.level()={}", - self.level(), - a.level() - ); - debug_assert!( - nb_moduli <= a.level(), - "invalid input nb_moduli: nb_moduli={} > a.level()={}", - nb_moduli, - a.level() - ); - if NTT { - self.intt::(a, buf); - (0..nb_moduli).for_each(|i| { - self.at_level(self.level() - i) - .div_round_by_last_modulus_inplace::(&mut PolyRNS::::default(), buf) + .div_by_last_modulus_inplace::(&mut PolyRNS::::default(), buf) }); self.at_level(self.level() - nb_moduli).ntt::(buf, a); } else { (0..nb_moduli).for_each(|i| { self.at_level(self.level() - i) - .div_round_by_last_modulus_inplace::(buf, a) + .div_by_last_modulus_inplace::(buf, a) }); } } diff --git a/math/src/ring/impl_u64/ring.rs b/math/src/ring/impl_u64/ring.rs index 644fa90..4f22de0 100644 --- a/math/src/ring/impl_u64/ring.rs +++ b/math/src/ring/impl_u64/ring.rs @@ -75,7 +75,7 @@ impl Ring { impl Ring { #[inline(always)] - pub fn add_inplace(&self, a: &Poly, b: &mut Poly) { + pub fn a_add_b_into_b(&self, a: &Poly, b: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus @@ -83,7 +83,7 @@ impl Ring { } #[inline(always)] - pub fn add(&self, a: &Poly, b: &Poly, c: &mut Poly) { + pub fn a_add_b_into_c(&self, a: &Poly, b: &Poly, c: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); @@ -92,13 +92,13 @@ impl Ring { } #[inline(always)] - pub fn add_scalar_inplace(&self, b: &u64, a: &mut Poly) { + pub fn a_add_b_scalar_into_a(&self, b: &u64, a: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); self.modulus.va_add_sb_into_va::(b, &mut a.0); } #[inline(always)] - pub fn add_scalar(&self, a: &Poly, b: &u64, c: &mut Poly) { + pub fn a_add_b_scalar_into_c(&self, a: &Poly, b: &u64, c: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); self.modulus @@ -106,7 +106,7 @@ impl Ring { } #[inline(always)] - pub fn add_scalar_then_mul_scalar_barrett_inplace(&self, b: &u64, c: &Barrett, a: &mut Poly) { + pub fn a_add_scalar_b_mul_c_scalar_barrett_into_a(&self, b: &u64, c: &Barrett, a: &mut Poly) { debug_assert!(a.n() == self.n(), "b.n()={} != n={}", a.n(), self.n()); self.modulus.va_add_sb_mul_sc_into_va::(b, c, &mut a.0); } @@ -120,7 +120,7 @@ impl Ring { } #[inline(always)] - pub fn sub_inplace(&self, a: &Poly, b: &mut Poly) { + pub fn a_sub_b_into_b(&self, a: &Poly, b: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus @@ -128,7 +128,15 @@ impl Ring { } #[inline(always)] - pub fn sub(&self, a: &Poly, b: &Poly, c: &mut Poly) { + pub fn a_sub_b_into_a(&self, b: &Poly, a: &mut Poly) { + debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); + debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); + self.modulus + .va_sub_vb_into_va::(&b.0, &mut a.0); + } + + #[inline(always)] + pub fn a_sub_b_into_c(&self, a: &Poly, b: &Poly, c: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); @@ -137,20 +145,20 @@ impl Ring { } #[inline(always)] - pub fn neg(&self, a: &Poly, b: &mut Poly) { + pub fn a_neg_into_b(&self, a: &Poly, b: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus.va_neg_into_vb::(&a.0, &mut b.0); } #[inline(always)] - pub fn neg_inplace(&self, a: &mut Poly) { + pub fn a_neg_into_a(&self, a: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); self.modulus.va_neg_into_va::(&mut a.0); } #[inline(always)] - pub fn mul_montgomery_external( + pub fn a_mul_b_montgomery_into_c( &self, a: &Poly>, b: &Poly, @@ -164,20 +172,20 @@ impl Ring { } #[inline(always)] - pub fn mul_montgomery_external_inplace( + pub fn a_mul_b_montgomery_into_a( &self, - a: &Poly>, - b: &mut Poly, + b: &Poly>, + a: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus - .va_mont_mul_vb_into_vb::(&a.0, &mut b.0); + .va_mont_mul_vb_into_vb::(&b.0, &mut a.0); } #[inline(always)] - pub fn mul_scalar(&self, a: &Poly, b: &u64, c: &mut Poly) { - debug_assert!(a.n() == self.n(), "b.n()={} != n={}", a.n(), self.n()); + pub fn a_mul_b_scalar_into_c(&self, a: &Poly, b: &u64, c: &mut Poly) { + debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); self.modulus.sa_barrett_mul_vb_into_vc::( &self.modulus.barrett.prepare(*b), @@ -187,30 +195,30 @@ impl Ring { } #[inline(always)] - pub fn mul_scalar_inplace(&self, a: &u64, b: &mut Poly) { - debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); + pub fn a_mul_b_scalar_into_a(&self, b: &u64, a: &mut Poly) { + debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); self.modulus.sa_barrett_mul_vb_into_vb::( &self .modulus .barrett - .prepare(self.modulus.barrett.reduce::(a)), - &mut b.0, + .prepare(self.modulus.barrett.reduce::(b)), + &mut a.0, ); } #[inline(always)] - pub fn mul_scalar_barrett_inplace( + pub fn a_mul_b_scalar_barrett_into_a( &self, - a: &Barrett, - b: &mut Poly, + b: &Barrett, + a: &mut Poly, ) { - debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); + debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); self.modulus - .sa_barrett_mul_vb_into_vb::(a, &mut b.0); + .sa_barrett_mul_vb_into_vb::(b, &mut a.0); } #[inline(always)] - pub fn mul_scalar_barrett( + pub fn a_mul_b_scalar_barrett_into_c( &self, a: &Barrett, b: &Poly, @@ -222,7 +230,7 @@ impl Ring { } #[inline(always)] - pub fn a_sub_b_mul_c_scalar_barrett( + pub fn a_sub_b_mul_c_scalar_barrett_into_d( &self, a: &Poly, b: &Poly, @@ -237,15 +245,46 @@ impl Ring { } #[inline(always)] - pub fn a_sub_b_mul_c_scalar_barrett_inplace( + pub fn b_sub_a_mul_c_scalar_barrett_into_a( &self, - a: &Poly, + b: &Poly, c: &Barrett, - b: &mut Poly, + a: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus - .va_sub_vb_mul_sc_into_vb::(&a.0, c, &mut b.0); + .va_sub_vb_mul_sc_into_vb::(&b.0, c, &mut a.0); } + + #[inline(always)] + pub fn a_sub_b_add_c_scalar_mul_d_scalar_barrett_into_e( + &self, + a: &Poly, + b: &Poly, + c: &u64, + d: &Barrett, + e: &mut Poly, + ){ + debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); + debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); + debug_assert!(e.n() == self.n(), "e.n()={} != n={}", e.n(), self.n()); + self.modulus + .vb_sub_va_add_sc_mul_sd_into_ve::(&a.0, &b.0, c, d, &mut e.0); + } + + #[inline(always)] + pub fn b_sub_a_add_c_scalar_mul_d_scalar_barrett_into_a( + &self, + b: &Poly, + c: &u64, + d: &Barrett, + a: &mut Poly, + ){ + debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); + debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); + self.modulus + .vb_sub_va_add_sc_mul_sd_into_va::(&b.0, c, d, &mut a.0); + } + } diff --git a/math/src/ring/impl_u64/ring_rns.rs b/math/src/ring/impl_u64/ring_rns.rs index 4549d8b..92c6d73 100644 --- a/math/src/ring/impl_u64/ring_rns.rs +++ b/math/src/ring/impl_u64/ring_rns.rs @@ -7,6 +7,8 @@ use crate::scalar::ScalarRNS; use num_bigint::BigInt; use std::sync::Arc; + + impl RingRNS { pub fn new(n: usize, moduli: Vec) -> Self { assert!(!moduli.is_empty(), "moduli cannot be empty"); @@ -121,7 +123,7 @@ impl RingRNS { impl RingRNS { #[inline(always)] - pub fn add( + pub fn a_add_b_into_c( &self, a: &PolyRNS, b: &PolyRNS, @@ -148,11 +150,11 @@ impl RingRNS { self.0 .iter() .enumerate() - .for_each(|(i, ring)| ring.add::(&a.0[i], &b.0[i], &mut c.0[i])); + .for_each(|(i, ring)| ring.a_add_b_into_c::(&a.0[i], &b.0[i], &mut c.0[i])); } #[inline(always)] - pub fn add_inplace(&self, a: &PolyRNS, b: &mut PolyRNS) { + pub fn a_add_b_into_b(&self, a: &PolyRNS, b: &mut PolyRNS) { debug_assert!( a.level() >= self.level(), "a.level()={} < self.level()={}", @@ -168,11 +170,11 @@ impl RingRNS { self.0 .iter() .enumerate() - .for_each(|(i, ring)| ring.add_inplace::(&a.0[i], &mut b.0[i])); + .for_each(|(i, ring)| ring.a_add_b_into_b::(&a.0[i], &mut b.0[i])); } #[inline(always)] - pub fn sub( + pub fn a_sub_b_into_c( &self, a: &PolyRNS, b: &PolyRNS, @@ -199,11 +201,11 @@ impl RingRNS { self.0 .iter() .enumerate() - .for_each(|(i, ring)| ring.sub::(&a.0[i], &b.0[i], &mut c.0[i])); + .for_each(|(i, ring)| ring.a_sub_b_into_c::(&a.0[i], &b.0[i], &mut c.0[i])); } #[inline(always)] - pub fn sub_inplace(&self, a: &PolyRNS, b: &mut PolyRNS) { + pub fn a_sub_b_into_b(&self, a: &PolyRNS, b: &mut PolyRNS) { debug_assert!( a.level() >= self.level(), "a.level()={} < self.level()={}", @@ -219,11 +221,11 @@ impl RingRNS { self.0 .iter() .enumerate() - .for_each(|(i, ring)| ring.sub_inplace::(&a.0[i], &mut b.0[i])); + .for_each(|(i, ring)| ring.a_sub_b_into_b::(&a.0[i], &mut b.0[i])); } #[inline(always)] - pub fn neg(&self, a: &PolyRNS, b: &mut PolyRNS) { + pub fn a_sub_b_into_a(&self, b: &PolyRNS, a: &mut PolyRNS) { debug_assert!( a.level() >= self.level(), "a.level()={} < self.level()={}", @@ -239,11 +241,31 @@ impl RingRNS { self.0 .iter() .enumerate() - .for_each(|(i, ring)| ring.neg::(&a.0[i], &mut b.0[i])); + .for_each(|(i, ring)| ring.a_sub_b_into_a::(&b.0[i], &mut a.0[i])); } #[inline(always)] - pub fn neg_inplace(&self, a: &mut PolyRNS) { + pub fn a_neg_into_b(&self, a: &PolyRNS, b: &mut PolyRNS) { + debug_assert!( + a.level() >= self.level(), + "a.level()={} < self.level()={}", + a.level(), + self.level() + ); + debug_assert!( + b.level() >= self.level(), + "b.level()={} < self.level()={}", + b.level(), + self.level() + ); + self.0 + .iter() + .enumerate() + .for_each(|(i, ring)| ring.a_neg_into_b::(&a.0[i], &mut b.0[i])); + } + + #[inline(always)] + pub fn a_neg_into_a(&self, a: &mut PolyRNS) { debug_assert!( a.level() >= self.level(), "a.level()={} < self.level()={}", @@ -253,7 +275,7 @@ impl RingRNS { self.0 .iter() .enumerate() - .for_each(|(i, ring)| ring.neg_inplace::(&mut a.0[i])); + .for_each(|(i, ring)| ring.a_neg_into_a::(&mut a.0[i])); } #[inline(always)] @@ -282,7 +304,7 @@ impl RingRNS { self.level() ); self.0.iter().enumerate().for_each(|(i, ring)| { - ring.mul_montgomery_external::(&a.0[i], &b.0[i], &mut c.0[i]) + ring.a_mul_b_montgomery_into_c::(&a.0[i], &b.0[i], &mut c.0[i]) }); } @@ -305,7 +327,7 @@ impl RingRNS { self.level() ); self.0.iter().enumerate().for_each(|(i, ring)| { - ring.mul_montgomery_external_inplace::(&a.0[i], &mut b.0[i]) + ring.a_mul_b_montgomery_into_a::(&a.0[i], &mut b.0[i]) }); } @@ -331,11 +353,57 @@ impl RingRNS { self.0 .iter() .enumerate() - .for_each(|(i, ring)| ring.mul_scalar::(&a.0[i], b, &mut c.0[i])); + .for_each(|(i, ring)| ring.a_mul_b_scalar_into_c::(&a.0[i], b, &mut c.0[i])); } #[inline(always)] - pub fn mul_scalar_inplace(&self, a: &u64, b: &mut PolyRNS) { + pub fn mul_scalar_inplace(&self, b: &u64, a: &mut PolyRNS) { + debug_assert!( + a.level() >= self.level(), + "b.level()={} < self.level()={}", + a.level(), + self.level() + ); + self.0 + .iter() + .enumerate() + .for_each(|(i, ring)| ring.a_mul_b_scalar_into_a::(b, &mut a.0[i])); + } + + #[inline(always)] + pub fn a_sub_b_add_scalar_mul_scalar_barrett_into_e(&self, a: &PolyRNS, b: &PolyRNS, c: &u64, d: &Barrett, e: &mut PolyRNS){ + debug_assert!( + a.level() >= self.level(), + "a.level()={} < self.level()={}", + a.level(), + self.level() + ); + debug_assert!( + b.level() >= self.level(), + "b.level()={} < self.level()={}", + b.level(), + self.level() + ); + debug_assert!( + e.level() >= self.level(), + "e.level()={} < self.level()={}", + e.level(), + self.level() + ); + self.0 + .iter() + .enumerate() + .for_each(|(i, ring)| ring.a_sub_b_add_c_scalar_mul_d_scalar_barrett_into_e::(&a.0[i], &b.0[i], c, d, &mut e.0[i])); + } + + #[inline(always)] + pub fn b_sub_a_add_c_scalar_mul_d_scalar_barrett_into_a(&self, b: &PolyRNS, c: &u64, d: &Barrett, a: &mut PolyRNS){ + debug_assert!( + a.level() >= self.level(), + "a.level()={} < self.level()={}", + a.level(), + self.level() + ); debug_assert!( b.level() >= self.level(), "b.level()={} < self.level()={}", @@ -345,6 +413,6 @@ impl RingRNS { self.0 .iter() .enumerate() - .for_each(|(i, ring)| ring.mul_scalar_inplace::(a, &mut b.0[i])); + .for_each(|(i, ring)| ring.b_sub_a_add_c_scalar_mul_d_scalar_barrett_into_a::(&b.0[i], c, d, &mut a.0[i])); } } diff --git a/math/src/ring/impl_u64/sampling.rs b/math/src/ring/impl_u64/sampling.rs index 5500e0a..0aafe1a 100644 --- a/math/src/ring/impl_u64/sampling.rs +++ b/math/src/ring/impl_u64/sampling.rs @@ -1,6 +1,8 @@ use crate::modulus::WordOps; use crate::poly::{Poly, PolyRNS}; use crate::ring::{Ring, RingRNS}; +use num::ToPrimitive; +use rand_distr::{Normal, Distribution}; use sampling::source::Source; impl Ring { @@ -10,6 +12,24 @@ impl Ring { a.0.iter_mut() .for_each(|a| *a = source.next_u64n(max, mask)); } + + pub fn fill_dist_f64>(&self, source: &mut Source, dist: T, bound: f64, a: &mut Poly) { + let max: u64 = self.modulus.q; + a.0.iter_mut() + .for_each(|a| { + + let mut dist_f64: f64 = dist.sample(source); + + while dist_f64.abs() > bound{ + dist_f64 = dist.sample(source) + } + + let dist_u64: u64 = (dist_f64+0.5).abs().to_u64().unwrap(); + let sign: u64 = dist_f64.to_bits()>>63; + + *a = (dist_u64 * sign) | (max-dist_u64)*(sign^1) + }); + } } impl RingRNS { @@ -19,4 +39,21 @@ impl RingRNS { .enumerate() .for_each(|(i, r)| r.fill_uniform(source, a.at_mut(i))); } + + pub fn fill_dist_f64>(&self, source: &mut Source, dist: T, bound: f64, a: &mut PolyRNS) { + (0..a.n()).for_each(|j|{ + let mut dist_f64: f64 = dist.sample(source); + + while dist_f64.abs() > bound{ + dist_f64 = dist.sample(source) + } + + let dist_u64: u64 = (dist_f64+0.5).abs().to_u64().unwrap(); + let sign: u64 = dist_f64.to_bits()>>63; + + self.0.iter().enumerate().for_each(|(i, r)|{ + a.at_mut(i).0[j] = (dist_u64 * sign) | (r.modulus.q-dist_u64)*(sign^1); + }) + }) + } } diff --git a/math/tests/rescaling_rns.rs b/math/tests/rescaling_rns.rs index 8814bc3..60d684f 100644 --- a/math/tests/rescaling_rns.rs +++ b/math/tests/rescaling_rns.rs @@ -1,8 +1,9 @@ use math::poly::PolyRNS; use math::ring::RingRNS; use num_bigint::BigInt; -use num_bigint::Sign; +use math::num_bigint::Div; use sampling::source::Source; +use itertools::izip; #[test] fn rescaling_rns_u64() { @@ -10,17 +11,31 @@ fn rescaling_rns_u64() { let moduli: Vec = vec![0x1fffffffffc80001u64, 0x1fffffffffe00001u64, 0x1fffffffffb40001, 0x1fffffffff500001]; let ring_rns: RingRNS = RingRNS::new(n, moduli); - test_div_floor_by_last_modulus::(&ring_rns); - test_div_floor_by_last_modulus::(&ring_rns); - test_div_floor_by_last_modulus_inplace::(&ring_rns); - test_div_floor_by_last_modulus_inplace::(&ring_rns); - test_div_floor_by_last_moduli::(&ring_rns); - test_div_floor_by_last_moduli::(&ring_rns); - test_div_floor_by_last_moduli_inplace::(&ring_rns); - test_div_floor_by_last_moduli_inplace::(&ring_rns); + + sub_test("test_div_by_last_modulus::", ||{test_div_by_last_modulus::(&ring_rns)}); + sub_test("test_div_by_last_modulus::", ||{test_div_by_last_modulus::(&ring_rns)}); + sub_test("test_div_by_last_modulus::", ||{test_div_by_last_modulus::(&ring_rns)}); + sub_test("test_div_by_last_modulus::", ||{test_div_by_last_modulus::(&ring_rns)}); + sub_test("test_div_by_last_modulus_inplace::", ||{test_div_by_last_modulus_inplace::(&ring_rns)}); + sub_test("test_div_by_last_modulus_inplace::", ||{test_div_by_last_modulus_inplace::(&ring_rns)}); + sub_test("test_div_by_last_modulus_inplace::", ||{test_div_by_last_modulus_inplace::(&ring_rns)}); + sub_test("test_div_by_last_modulus_inplace::", ||{test_div_by_last_modulus_inplace::(&ring_rns)}); + + + + + + + //sub_test("test_div_by_last_moduli::", ||{test_div_by_last_moduli::(&ring_rns)}); } -fn test_div_floor_by_last_modulus(ring_rns: &RingRNS) { +fn sub_test(name: &str, f: F) { + println!("Running {}", name); + f(); +} + +fn test_div_by_last_modulus(ring_rns: &RingRNS){ + let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); @@ -42,7 +57,8 @@ fn test_div_floor_by_last_modulus(ring_rns: &RingRNS) { ring_rns.ntt_inplace::(&mut a); } - ring_rns.div_floor_by_last_modulus::(&a, &mut b, &mut c); + ring_rns.div_by_last_modulus::(&a, &mut b, &mut c); + if NTT { ring_rns.at_level(c.level()).intt_inplace::(&mut c); @@ -57,22 +73,23 @@ fn test_div_floor_by_last_modulus(ring_rns: &RingRNS) { // Performs floor division on a let scalar_big = BigInt::from(ring_rns.0[ring_rns.level()].modulus.q); coeffs_a.iter_mut().for_each(|a| { - // Emulates floor division in [0, q-1] and maps to [-(q-1)/2, (q-1)/2-1] - *a /= &scalar_big; - if a.sign() == Sign::Minus { - *a -= 1; - } + if ROUND{ + *a = a.div_round(&scalar_big); + }else{ + *a = a.div_floor(&scalar_big); + } }); - assert!(coeffs_a == coeffs_c, "test_div_floor_by_last_modulus"); + izip!(coeffs_a, coeffs_c).for_each(|(a, b)| assert_eq!(a, b)); } -fn test_div_floor_by_last_modulus_inplace(ring_rns: &RingRNS) { +fn test_div_by_last_modulus_inplace(ring_rns: &RingRNS) { + let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); let mut a: PolyRNS = ring_rns.new_polyrns(); - let mut b: PolyRNS = ring_rns.new_polyrns(); + let mut buf: PolyRNS = ring_rns.new_polyrns(); // Allocates a random PolyRNS ring_rns.fill_uniform(&mut source, &mut a); @@ -88,7 +105,7 @@ fn test_div_floor_by_last_modulus_inplace(ring_rns: &RingRNS(&mut a); } - ring_rns.div_floor_by_last_modulus_inplace::(&mut b, &mut a); + ring_rns.div_by_last_modulus_inplace::(&mut buf, &mut a); if NTT { ring_rns.at_level(a.level()-1).intt_inplace::(&mut a); @@ -103,24 +120,26 @@ fn test_div_floor_by_last_modulus_inplace(ring_rns: &RingRNS(ring_rns: &RingRNS) { + +fn test_div_by_last_moduli(ring_rns: &RingRNS){ + let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); let nb_moduli: usize = ring_rns.level(); let mut a: PolyRNS = ring_rns.new_polyrns(); - let mut b: PolyRNS = ring_rns.new_polyrns(); + let mut buf: PolyRNS = ring_rns.new_polyrns(); let mut c: PolyRNS = ring_rns.at_level(ring_rns.level() - nb_moduli).new_polyrns(); // Allocates a random PolyRNS @@ -137,14 +156,14 @@ fn test_div_floor_by_last_moduli(ring_rns: &RingRNS) { ring_rns.ntt_inplace::(&mut a); } - ring_rns.div_floor_by_last_moduli::(nb_moduli, &a, &mut b, &mut c); + ring_rns.div_by_last_moduli::(nb_moduli, &a, &mut buf, &mut c); if NTT { ring_rns.at_level(c.level()).intt_inplace::(&mut c); } // Exports c to coeffs_c - let mut coeffs_c = vec![BigInt::from(0); c.n()]; + let mut coeffs_c = vec![BigInt::from(0); a.n()]; ring_rns .at_level(c.level()) .to_bigint_inplace(&c, 1, &mut coeffs_c); @@ -152,18 +171,18 @@ fn test_div_floor_by_last_moduli(ring_rns: &RingRNS) { // Performs floor division on a let mut scalar_big = BigInt::from(1); (0..nb_moduli).for_each(|i|{scalar_big *= BigInt::from(ring_rns.0[ring_rns.level()-i].modulus.q)}); - coeffs_a.iter_mut().for_each(|a| { - // Emulates floor division in [0, q-1] and maps to [-(q-1)/2, (q-1)/2-1] - *a /= &scalar_big; - if a.sign() == Sign::Minus { - *a -= 1; - } + if ROUND{ + *a = a.div_round(&scalar_big); + }else{ + *a = a.div_floor(&scalar_big); + } }); - assert!(coeffs_a == coeffs_c, "test_div_floor_by_last_moduli"); + izip!(coeffs_a, coeffs_c).for_each(|(a, b)| assert_eq!(a, b)); } +/* fn test_div_floor_by_last_moduli_inplace(ring_rns: &RingRNS) { let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); @@ -202,13 +221,8 @@ fn test_div_floor_by_last_moduli_inplace(ring_rns: &RingRNS u64 { - self.source.next_u64() - } - #[inline(always)] pub fn next_u64n(&mut self, max: u64, mask: u64) -> u64 { let mut x: u64 = self.next_u64() & mask; @@ -39,9 +34,26 @@ impl Source { pub fn next_f64(&mut self, min: f64, max: f64) -> f64 { min + ((self.next_u64() << 11 >> 11) as f64) / MAXF64 * (max - min) } +} + +impl RngCore for Source{ + #[inline(always)] + fn next_u32(&mut self) -> u32 { + self.source.next_u32() + } #[inline(always)] - pub fn fill_bytes(&mut self, bytes: &mut [u8]) { + fn next_u64(&mut self) -> u64 { + self.source.next_u64() + } + + #[inline(always)] + fn fill_bytes(&mut self, bytes: &mut [u8]) { self.source.fill_bytes(bytes) } -} + + #[inline(always)] + fn try_fill_bytes(&mut self, bytes: &mut [u8]) -> Result<(), rand_core::Error>{ + self.source.try_fill_bytes(bytes) + } +} \ No newline at end of file