From 2888b9128d7ce0b59adacb24ed5ad5a2873ded7a Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 21 Jan 2025 00:21:57 +0100 Subject: [PATCH] added digit decomposition --- math/src/lib.rs | 70 +++++++++++-- math/src/modulus.rs | 60 +++++++++++ math/src/modulus/impl_u64/operations.rs | 133 +++++++++++++++++++++++- math/src/ring/impl_u64/sampling.rs | 2 +- 4 files changed, 254 insertions(+), 11 deletions(-) diff --git a/math/src/lib.rs b/math/src/lib.rs index da81012..8d1499e 100644 --- a/math/src/lib.rs +++ b/math/src/lib.rs @@ -146,7 +146,6 @@ pub mod macros { macro_rules! apply_sv { ($self:expr, $f:expr, $a:expr, $b:expr, $CHUNK:expr) => { let n: usize = $b.len(); - debug_assert!( CHUNK & (CHUNK - 1) == 0, "invalid CHUNK const: not a power of two" @@ -426,13 +425,13 @@ pub mod macros { ) .for_each(|(a, b, e)| { $f(&$self, &a[0], &b[0], $c, $d, &mut e[0]); - $f(&$self, &a[1], &b[1], $c, $d, &mut e[1]); - $f(&$self, &a[2], &b[2], $c, $d, &mut e[2]); - $f(&$self, &a[3], &b[3], $c, $d, &mut e[3]); - $f(&$self, &a[4], &b[4], $c, $d, &mut e[4]); - $f(&$self, &a[5], &b[5], $c, $d, &mut e[5]); - $f(&$self, &a[6], &b[6], $c, $d, &mut e[6]); - $f(&$self, &a[7], &b[7], $c, $d, &mut e[7]); + $f(&$self, &a[1], &b[1], $c, $d, &mut e[1]); + $f(&$self, &a[2], &b[2], $c, $d, &mut e[2]); + $f(&$self, &a[3], &b[3], $c, $d, &mut e[3]); + $f(&$self, &a[4], &b[4], $c, $d, &mut e[4]); + $f(&$self, &a[5], &b[5], $c, $d, &mut e[5]); + $f(&$self, &a[6], &b[6], $c, $d, &mut e[6]); + $f(&$self, &a[7], &b[7], $c, $d, &mut e[7]); }); let m = n - (n & 7); @@ -450,4 +449,59 @@ pub mod macros { } }; } + + #[macro_export] + macro_rules! apply_vsssvv { + ($self:expr, $f:expr, $a:expr, $b:expr, $c:expr, $d:expr, $e:expr, $g:expr, $CHUNK:expr) => { + let n: usize = $a.len(); + debug_assert!( + $e.len() == n, + "invalid argument b: e.len() = {} != a.len() = {}", + $e.len(), + n + ); + debug_assert!( + $g.len() == n, + "invalid argument g: g.len() = {} != a.len() = {}", + $g.len(), + n + ); + debug_assert!( + CHUNK & (CHUNK - 1) == 0, + "invalid CHUNK const: not a power of two" + ); + + match CHUNK { + 8 => { + izip!( + $a.chunks_exact(8), + $e.chunks_exact_mut(8), + $g.chunks_exact_mut(8) + ) + .for_each(|(a, e, g)| { + $f(&$self, &a[0], $b, $c, $d, &mut e[0], &mut g[0]); + $f(&$self, &a[1], $b, $c, $d, &mut e[1], &mut g[1]); + $f(&$self, &a[2], $b, $c, $d, &mut e[2], &mut g[2]); + $f(&$self, &a[3], $b, $c, $d, &mut e[3], &mut g[3]); + $f(&$self, &a[4], $b, $c, $d, &mut e[4], &mut g[4]); + $f(&$self, &a[5], $b, $c, $d, &mut e[5], &mut g[5]); + $f(&$self, &a[6], $b, $c, $d, &mut e[6], &mut g[6]); + $f(&$self, &a[7], $b, $c, $d, &mut e[7], &mut g[7]); + }); + + let m = n - (n & 7); + izip!($a[m..].iter(), $e[m..].iter_mut(), $g[m..].iter_mut()).for_each( + |(a, e, g)| { + $f(&$self, a, $b, $c, $d, e, g); + }, + ); + } + _ => { + izip!($a.iter(), $e.iter_mut(), $g.iter_mut()).for_each(|(a, e, g)| { + $f(&$self, a, $b, $c, $d, e, g); + }); + } + } + }; + } } diff --git a/math/src/modulus.rs b/math/src/modulus.rs index 7accd6e..dcb492f 100644 --- a/math/src/modulus.rs +++ b/math/src/modulus.rs @@ -187,6 +187,22 @@ pub trait ScalarOperations { d: &barrett::Barrett, a: &mut u64, ); + + fn sa_rsh_sb_mask_sc_into_sa(&self, c: &u64, b: &u64, a: &mut u64); + + fn sa_rsh_sb_mask_sc_into_sd(&self, a: &u64, b: &u64, c: &u64, d: &mut u64); + + fn sa_rsh_sb_mask_sc_add_sd_into_sd(&self, a: &u64, b: &u64, c: &u64, d: &mut u64); + + fn sa_signed_digit_into_sb( + &self, + a: &u64, + base: &u64, + shift: &u64, + mask: &u64, + carry: &mut u64, + b: &mut u64, + ); } pub trait VectorOperations { @@ -354,4 +370,48 @@ pub trait VectorOperations { sd: &barrett::Barrett, va: &mut [u64], ); + + // vec(a) <- (vec(a)>>scalar(b)) & scalar(c). + fn va_rsh_sb_mask_sd_into_va(&self, sb: &u64, sc: &u64, va: &mut [u64]); + + // vec(d) <- (vec(a)>>scalar(b)) & scalar(c). + fn va_rsh_sb_mask_sc_into_vd( + &self, + va: &[u64], + sb: &u64, + sc: &u64, + vd: &mut [u64], + ); + + // vec(d) <- vec(d) + (vec(a)>>scalar(b)) & scalar(c). + fn va_rsh_sb_mask_sc_add_vd_into_vd( + &self, + va: &[u64], + sb: &u64, + sc: &u64, + vd: &mut [u64], + ); + + // vec(c) <- i-th unsigned digit base 2^{sb} of vec(a). + // vec(c) is ensured to be in the range [0, 2^{sb}-1[ with E[vec(c)] = 2^{sb}-1. + fn va_ith_digit_unsigned_base_sb_into_vc( + &self, + i: usize, + va: &[u64], + sb: &u64, + vc: &mut [u64], + ); + + // vec(c) <- i-th signed digit base 2^{w} of vec(a). + // Reads the carry of the i-1-th iteration and write the carry on the i-th iteration on carry. + // if i > 0, carry of the i-1th iteration must be provided. + // if BALANCED: vec(c) is ensured to be [-2^{sb-1}, 2^{sb-1}[ with E[vec(c)] = 0, else E[vec(c)] = -0.5 + fn va_ith_digit_signed_base_sb_into_vc( + &self, + i: usize, + va: &[u64], + sb: &u64, + carry: &mut [u64], + vc: &mut [u64], + ); } diff --git a/math/src/modulus/impl_u64/operations.rs b/math/src/modulus/impl_u64/operations.rs index cff790f..0c412b4 100644 --- a/math/src/modulus/impl_u64/operations.rs +++ b/math/src/modulus/impl_u64/operations.rs @@ -4,8 +4,8 @@ use crate::modulus::prime::Prime; use crate::modulus::{ScalarOperations, VectorOperations}; use crate::modulus::{NONE, REDUCEMOD}; use crate::{ - apply_ssv, apply_sv, apply_svv, apply_v, apply_vssv, apply_vsv, apply_vv, apply_vvssv, - apply_vvsv, apply_vvv, + apply_ssv, apply_sv, apply_svv, apply_v, apply_vsssvv, apply_vssv, apply_vsv, apply_vv, + apply_vvssv, apply_vvsv, apply_vvv, }; use itertools::izip; @@ -211,6 +211,47 @@ impl ScalarOperations for Prime { self.sa_sub_sb_into_sb::(&(b + c), a); self.barrett.mul_external_assign::(*d, a); } + + #[inline(always)] + fn sa_rsh_sb_mask_sc_into_sa(&self, b: &u64, c: &u64, a: &mut u64) { + *a = (*a >> b) & c + } + + #[inline(always)] + fn sa_rsh_sb_mask_sc_into_sd(&self, a: &u64, b: &u64, c: &u64, d: &mut u64) { + *d = (*a >> b) & c + } + + #[inline(always)] + fn sa_rsh_sb_mask_sc_add_sd_into_sd(&self, a: &u64, b: &u64, c: &u64, d: &mut u64) { + *d += (*a >> b) & c + } + + #[inline(always)] + fn sa_signed_digit_into_sb( + &self, + a: &u64, + base: &u64, + shift: &u64, + mask: &u64, + carry: &mut u64, + b: &mut u64, + ) { + if CARRYOVERWRITE { + self.sa_rsh_sb_mask_sc_into_sd(a, shift, mask, carry); + } else { + self.sa_rsh_sb_mask_sc_add_sd_into_sd(a, shift, mask, carry); + } + + let c: u64 = if BALANCED && *carry == base >> 1 { + a & 1 + } else { + ((*carry | (*carry << 1)) >> base) & 1 + }; + + *b = *carry + (self.q - base) * c; + *carry = c; + } } impl VectorOperations for Prime { @@ -518,4 +559,92 @@ impl VectorOperations for Prime { CHUNK ); } + + // vec(a) <- (vec(a)>>scalar(b)) & scalar(c). + fn va_rsh_sb_mask_sd_into_va(&self, sb: &u64, sc: &u64, va: &mut [u64]) { + apply_ssv!(self, Self::sa_rsh_sb_mask_sc_into_sa, sb, sc, va, CHUNK); + } + + // vec(d) <- (vec(a)>>scalar(b)) & scalar(c). + fn va_rsh_sb_mask_sc_into_vd( + &self, + va: &[u64], + sb: &u64, + sc: &u64, + vd: &mut [u64], + ) { + apply_vssv!(self, Self::sa_rsh_sb_mask_sc_into_sd, va, sb, sc, vd, CHUNK); + } + + // vec(d) <- vec(d) + (vec(a)>>scalar(b)) & scalar(c). + fn va_rsh_sb_mask_sc_add_vd_into_vd( + &self, + va: &[u64], + sb: &u64, + sc: &u64, + vd: &mut [u64], + ) { + apply_vssv!( + self, + Self::sa_rsh_sb_mask_sc_add_sd_into_sd, + va, + sb, + sc, + vd, + CHUNK + ); + } + + // vec(c) <- i-th unsigned digit base 2^{sb} of vec(a). + // vec(c) is ensured to be in the range [0, 2^{sb}-1[ with E[vec(c)] = 2^{sb}-1. + fn va_ith_digit_unsigned_base_sb_into_vc( + &self, + i: usize, + va: &[u64], + sb: &u64, + vc: &mut [u64], + ) { + self.va_rsh_sb_mask_sc_into_vd::(va, &((i as u64) * sb), &((1 << sb) - 1), vc); + } + + // vec(c) <- i-th signed digit base 2^{w} of vec(a). + // Reads the carry of the i-1-th iteration and write the carry on the i-th iteration on carry. + // if i > 0, carry of the i-1th iteration must be provided. + // if BALANCED: vec(c) is ensured to be [-2^{sb-1}, 2^{sb-1}[ with E[vec(c)] = 0, else E[vec(c)] = -0.5 + fn va_ith_digit_signed_base_sb_into_vc( + &self, + i: usize, + va: &[u64], + sb: &u64, + carry: &mut [u64], + vc: &mut [u64], + ) { + let base: u64 = 1 << sb; + let mask: u64 = base - 1; + if i == 0 { + apply_vsssvv!( + self, + Self::sa_signed_digit_into_sb::, + va, + &base, + &(i as u64 * sb), + &mask, + carry, + vc, + CHUNK + ); + } else { + apply_vsssvv!( + self, + Self::sa_signed_digit_into_sb::, + va, + &base, + &(i as u64 * sb), + &mask, + carry, + vc, + CHUNK + ); + } + } } diff --git a/math/src/ring/impl_u64/sampling.rs b/math/src/ring/impl_u64/sampling.rs index a6e44e5..c6905e3 100644 --- a/math/src/ring/impl_u64/sampling.rs +++ b/math/src/ring/impl_u64/sampling.rs @@ -35,7 +35,7 @@ impl Ring { }); } - pub fn fill_normal(&self, source: &mut Source, sigma: f64, bound: f64, a: &mut Poly){ + pub fn fill_normal(&self, source: &mut Source, sigma: f64, bound: f64, a: &mut Poly) { self.fill_dist_f64(source, Normal::new(0.0, sigma).unwrap(), bound, a); } }