From 1ac43bf35bd5737ad63c26fd42fa62c6775a2ab5 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 21 Jan 2025 10:25:32 +0100 Subject: [PATCH] added unit tests for digit decomposition --- math/src/modulus.rs | 22 ++++---- math/src/modulus/impl_u64/montgomery.rs | 2 +- math/src/modulus/impl_u64/operations.rs | 26 ++++----- math/src/ring/impl_u64/ring.rs | 65 +++++++++++++++++++++ math/tests/digit_decomposition.rs | 75 +++++++++++++++++++++++++ 5 files changed, 165 insertions(+), 25 deletions(-) create mode 100644 math/tests/digit_decomposition.rs diff --git a/math/src/modulus.rs b/math/src/modulus.rs index dcb492f..ec9e2e1 100644 --- a/math/src/modulus.rs +++ b/math/src/modulus.rs @@ -13,14 +13,14 @@ pub const BARRETT: REDUCEMOD = 4; pub const BARRETTLAZY: REDUCEMOD = 5; pub trait WordOps { - fn log2(self) -> O; + fn log2(self) -> usize; fn reverse_bits_msb(self, n: u32) -> O; fn mask(self) -> O; } impl WordOps for u64 { #[inline(always)] - fn log2(self) -> u64 { + fn log2(self) -> usize { (u64::BITS - (self - 1).leading_zeros()) as _ } #[inline(always)] @@ -188,17 +188,17 @@ pub trait ScalarOperations { a: &mut u64, ); - fn sa_rsh_sb_mask_sc_into_sa(&self, c: &u64, b: &u64, a: &mut u64); + fn sa_rsh_sb_mask_sc_into_sa(&self, c: &usize, b: &u64, a: &mut u64); - fn sa_rsh_sb_mask_sc_into_sd(&self, a: &u64, b: &u64, c: &u64, d: &mut u64); + fn sa_rsh_sb_mask_sc_into_sd(&self, a: &u64, b: &usize, c: &u64, d: &mut u64); - fn sa_rsh_sb_mask_sc_add_sd_into_sd(&self, a: &u64, b: &u64, c: &u64, d: &mut u64); + fn sa_rsh_sb_mask_sc_add_sd_into_sd(&self, a: &u64, b: &usize, c: &u64, d: &mut u64); fn sa_signed_digit_into_sb( &self, a: &u64, base: &u64, - shift: &u64, + shift: &usize, mask: &u64, carry: &mut u64, b: &mut u64, @@ -372,13 +372,13 @@ pub trait VectorOperations { ); // vec(a) <- (vec(a)>>scalar(b)) & scalar(c). - fn va_rsh_sb_mask_sd_into_va(&self, sb: &u64, sc: &u64, va: &mut [u64]); + fn va_rsh_sb_mask_sc_into_va(&self, sb: &usize, sc: &u64, va: &mut [u64]); // vec(d) <- (vec(a)>>scalar(b)) & scalar(c). fn va_rsh_sb_mask_sc_into_vd( &self, va: &[u64], - sb: &u64, + sb: &usize, sc: &u64, vd: &mut [u64], ); @@ -387,7 +387,7 @@ pub trait VectorOperations { fn va_rsh_sb_mask_sc_add_vd_into_vd( &self, va: &[u64], - sb: &u64, + sb: &usize, sc: &u64, vd: &mut [u64], ); @@ -398,7 +398,7 @@ pub trait VectorOperations { &self, i: usize, va: &[u64], - sb: &u64, + sb: &usize, vc: &mut [u64], ); @@ -410,7 +410,7 @@ pub trait VectorOperations { &self, i: usize, va: &[u64], - sb: &u64, + sb: &usize, carry: &mut [u64], vc: &mut [u64], ); diff --git a/math/src/modulus/impl_u64/montgomery.rs b/math/src/modulus/impl_u64/montgomery.rs index 6bceebe..2cbc8f2 100644 --- a/math/src/modulus/impl_u64/montgomery.rs +++ b/math/src/modulus/impl_u64/montgomery.rs @@ -22,7 +22,7 @@ impl MontgomeryPrecomp { q_inv = q_inv.wrapping_mul(q_pow); q_pow = q_pow.wrapping_mul(q_pow); } - let mut precomp = Self { + let mut precomp: MontgomeryPrecomp = Self { q: q, two_q: q << 1, four_q: q << 2, diff --git a/math/src/modulus/impl_u64/operations.rs b/math/src/modulus/impl_u64/operations.rs index 0c412b4..8fab036 100644 --- a/math/src/modulus/impl_u64/operations.rs +++ b/math/src/modulus/impl_u64/operations.rs @@ -213,17 +213,17 @@ impl ScalarOperations for Prime { } #[inline(always)] - fn sa_rsh_sb_mask_sc_into_sa(&self, b: &u64, c: &u64, a: &mut u64) { + fn sa_rsh_sb_mask_sc_into_sa(&self, b: &usize, c: &u64, a: &mut u64) { *a = (*a >> b) & c } #[inline(always)] - fn sa_rsh_sb_mask_sc_into_sd(&self, a: &u64, b: &u64, c: &u64, d: &mut u64) { + fn sa_rsh_sb_mask_sc_into_sd(&self, a: &u64, b: &usize, c: &u64, d: &mut u64) { *d = (*a >> b) & c } #[inline(always)] - fn sa_rsh_sb_mask_sc_add_sd_into_sd(&self, a: &u64, b: &u64, c: &u64, d: &mut u64) { + fn sa_rsh_sb_mask_sc_add_sd_into_sd(&self, a: &u64, b: &usize, c: &u64, d: &mut u64) { *d += (*a >> b) & c } @@ -232,7 +232,7 @@ impl ScalarOperations for Prime { &self, a: &u64, base: &u64, - shift: &u64, + shift: &usize, mask: &u64, carry: &mut u64, b: &mut u64, @@ -246,7 +246,7 @@ impl ScalarOperations for Prime { let c: u64 = if BALANCED && *carry == base >> 1 { a & 1 } else { - ((*carry | (*carry << 1)) >> base) & 1 + ((*carry | (*carry << 1)) >> shift) & 1 }; *b = *carry + (self.q - base) * c; @@ -561,7 +561,7 @@ impl VectorOperations for Prime { } // vec(a) <- (vec(a)>>scalar(b)) & scalar(c). - fn va_rsh_sb_mask_sd_into_va(&self, sb: &u64, sc: &u64, va: &mut [u64]) { + fn va_rsh_sb_mask_sc_into_va(&self, sb: &usize, sc: &u64, va: &mut [u64]) { apply_ssv!(self, Self::sa_rsh_sb_mask_sc_into_sa, sb, sc, va, CHUNK); } @@ -569,7 +569,7 @@ impl VectorOperations for Prime { fn va_rsh_sb_mask_sc_into_vd( &self, va: &[u64], - sb: &u64, + sb: &usize, sc: &u64, vd: &mut [u64], ) { @@ -580,7 +580,7 @@ impl VectorOperations for Prime { fn va_rsh_sb_mask_sc_add_vd_into_vd( &self, va: &[u64], - sb: &u64, + sb: &usize, sc: &u64, vd: &mut [u64], ) { @@ -601,10 +601,10 @@ impl VectorOperations for Prime { &self, i: usize, va: &[u64], - sb: &u64, + sb: &usize, vc: &mut [u64], ) { - self.va_rsh_sb_mask_sc_into_vd::(va, &((i as u64) * sb), &((1 << sb) - 1), vc); + self.va_rsh_sb_mask_sc_into_vd::(va, &(i * sb), &((1 << sb) - 1), vc); } // vec(c) <- i-th signed digit base 2^{w} of vec(a). @@ -615,7 +615,7 @@ impl VectorOperations for Prime { &self, i: usize, va: &[u64], - sb: &u64, + sb: &usize, carry: &mut [u64], vc: &mut [u64], ) { @@ -627,7 +627,7 @@ impl VectorOperations for Prime { Self::sa_signed_digit_into_sb::, va, &base, - &(i as u64 * sb), + &(i * sb), &mask, carry, vc, @@ -639,7 +639,7 @@ impl VectorOperations for Prime { Self::sa_signed_digit_into_sb::, va, &base, - &(i as u64 * sb), + &(i * sb), &mask, carry, vc, diff --git a/math/src/ring/impl_u64/ring.rs b/math/src/ring/impl_u64/ring.rs index adc55af..c646fe2 100644 --- a/math/src/ring/impl_u64/ring.rs +++ b/math/src/ring/impl_u64/ring.rs @@ -349,4 +349,69 @@ impl Ring { self.modulus .vb_sub_va_add_sc_mul_sd_into_va::(&b.0, c, d, &mut a.0); } + + pub fn a_rsh_scalar_b_mask_scalar_c_into_d( + &self, + a: &Poly, + b: &usize, + c: &u64, + d: &mut Poly, + ) { + debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); + debug_assert!(d.n() == self.n(), "d.n()={} != n={}", d.n(), self.n()); + self.modulus + .va_rsh_sb_mask_sc_into_vd::(&a.0, b, c, &mut d.0); + } + + pub fn a_rsh_scalar_b_mask_scalar_c_add_d_into_d( + &self, + a: &Poly, + b: &usize, + c: &u64, + d: &mut Poly, + ) { + debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); + debug_assert!(d.n() == self.n(), "d.n()={} != n={}", d.n(), self.n()); + self.modulus + .va_rsh_sb_mask_sc_add_vd_into_vd::(&a.0, b, c, &mut d.0); + } + + pub fn a_ith_digit_unsigned_base_scalar_b_into_c( + &self, + i: usize, + a: &Poly, + b: &usize, + c: &mut Poly, + ) { + debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); + debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); + self.modulus + .va_ith_digit_unsigned_base_sb_into_vc::(i, &a.0, b, &mut c.0); + } + + pub fn a_ith_digit_signed_base_scalar_b_into_c( + &self, + i: usize, + a: &Poly, + b: &usize, + carry: &mut Poly, + c: &mut Poly, + ) { + debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); + debug_assert!( + carry.n() == self.n(), + "carry.n()={} != n={}", + carry.n(), + self.n() + ); + debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); + self.modulus + .va_ith_digit_signed_base_sb_into_vc::( + i, + &a.0, + b, + &mut carry.0, + &mut c.0, + ); + } } diff --git a/math/tests/digit_decomposition.rs b/math/tests/digit_decomposition.rs new file mode 100644 index 0000000..fd821a9 --- /dev/null +++ b/math/tests/digit_decomposition.rs @@ -0,0 +1,75 @@ +use itertools::izip; +use math::modulus::{WordOps, ONCE}; +use math::poly::Poly; +use math::ring::Ring; +use sampling::source::Source; + +#[test] +fn digit_decomposition() { + let n: usize = 1 << 4; + let q_base: u64 = 65537u64; + let q_power: usize = 1usize; + let ring: Ring = Ring::new(n, q_base, q_power); + + sub_test("test_unsigned_digit_decomposition", || { + test_unsigned_digit_decomposition(&ring) + }); + + sub_test("test_signed_digit_decomposition::", || { + test_signed_digit_decomposition::(&ring) + }); + + sub_test("test_signed_digit_decomposition::", || { + test_signed_digit_decomposition::(&ring) + }); +} + +fn sub_test(name: &str, f: F) { + println!("Running {}", name); + f(); +} + +fn test_unsigned_digit_decomposition(ring: &Ring) { + let mut a: Poly = ring.new_poly(); + let mut b: Poly = ring.new_poly(); + let mut c: Poly = ring.new_poly(); + + let seed: [u8; 32] = [0; 32]; + let mut source: Source = Source::new(seed); + ring.fill_uniform(&mut source, &mut a); + + let base: usize = 8; + let log_q: usize = ring.modulus.q.log2(); + let d: usize = ((log_q + base - 1) / base) as _; + + (0..d).for_each(|i| { + ring.a_ith_digit_unsigned_base_scalar_b_into_c(i, &a, &base, &mut b); + ring.a_mul_b_scalar_into_a::(&(1 << (i * base)), &mut b); + ring.a_add_b_into_b::(&b, &mut c); + }); + + izip!(a.0, c.0).for_each(|(a, c)| assert_eq!(a, c)); +} + +fn test_signed_digit_decomposition(ring: &Ring) { + let mut a: Poly = ring.new_poly(); + let mut b: Poly = ring.new_poly(); + let mut carry: Poly = ring.new_poly(); + let mut c: Poly = ring.new_poly(); + + let seed: [u8; 32] = [0; 32]; + let mut source: Source = Source::new(seed); + ring.fill_uniform(&mut source, &mut a); + + let base: usize = 8; + let log_q: usize = ring.modulus.q.log2(); + let d: usize = ((log_q + base - 1) / base) as _; + + (0..d).for_each(|i| { + ring.a_ith_digit_signed_base_scalar_b_into_c::(i, &a, &base, &mut carry, &mut b); + ring.a_mul_b_scalar_into_a::(&(1 << (i * base)), &mut b); + ring.a_add_b_into_b::(&b, &mut c); + }); + + izip!(a.0, c.0).for_each(|(a, c)| assert_eq!(a, c)); +}