diff --git a/math/src/ring/impl_u64/rescaling_rns.rs b/math/src/ring/impl_u64/rescaling_rns.rs index e4884ef..c1526e3 100644 --- a/math/src/ring/impl_u64/rescaling_rns.rs +++ b/math/src/ring/impl_u64/rescaling_rns.rs @@ -14,6 +14,7 @@ impl RingRNS { buf: &mut PolyRNS, b: &mut PolyRNS, ) { + debug_assert!(self.level() != 0, "invalid call: self.level()=0"); debug_assert!( self.level() <= a.level(), "invalid input a: self.level()={} > a.level()={}", @@ -21,10 +22,10 @@ impl RingRNS { a.level() ); debug_assert!( - b.level() >= a.level() - 1, - "invalid input b: b.level()={} < a.level()-1={}", + b.level() >= self.level() - 1, + "invalid input b: b.level()={} < self.level()-1={}", b.level(), - a.level() - 1 + self.level() - 1 ); let level = self.level(); @@ -102,28 +103,31 @@ impl RingRNS { buf: &mut PolyRNS, c: &mut PolyRNS, ) { - - println!("{:?}", buf); - debug_assert!( - self.level() <= a.level(), + nb_moduli <= self.level(), + "invalid input nb_moduli: nb_moduli={} > a.level()={}", + nb_moduli, + a.level() + ); + debug_assert!( + a.level() <= self.level(), "invalid input a: self.level()={} > a.level()={}", self.level(), a.level() ); debug_assert!( - c.level() >= a.level() - 1, - "invalid input b: b.level()={} < a.level()-1={}", - c.level(), + buf.level() >= self.level() - 1, + "invalid input buf: buf.level()={} < a.level()-1={}", + buf.level(), a.level() - 1 ); debug_assert!( - nb_moduli <= a.level(), - "invalid input nb_moduli: nb_moduli={} > a.level()={}", - nb_moduli, - a.level() + c.level() >= self.level() - nb_moduli, + "invalid input c: c.level()={} < c.level()-nb_moduli={}", + c.level(), + a.level() - nb_moduli ); - + if nb_moduli == 0 { if a != c { c.copy(a); @@ -140,11 +144,21 @@ impl RingRNS { }); self.at_level(self.level() - nb_moduli).ntt::(buf, c); } else { - self.div_floor_by_last_modulus::(a, buf, c); - (1..nb_moduli).for_each(|i| { + + let empty_buf: &mut PolyRNS = &mut PolyRNS::::default(); + + if nb_moduli == 1{ + self.div_floor_by_last_modulus::(a, empty_buf, c); + }else{ + self.div_floor_by_last_modulus::(a, empty_buf, buf); + } + + (1..nb_moduli-1).for_each(|i| { self.at_level(self.level() - i) - .div_floor_by_last_modulus_inplace::(buf, c) + .div_floor_by_last_modulus_inplace::(empty_buf, buf); }); + + self.at_level(self.level()-nb_moduli+1).div_floor_by_last_modulus::(buf, empty_buf, c); } } } @@ -168,17 +182,20 @@ impl RingRNS { nb_moduli, a.level() ); + if nb_moduli == 0{ + return + } if NTT { self.intt::(a, buf); (0..nb_moduli).for_each(|i| { self.at_level(self.level() - i) .div_floor_by_last_modulus_inplace::(&mut PolyRNS::::default(), buf) }); - self.at_level(self.level() - nb_moduli).ntt::(buf, a); + self.at_level(self.level() - nb_moduli+1).ntt::(buf, a); } else { (0..nb_moduli).for_each(|i| { self.at_level(self.level() - i) - .div_floor_by_last_modulus_inplace::(buf, a) + .div_floor_by_last_modulus_inplace::(buf, a); }); } } diff --git a/math/tests/rescaling_rns.rs b/math/tests/rescaling_rns.rs index 5dae27b..8814bc3 100644 --- a/math/tests/rescaling_rns.rs +++ b/math/tests/rescaling_rns.rs @@ -10,14 +10,14 @@ fn rescaling_rns_u64() { let moduli: Vec = vec![0x1fffffffffc80001u64, 0x1fffffffffe00001u64, 0x1fffffffffb40001, 0x1fffffffff500001]; let ring_rns: RingRNS = RingRNS::new(n, moduli); - //test_div_floor_by_last_modulus::(&ring_rns); - //test_div_floor_by_last_modulus::(&ring_rns); + test_div_floor_by_last_modulus::(&ring_rns); + test_div_floor_by_last_modulus::(&ring_rns); test_div_floor_by_last_modulus_inplace::(&ring_rns); - //test_div_floor_by_last_modulus_inplace::(&ring_rns); - //test_div_floor_by_last_moduli::(&ring_rns); - //test_div_floor_by_last_moduli::(&ring_rns); - //test_div_floor_by_last_moduli_inplace::(&ring_rns); - //test_div_floor_by_last_moduli_inplace::(&ring_rns); + test_div_floor_by_last_modulus_inplace::(&ring_rns); + test_div_floor_by_last_moduli::(&ring_rns); + test_div_floor_by_last_moduli::(&ring_rns); + test_div_floor_by_last_moduli_inplace::(&ring_rns); + test_div_floor_by_last_moduli_inplace::(&ring_rns); } fn test_div_floor_by_last_modulus(ring_rns: &RingRNS) { @@ -83,8 +83,6 @@ fn test_div_floor_by_last_modulus_inplace(ring_rns: &RingRNS(&mut a); @@ -112,9 +110,6 @@ fn test_div_floor_by_last_modulus_inplace(ring_rns: &RingRNS(ring_rns: &RingRNS) { let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); - let nb_moduli: usize = ring_rns.level()-1; + let nb_moduli: usize = ring_rns.level(); let mut a: PolyRNS = ring_rns.new_polyrns(); let mut b: PolyRNS = ring_rns.new_polyrns(); - let mut c: PolyRNS = ring_rns.at_level(ring_rns.level() - 1).new_polyrns(); + let mut c: PolyRNS = ring_rns.at_level(ring_rns.level() - nb_moduli).new_polyrns(); // Allocates a random PolyRNS ring_rns.fill_uniform(&mut source, &mut a); @@ -156,7 +151,8 @@ fn test_div_floor_by_last_moduli(ring_rns: &RingRNS) { // Performs floor division on a let mut scalar_big = BigInt::from(1); - (0..nb_moduli).for_each(|i|{scalar_big *= BigInt::from(ring_rns.0[ring_rns.level()].modulus.q)}); + (0..nb_moduli).for_each(|i|{scalar_big *= BigInt::from(ring_rns.0[ring_rns.level()-i].modulus.q)}); + coeffs_a.iter_mut().for_each(|a| { // Emulates floor division in [0, q-1] and maps to [-(q-1)/2, (q-1)/2-1] *a /= &scalar_big; @@ -172,7 +168,7 @@ fn test_div_floor_by_last_moduli_inplace(ring_rns: &RingRNS = ring_rns.new_polyrns(); let mut b: PolyRNS = ring_rns.new_polyrns(); @@ -205,7 +201,7 @@ fn test_div_floor_by_last_moduli_inplace(ring_rns: &RingRNS