added unit tests for digit decomposition

This commit is contained in:
Jean-Philippe Bossuat
2025-01-21 10:25:32 +01:00
parent 2888b9128d
commit 1ac43bf35b
5 changed files with 165 additions and 25 deletions

View File

@@ -13,14 +13,14 @@ pub const BARRETT: REDUCEMOD = 4;
pub const BARRETTLAZY: REDUCEMOD = 5; pub const BARRETTLAZY: REDUCEMOD = 5;
pub trait WordOps<O> { pub trait WordOps<O> {
fn log2(self) -> O; fn log2(self) -> usize;
fn reverse_bits_msb(self, n: u32) -> O; fn reverse_bits_msb(self, n: u32) -> O;
fn mask(self) -> O; fn mask(self) -> O;
} }
impl WordOps<u64> for u64 { impl WordOps<u64> for u64 {
#[inline(always)] #[inline(always)]
fn log2(self) -> u64 { fn log2(self) -> usize {
(u64::BITS - (self - 1).leading_zeros()) as _ (u64::BITS - (self - 1).leading_zeros()) as _
} }
#[inline(always)] #[inline(always)]
@@ -188,17 +188,17 @@ pub trait ScalarOperations<O> {
a: &mut u64, a: &mut u64,
); );
fn sa_rsh_sb_mask_sc_into_sa(&self, c: &u64, b: &u64, a: &mut u64); fn sa_rsh_sb_mask_sc_into_sa(&self, c: &usize, b: &u64, a: &mut u64);
fn sa_rsh_sb_mask_sc_into_sd(&self, a: &u64, b: &u64, c: &u64, d: &mut u64); fn sa_rsh_sb_mask_sc_into_sd(&self, a: &u64, b: &usize, c: &u64, d: &mut u64);
fn sa_rsh_sb_mask_sc_add_sd_into_sd(&self, a: &u64, b: &u64, c: &u64, d: &mut u64); fn sa_rsh_sb_mask_sc_add_sd_into_sd(&self, a: &u64, b: &usize, c: &u64, d: &mut u64);
fn sa_signed_digit_into_sb<const CARRYOVERWRITE: bool, const BALANCED: bool>( fn sa_signed_digit_into_sb<const CARRYOVERWRITE: bool, const BALANCED: bool>(
&self, &self,
a: &u64, a: &u64,
base: &u64, base: &u64,
shift: &u64, shift: &usize,
mask: &u64, mask: &u64,
carry: &mut u64, carry: &mut u64,
b: &mut u64, b: &mut u64,
@@ -372,13 +372,13 @@ pub trait VectorOperations<O> {
); );
// vec(a) <- (vec(a)>>scalar(b)) & scalar(c). // vec(a) <- (vec(a)>>scalar(b)) & scalar(c).
fn va_rsh_sb_mask_sd_into_va<const CHUNK: usize>(&self, sb: &u64, sc: &u64, va: &mut [u64]); fn va_rsh_sb_mask_sc_into_va<const CHUNK: usize>(&self, sb: &usize, sc: &u64, va: &mut [u64]);
// vec(d) <- (vec(a)>>scalar(b)) & scalar(c). // vec(d) <- (vec(a)>>scalar(b)) & scalar(c).
fn va_rsh_sb_mask_sc_into_vd<const CHUNK: usize>( fn va_rsh_sb_mask_sc_into_vd<const CHUNK: usize>(
&self, &self,
va: &[u64], va: &[u64],
sb: &u64, sb: &usize,
sc: &u64, sc: &u64,
vd: &mut [u64], vd: &mut [u64],
); );
@@ -387,7 +387,7 @@ pub trait VectorOperations<O> {
fn va_rsh_sb_mask_sc_add_vd_into_vd<const CHUNK: usize>( fn va_rsh_sb_mask_sc_add_vd_into_vd<const CHUNK: usize>(
&self, &self,
va: &[u64], va: &[u64],
sb: &u64, sb: &usize,
sc: &u64, sc: &u64,
vd: &mut [u64], vd: &mut [u64],
); );
@@ -398,7 +398,7 @@ pub trait VectorOperations<O> {
&self, &self,
i: usize, i: usize,
va: &[u64], va: &[u64],
sb: &u64, sb: &usize,
vc: &mut [u64], vc: &mut [u64],
); );
@@ -410,7 +410,7 @@ pub trait VectorOperations<O> {
&self, &self,
i: usize, i: usize,
va: &[u64], va: &[u64],
sb: &u64, sb: &usize,
carry: &mut [u64], carry: &mut [u64],
vc: &mut [u64], vc: &mut [u64],
); );

View File

@@ -22,7 +22,7 @@ impl MontgomeryPrecomp<u64> {
q_inv = q_inv.wrapping_mul(q_pow); q_inv = q_inv.wrapping_mul(q_pow);
q_pow = q_pow.wrapping_mul(q_pow); q_pow = q_pow.wrapping_mul(q_pow);
} }
let mut precomp = Self { let mut precomp: MontgomeryPrecomp<u64> = Self {
q: q, q: q,
two_q: q << 1, two_q: q << 1,
four_q: q << 2, four_q: q << 2,

View File

@@ -213,17 +213,17 @@ impl ScalarOperations<u64> for Prime<u64> {
} }
#[inline(always)] #[inline(always)]
fn sa_rsh_sb_mask_sc_into_sa(&self, b: &u64, c: &u64, a: &mut u64) { fn sa_rsh_sb_mask_sc_into_sa(&self, b: &usize, c: &u64, a: &mut u64) {
*a = (*a >> b) & c *a = (*a >> b) & c
} }
#[inline(always)] #[inline(always)]
fn sa_rsh_sb_mask_sc_into_sd(&self, a: &u64, b: &u64, c: &u64, d: &mut u64) { fn sa_rsh_sb_mask_sc_into_sd(&self, a: &u64, b: &usize, c: &u64, d: &mut u64) {
*d = (*a >> b) & c *d = (*a >> b) & c
} }
#[inline(always)] #[inline(always)]
fn sa_rsh_sb_mask_sc_add_sd_into_sd(&self, a: &u64, b: &u64, c: &u64, d: &mut u64) { fn sa_rsh_sb_mask_sc_add_sd_into_sd(&self, a: &u64, b: &usize, c: &u64, d: &mut u64) {
*d += (*a >> b) & c *d += (*a >> b) & c
} }
@@ -232,7 +232,7 @@ impl ScalarOperations<u64> for Prime<u64> {
&self, &self,
a: &u64, a: &u64,
base: &u64, base: &u64,
shift: &u64, shift: &usize,
mask: &u64, mask: &u64,
carry: &mut u64, carry: &mut u64,
b: &mut u64, b: &mut u64,
@@ -246,7 +246,7 @@ impl ScalarOperations<u64> for Prime<u64> {
let c: u64 = if BALANCED && *carry == base >> 1 { let c: u64 = if BALANCED && *carry == base >> 1 {
a & 1 a & 1
} else { } else {
((*carry | (*carry << 1)) >> base) & 1 ((*carry | (*carry << 1)) >> shift) & 1
}; };
*b = *carry + (self.q - base) * c; *b = *carry + (self.q - base) * c;
@@ -561,7 +561,7 @@ impl VectorOperations<u64> for Prime<u64> {
} }
// vec(a) <- (vec(a)>>scalar(b)) & scalar(c). // vec(a) <- (vec(a)>>scalar(b)) & scalar(c).
fn va_rsh_sb_mask_sd_into_va<const CHUNK: usize>(&self, sb: &u64, sc: &u64, va: &mut [u64]) { fn va_rsh_sb_mask_sc_into_va<const CHUNK: usize>(&self, sb: &usize, sc: &u64, va: &mut [u64]) {
apply_ssv!(self, Self::sa_rsh_sb_mask_sc_into_sa, sb, sc, va, CHUNK); apply_ssv!(self, Self::sa_rsh_sb_mask_sc_into_sa, sb, sc, va, CHUNK);
} }
@@ -569,7 +569,7 @@ impl VectorOperations<u64> for Prime<u64> {
fn va_rsh_sb_mask_sc_into_vd<const CHUNK: usize>( fn va_rsh_sb_mask_sc_into_vd<const CHUNK: usize>(
&self, &self,
va: &[u64], va: &[u64],
sb: &u64, sb: &usize,
sc: &u64, sc: &u64,
vd: &mut [u64], vd: &mut [u64],
) { ) {
@@ -580,7 +580,7 @@ impl VectorOperations<u64> for Prime<u64> {
fn va_rsh_sb_mask_sc_add_vd_into_vd<const CHUNK: usize>( fn va_rsh_sb_mask_sc_add_vd_into_vd<const CHUNK: usize>(
&self, &self,
va: &[u64], va: &[u64],
sb: &u64, sb: &usize,
sc: &u64, sc: &u64,
vd: &mut [u64], vd: &mut [u64],
) { ) {
@@ -601,10 +601,10 @@ impl VectorOperations<u64> for Prime<u64> {
&self, &self,
i: usize, i: usize,
va: &[u64], va: &[u64],
sb: &u64, sb: &usize,
vc: &mut [u64], vc: &mut [u64],
) { ) {
self.va_rsh_sb_mask_sc_into_vd::<CHUNK>(va, &((i as u64) * sb), &((1 << sb) - 1), vc); self.va_rsh_sb_mask_sc_into_vd::<CHUNK>(va, &(i * sb), &((1 << sb) - 1), vc);
} }
// vec(c) <- i-th signed digit base 2^{w} of vec(a). // vec(c) <- i-th signed digit base 2^{w} of vec(a).
@@ -615,7 +615,7 @@ impl VectorOperations<u64> for Prime<u64> {
&self, &self,
i: usize, i: usize,
va: &[u64], va: &[u64],
sb: &u64, sb: &usize,
carry: &mut [u64], carry: &mut [u64],
vc: &mut [u64], vc: &mut [u64],
) { ) {
@@ -627,7 +627,7 @@ impl VectorOperations<u64> for Prime<u64> {
Self::sa_signed_digit_into_sb::<true, BALANCED>, Self::sa_signed_digit_into_sb::<true, BALANCED>,
va, va,
&base, &base,
&(i as u64 * sb), &(i * sb),
&mask, &mask,
carry, carry,
vc, vc,
@@ -639,7 +639,7 @@ impl VectorOperations<u64> for Prime<u64> {
Self::sa_signed_digit_into_sb::<false, BALANCED>, Self::sa_signed_digit_into_sb::<false, BALANCED>,
va, va,
&base, &base,
&(i as u64 * sb), &(i * sb),
&mask, &mask,
carry, carry,
vc, vc,

View File

@@ -349,4 +349,69 @@ impl Ring<u64> {
self.modulus self.modulus
.vb_sub_va_add_sc_mul_sd_into_va::<CHUNK, BRANGE, REDUCE>(&b.0, c, d, &mut a.0); .vb_sub_va_add_sc_mul_sd_into_va::<CHUNK, BRANGE, REDUCE>(&b.0, c, d, &mut a.0);
} }
pub fn a_rsh_scalar_b_mask_scalar_c_into_d(
&self,
a: &Poly<u64>,
b: &usize,
c: &u64,
d: &mut Poly<u64>,
) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(d.n() == self.n(), "d.n()={} != n={}", d.n(), self.n());
self.modulus
.va_rsh_sb_mask_sc_into_vd::<CHUNK>(&a.0, b, c, &mut d.0);
}
pub fn a_rsh_scalar_b_mask_scalar_c_add_d_into_d(
&self,
a: &Poly<u64>,
b: &usize,
c: &u64,
d: &mut Poly<u64>,
) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(d.n() == self.n(), "d.n()={} != n={}", d.n(), self.n());
self.modulus
.va_rsh_sb_mask_sc_add_vd_into_vd::<CHUNK>(&a.0, b, c, &mut d.0);
}
pub fn a_ith_digit_unsigned_base_scalar_b_into_c(
&self,
i: usize,
a: &Poly<u64>,
b: &usize,
c: &mut Poly<u64>,
) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n());
self.modulus
.va_ith_digit_unsigned_base_sb_into_vc::<CHUNK>(i, &a.0, b, &mut c.0);
}
pub fn a_ith_digit_signed_base_scalar_b_into_c<const BALANCED: bool>(
&self,
i: usize,
a: &Poly<u64>,
b: &usize,
carry: &mut Poly<u64>,
c: &mut Poly<u64>,
) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(
carry.n() == self.n(),
"carry.n()={} != n={}",
carry.n(),
self.n()
);
debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n());
self.modulus
.va_ith_digit_signed_base_sb_into_vc::<CHUNK, BALANCED>(
i,
&a.0,
b,
&mut carry.0,
&mut c.0,
);
}
} }

View File

@@ -0,0 +1,75 @@
use itertools::izip;
use math::modulus::{WordOps, ONCE};
use math::poly::Poly;
use math::ring::Ring;
use sampling::source::Source;
#[test]
fn digit_decomposition() {
let n: usize = 1 << 4;
let q_base: u64 = 65537u64;
let q_power: usize = 1usize;
let ring: Ring<u64> = Ring::new(n, q_base, q_power);
sub_test("test_unsigned_digit_decomposition", || {
test_unsigned_digit_decomposition(&ring)
});
sub_test("test_signed_digit_decomposition::<BALANCED=false>", || {
test_signed_digit_decomposition::<false>(&ring)
});
sub_test("test_signed_digit_decomposition::<BALANCED=true>", || {
test_signed_digit_decomposition::<true>(&ring)
});
}
fn sub_test<F: FnOnce()>(name: &str, f: F) {
println!("Running {}", name);
f();
}
fn test_unsigned_digit_decomposition(ring: &Ring<u64>) {
let mut a: Poly<u64> = ring.new_poly();
let mut b: Poly<u64> = ring.new_poly();
let mut c: Poly<u64> = ring.new_poly();
let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed);
ring.fill_uniform(&mut source, &mut a);
let base: usize = 8;
let log_q: usize = ring.modulus.q.log2();
let d: usize = ((log_q + base - 1) / base) as _;
(0..d).for_each(|i| {
ring.a_ith_digit_unsigned_base_scalar_b_into_c(i, &a, &base, &mut b);
ring.a_mul_b_scalar_into_a::<ONCE>(&(1 << (i * base)), &mut b);
ring.a_add_b_into_b::<ONCE>(&b, &mut c);
});
izip!(a.0, c.0).for_each(|(a, c)| assert_eq!(a, c));
}
fn test_signed_digit_decomposition<const BALANCED: bool>(ring: &Ring<u64>) {
let mut a: Poly<u64> = ring.new_poly();
let mut b: Poly<u64> = ring.new_poly();
let mut carry: Poly<u64> = ring.new_poly();
let mut c: Poly<u64> = ring.new_poly();
let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed);
ring.fill_uniform(&mut source, &mut a);
let base: usize = 8;
let log_q: usize = ring.modulus.q.log2();
let d: usize = ((log_q + base - 1) / base) as _;
(0..d).for_each(|i| {
ring.a_ith_digit_signed_base_scalar_b_into_c::<BALANCED>(i, &a, &base, &mut carry, &mut b);
ring.a_mul_b_scalar_into_a::<ONCE>(&(1 << (i * base)), &mut b);
ring.a_add_b_into_b::<ONCE>(&b, &mut c);
});
izip!(a.0, c.0).for_each(|(a, c)| assert_eq!(a, c));
}