diff --git a/math/src/ring.rs b/math/src/ring.rs index a335421..f8ef639 100644 --- a/math/src/ring.rs +++ b/math/src/ring.rs @@ -33,6 +33,10 @@ impl RingRNS { PolyRNS::::new(self.n(), self.level()) } + pub fn new_poly(&self) -> Poly { + Poly::::new(self.n()) + } + pub fn max_level(&self) -> usize { self.0.len() - 1 } diff --git a/math/src/ring/impl_u64/rescaling_rns.rs b/math/src/ring/impl_u64/rescaling_rns.rs index d1c9690..0de8059 100644 --- a/math/src/ring/impl_u64/rescaling_rns.rs +++ b/math/src/ring/impl_u64/rescaling_rns.rs @@ -1,6 +1,6 @@ use crate::modulus::barrett::Barrett; use crate::modulus::{BARRETT, NONE, ONCE}; -use crate::poly::PolyRNS; +use crate::poly::{Poly, PolyRNS}; use crate::ring::Ring; use crate::ring::RingRNS; use crate::scalar::ScalarRNS; @@ -8,18 +8,19 @@ extern crate test; impl RingRNS { /// Updates b to floor(a / q[b.level()]). + /// buf is unused if pub fn div_by_last_modulus( &self, a: &PolyRNS, - buf: &mut PolyRNS, + buf: &mut [Poly; 2], b: &mut PolyRNS, ) { debug_assert!(self.level() != 0, "invalid call: self.level()=0"); debug_assert!( - self.level() <= a.level(), - "invalid input a: self.level()={} > a.level()={}", - self.level(), - a.level() + a.level() >= self.level(), + "invalid input a: a.level()={} < self.level()={}", + a.level(), + self.level() ); debug_assert!( b.level() >= self.level() - 1, @@ -35,7 +36,7 @@ impl RingRNS { if ROUND { let q_level_half: u64 = r_last.modulus.q >> 1; - let (buf_q_scaling, buf_qi_scaling) = buf.0.split_at_mut(1); + let (buf_q_scaling, buf_qi_scaling) = buf.split_at_mut(1); if NTT { r_last.intt::(a.at(level), &mut buf_q_scaling[0]); @@ -76,7 +77,7 @@ impl RingRNS { } } else { if NTT { - let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); + let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.split_at_mut(1); self.0[level].intt::(a.at(level), &mut buf_ntt_q_scaling[0]); for (i, r) in self.0[0..level].iter().enumerate() { r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); @@ -104,12 +105,12 @@ impl RingRNS { /// Expects a to be in the NTT domain. pub fn div_by_last_modulus_inplace( &self, - buf: &mut PolyRNS, + buf: &mut [Poly; 2], a: &mut PolyRNS, ) { debug_assert!( self.level() <= a.level(), - "invalid input a: self.level()={} > a.level()={}", + "invalid input a: a.level()={} < self.level()={}", self.level(), a.level() ); @@ -120,7 +121,7 @@ impl RingRNS { if ROUND { let q_level_half: u64 = r_last.modulus.q >> 1; - let (buf_q_scaling, buf_qi_scaling) = buf.0.split_at_mut(1); + let (buf_q_scaling, buf_qi_scaling) = buf.split_at_mut(1); if NTT { r_last.intt::(a.at(level), &mut buf_q_scaling[0]); @@ -152,7 +153,7 @@ impl RingRNS { } } else { if NTT { - let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); + let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.split_at_mut(1); r_last.intt::(a.at(level), &mut buf_ntt_q_scaling[0]); for (i, r) in self.0[0..level].iter().enumerate() { r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); @@ -178,104 +179,106 @@ impl RingRNS { /// Updates b to floor(a / prod_{level - nb_moduli}^{level} q[i]) pub fn div_by_last_moduli( &self, - nb_moduli: usize, + nb_moduli_dropped: usize, a: &PolyRNS, - buf: &mut PolyRNS, + buf0: &mut [Poly; 2], + buf1: &mut PolyRNS, c: &mut PolyRNS, ) { debug_assert!( - nb_moduli <= self.level(), - "invalid input nb_moduli: nb_moduli={} > a.level()={}", - nb_moduli, - a.level() + nb_moduli_dropped <= self.level(), + "invalid input nb_moduli_dropped: nb_moduli_dropped={} > self.level()={}", + nb_moduli_dropped, + self.level() ); debug_assert!( - a.level() <= self.level(), - "invalid input a: self.level()={} > a.level()={}", - self.level(), - a.level() + a.level() >= self.level(), + "invalid input a: a.level()={} < self.level()={}", + a.level(), + self.level() ); debug_assert!( - buf.level() >= self.level() - 1, - "invalid input buf: buf.level()={} < a.level()-1={}", - buf.level(), - a.level() - 1 + buf1.level() >= self.level(), + "invalid input buf: buf.level()={} < self.level()={}", + buf1.level(), + self.level() ); debug_assert!( - c.level() >= self.level() - nb_moduli, - "invalid input c: c.level()={} < c.level()-nb_moduli={}", + c.level() >= self.level() - nb_moduli_dropped, + "invalid input c: c.level()={} < self.level()-nb_moduli_dropped={}", c.level(), - a.level() - nb_moduli + self.level() - nb_moduli_dropped ); - if nb_moduli == 0 { + if nb_moduli_dropped == 0 { if a != c { c.copy(a); } } else { if NTT { - self.intt::(a, buf); - (0..nb_moduli).for_each(|i| { + self.intt::(a, buf1); + (0..nb_moduli_dropped).for_each(|i| { self.at_level(self.level() - i) - .div_by_last_modulus_inplace::( - &mut PolyRNS::::default(), - buf, - ) + .div_by_last_modulus_inplace::(buf0, buf1) }); - self.at_level(self.level() - nb_moduli).ntt::(buf, c); + self.at_level(self.level() - nb_moduli_dropped) + .ntt::(buf1, c); } else { - println!("{} {:?}", self.level(), buf.level()); - self.div_by_last_modulus::(a, buf, c); + self.div_by_last_modulus::(a, buf0, buf1); - (1..nb_moduli - 1).for_each(|i| { - println!("{} {:?}", self.level() - i, buf.level()); + (1..nb_moduli_dropped - 1).for_each(|i| { self.at_level(self.level() - i) - .div_by_last_modulus_inplace::(buf, c); + .div_by_last_modulus_inplace::(buf0, buf1); }); - self.at_level(self.level() - nb_moduli + 1) - .div_by_last_modulus_inplace::(buf, c); + self.at_level(self.level() - nb_moduli_dropped + 1) + .div_by_last_modulus::(buf1, buf0, c); } } } - /// Updates a to floor(a / prod_{level - nb_moduli}^{level} q[i]) + /// Updates a to floor(a / prod_{level - nb_moduli_dropped}^{level} q[i]) pub fn div_by_last_moduli_inplace( &self, - nb_moduli: usize, - buf: &mut PolyRNS, + nb_moduli_dropped: usize, + buf0: &mut [Poly; 2], + buf1: &mut PolyRNS, a: &mut PolyRNS, ) { debug_assert!( - self.level() <= a.level(), - "invalid input a: self.level()={} > a.level()={}", - self.level(), - a.level() + nb_moduli_dropped <= self.level(), + "invalid input nb_moduli_dropped: nb_moduli_dropped={} > self.level()={}", + nb_moduli_dropped, + self.level() ); debug_assert!( - nb_moduli <= a.level(), - "invalid input nb_moduli: nb_moduli={} > a.level()={}", - nb_moduli, - a.level() + a.level() >= self.level(), + "invalid input a: a.level()={} < self.level()={}", + a.level(), + self.level() ); - if nb_moduli == 0 { + debug_assert!( + buf1.level() >= self.level(), + "invalid input buf: buf.level()={} < self.level()={}", + buf1.level(), + self.level() + ); + if nb_moduli_dropped == 0 { return; } if NTT { - self.intt::(a, buf); - (0..nb_moduli).for_each(|i| { + self.intt::(a, buf1); + (0..nb_moduli_dropped).for_each(|i| { self.at_level(self.level() - i) - .div_by_last_modulus_inplace::( - &mut PolyRNS::::default(), - buf, - ) + .div_by_last_modulus_inplace::(buf0, buf1) }); - self.at_level(self.level() - nb_moduli).ntt::(buf, a); + self.at_level(self.level() - nb_moduli_dropped) + .ntt::(buf1, a); } else { - (0..nb_moduli).for_each(|i| { + (0..nb_moduli_dropped).for_each(|i| { self.at_level(self.level() - i) - .div_by_last_modulus_inplace::(buf, a) + .div_by_last_modulus_inplace::(buf0, a) }); } } diff --git a/math/tests/rescaling_rns.rs b/math/tests/rescaling_rns.rs index cba05bc..35b4bce 100644 --- a/math/tests/rescaling_rns.rs +++ b/math/tests/rescaling_rns.rs @@ -1,6 +1,6 @@ use itertools::izip; use math::num_bigint::Div; -use math::poly::PolyRNS; +use math::poly::{Poly, PolyRNS}; use math::ring::RingRNS; use num_bigint::BigInt; use sampling::source::Source; @@ -44,8 +44,34 @@ fn rescaling_rns_u64() { "test_div_by_last_modulus_inplace::", || test_div_by_last_modulus_inplace::(&ring_rns), ); - - //sub_test("test_div_by_last_moduli::", ||{test_div_by_last_moduli::(&ring_rns)}); + sub_test("test_div_by_last_moduli::", || { + test_div_by_last_moduli::(&ring_rns) + }); + sub_test("test_div_by_last_moduli::", || { + test_div_by_last_moduli::(&ring_rns) + }); + sub_test("test_div_by_last_moduli::", || { + test_div_by_last_moduli::(&ring_rns) + }); + sub_test("test_div_by_last_moduli::", || { + test_div_by_last_moduli::(&ring_rns) + }); + sub_test( + "test_div_by_last_moduli_inplace::", + || test_div_by_last_moduli_inplace::(&ring_rns), + ); + sub_test( + "test_div_by_last_moduli_inplace::", + || test_div_by_last_moduli_inplace::(&ring_rns), + ); + sub_test( + "test_div_by_last_moduli_inplace::", + || test_div_by_last_moduli_inplace::(&ring_rns), + ); + sub_test( + "test_div_by_last_moduli_inplace::", + || test_div_by_last_moduli_inplace::(&ring_rns), + ); } fn sub_test(name: &str, f: F) { @@ -58,7 +84,7 @@ fn test_div_by_last_modulus(ring_rns: &RingR let mut source: Source = Source::new(seed); let mut a: PolyRNS = ring_rns.new_polyrns(); - let mut b: PolyRNS = ring_rns.new_polyrns(); + let mut buf: [Poly; 2] = [ring_rns.new_poly(), ring_rns.new_poly()]; let mut c: PolyRNS = ring_rns.at_level(ring_rns.level() - 1).new_polyrns(); // Allocates a random PolyRNS @@ -75,7 +101,7 @@ fn test_div_by_last_modulus(ring_rns: &RingR ring_rns.ntt_inplace::(&mut a); } - ring_rns.div_by_last_modulus::(&a, &mut b, &mut c); + ring_rns.div_by_last_modulus::(&a, &mut buf, &mut c); if NTT { ring_rns.at_level(c.level()).intt_inplace::(&mut c); @@ -105,7 +131,7 @@ fn test_div_by_last_modulus_inplace(ring_rns let mut source: Source = Source::new(seed); let mut a: PolyRNS = ring_rns.new_polyrns(); - let mut buf: PolyRNS = ring_rns.new_polyrns(); + let mut buf: [Poly; 2] = [ring_rns.new_poly(), ring_rns.new_poly()]; // Allocates a random PolyRNS ring_rns.fill_uniform(&mut source, &mut a); @@ -152,12 +178,13 @@ fn test_div_by_last_moduli(ring_rns: &RingRN let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); - let nb_moduli: usize = ring_rns.level(); + let nb_moduli_dropped: usize = ring_rns.level(); let mut a: PolyRNS = ring_rns.new_polyrns(); - let mut buf: PolyRNS = ring_rns.new_polyrns(); + let mut buf0: [Poly; 2] = [ring_rns.new_poly(), ring_rns.new_poly()]; + let mut buf1: PolyRNS = ring_rns.new_polyrns(); let mut c: PolyRNS = ring_rns - .at_level(ring_rns.level() - nb_moduli) + .at_level(ring_rns.level() - nb_moduli_dropped) .new_polyrns(); // Allocates a random PolyRNS @@ -174,7 +201,7 @@ fn test_div_by_last_moduli(ring_rns: &RingRN ring_rns.ntt_inplace::(&mut a); } - ring_rns.div_by_last_moduli::(nb_moduli, &a, &mut buf, &mut c); + ring_rns.div_by_last_moduli::(nb_moduli_dropped, &a, &mut buf0, &mut buf1, &mut c); if NTT { ring_rns.at_level(c.level()).intt_inplace::(&mut c); @@ -188,7 +215,7 @@ fn test_div_by_last_moduli(ring_rns: &RingRN // Performs floor division on a let mut scalar_big = BigInt::from(1); - (0..nb_moduli) + (0..nb_moduli_dropped) .for_each(|i| scalar_big *= BigInt::from(ring_rns.0[ring_rns.level() - i].modulus.q)); coeffs_a.iter_mut().for_each(|a| { if ROUND { @@ -201,15 +228,15 @@ fn test_div_by_last_moduli(ring_rns: &RingRN izip!(coeffs_a, coeffs_c).for_each(|(a, b)| assert_eq!(a, b)); } -/* -fn test_div_floor_by_last_moduli_inplace(ring_rns: &RingRNS) { +fn test_div_by_last_moduli_inplace(ring_rns: &RingRNS) { let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); - let nb_moduli: usize = ring_rns.level(); + let nb_moduli_dropped: usize = ring_rns.level(); let mut a: PolyRNS = ring_rns.new_polyrns(); - let mut b: PolyRNS = ring_rns.new_polyrns(); + let mut buf0: [Poly; 2] = [ring_rns.new_poly(), ring_rns.new_poly()]; + let mut buf1: PolyRNS = ring_rns.new_polyrns(); // Allocates a random PolyRNS ring_rns.fill_uniform(&mut source, &mut a); @@ -225,23 +252,36 @@ fn test_div_floor_by_last_moduli_inplace(ring_rns: &RingRNS(&mut a); } - ring_rns.div_floor_by_last_moduli_inplace::(nb_moduli, &mut b, &mut a); + ring_rns.div_by_last_moduli_inplace::( + nb_moduli_dropped, + &mut buf0, + &mut buf1, + &mut a, + ); if NTT { - ring_rns.at_level(a.level()-nb_moduli).intt_inplace::(&mut a); + ring_rns + .at_level(a.level() - nb_moduli_dropped) + .intt_inplace::(&mut a); } // Exports c to coeffs_c let mut coeffs_c = vec![BigInt::from(0); a.n()]; ring_rns - .at_level(a.level()-nb_moduli) + .at_level(a.level() - nb_moduli_dropped) .to_bigint_inplace(&a, 1, &mut coeffs_c); // Performs floor division on a let mut scalar_big = BigInt::from(1); - (0..nb_moduli).for_each(|i|{scalar_big *= BigInt::from(ring_rns.0[ring_rns.level()-i].modulus.q)}); - coeffs_a.iter_mut().for_each(|a| {a.div_floor(&scalar_big)}); + (0..nb_moduli_dropped) + .for_each(|i| scalar_big *= BigInt::from(ring_rns.0[ring_rns.level() - i].modulus.q)); + coeffs_a.iter_mut().for_each(|a| { + if ROUND { + *a = a.div_round(&scalar_big); + } else { + *a = a.div_floor(&scalar_big); + } + }); - assert!(coeffs_a == coeffs_c, "test_div_floor_by_last_moduli_inplace"); + izip!(coeffs_a, coeffs_c).for_each(|(a, b)| assert_eq!(a, b)); } -*/