From 7e4ca491c7508a2b840f1ef63b55145aa30950bc Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 6 Jan 2025 18:05:32 +0100 Subject: [PATCH] wip --- math/src/lib.rs | 84 ++++++++++++ math/src/modulus.rs | 107 +++++++++------ math/src/modulus/impl_u64/operations.rs | 156 ++++++++++++++++----- math/src/ring/impl_u64/rescaling_rns.rs | 173 ++++++++++++++++++++---- math/src/ring/impl_u64/ring.rs | 50 ++++--- math/src/ring/impl_u64/ring_rns.rs | 16 +-- math/tests/rescaling_rns.rs | 164 +++++++++++++++++++++- 7 files changed, 617 insertions(+), 133 deletions(-) diff --git a/math/src/lib.rs b/math/src/lib.rs index f8e16be..c471447 100644 --- a/math/src/lib.rs +++ b/math/src/lib.rs @@ -314,4 +314,88 @@ pub mod macros { } }; } + + #[macro_export] + macro_rules! apply_vssv { + ($self:expr, $f:expr, $a:expr, $b:expr, $c:expr, $d:expr, $CHUNK:expr) => { + let n: usize = $a.len(); + debug_assert!( + $d.len() == n, + "invalid argument d: d.len() = {} != a.len() = {}", + $d.len(), + n + ); + debug_assert!( + CHUNK & (CHUNK - 1) == 0, + "invalid CHUNK const: not a power of two" + ); + + match CHUNK { + 8 => { + izip!( + $a.chunks_exact(8), + $d.chunks_exact_mut(8) + ) + .for_each(|(a, d)| { + $f(&$self, &a[0], $b, $c, &mut d[0]); + $f(&$self, &a[1], $b, $c, &mut d[1]); + $f(&$self, &a[2], $b, $c, &mut d[2]); + $f(&$self, &a[3], $b, $c, &mut d[3]); + $f(&$self, &a[4], $b, $c, &mut d[4]); + $f(&$self, &a[5], $b, $c, &mut d[5]); + $f(&$self, &a[6], $b, $c, &mut d[6]); + $f(&$self, &a[7], $b, $c, &mut d[7]); + }); + + let m = n - (n & 7); + izip!($a[m..].iter(), $d[m..].iter_mut()).for_each( + |(a, d)| { + $f(&$self, a, $b, $c, d); + }, + ); + } + _ => { + izip!($a.iter(), $d.iter_mut()).for_each(|(a, d)| { + $f(&$self, a, $b, $c, d); + }); + } + } + }; + } + + #[macro_export] + macro_rules! apply_ssv { + ($self:expr, $f:expr, $a:expr, $b:expr, $c:expr, $CHUNK:expr) => { + let n: usize = $c.len(); + debug_assert!( + CHUNK & (CHUNK - 1) == 0, + "invalid CHUNK const: not a power of two" + ); + + match CHUNK { + 8 => { + izip!($c.chunks_exact_mut(8)).for_each(|c| { + $f(&$self, $a, $b, &mut c[0]); + $f(&$self, $a, $b, &mut c[1]); + $f(&$self, $a, $b, &mut c[2]); + $f(&$self, $a, $b, &mut c[3]); + $f(&$self, $a, $b, &mut c[4]); + $f(&$self, $a, $b, &mut c[5]); + $f(&$self, $a, $b, &mut c[6]); + $f(&$self, $a, $b, &mut c[7]); + }); + + let m = n - (n & 7); + izip!($c[m..].iter_mut()).for_each(|c| { + $f(&$self, $a, $b, c); + }); + } + _ => { + izip!($c.iter_mut()).for_each(|c| { + $f(&$self, $a, $b, c); + }); + } + } + }; + } } diff --git a/math/src/modulus.rs b/math/src/modulus.rs index 305eee4..dc97687 100644 --- a/math/src/modulus.rs +++ b/math/src/modulus.rs @@ -74,16 +74,16 @@ pub trait ScalarOperations { fn sa_add_sb_into_sb(&self, a: &O, b: &mut O); // Assigns a - b to c. - fn sa_sub_sb_into_sc(&self, a: &O, b: &O, c: &mut O); + fn sa_sub_sb_into_sc(&self, a: &O, b: &O, c: &mut O); - // Assigns b - a to b. - fn sa_sub_sb_into_sb(&self, a: &O, b: &mut O); + // Assigns a - b to b. + fn sa_sub_sb_into_sb(&self, a: &O, b: &mut O); // Assigns -a to a. - fn sa_neg_into_sa(&self, a: &mut O); + fn sa_neg_into_sa(&self, a: &mut O); // Assigns -a to b. - fn sa_neg_into_sb(&self, a: &O, b: &mut O); + fn sa_neg_into_sb(&self, a: &O, b: &mut O); // Assigns a * 2^64 to b. fn sa_prep_mont_into_sb( @@ -122,8 +122,8 @@ pub trait ScalarOperations { b: &mut O, ); - // Assigns (a + 2q - b) * c to d. - fn sa_sub_sb_mul_sc_into_sd( + // Assigns (a + q - b) * c to d. + fn sa_sub_sb_mul_sc_into_sd( &self, a: &O, b: &O, @@ -131,13 +131,30 @@ pub trait ScalarOperations { d: &mut O, ); - // Assigns (a + 2q - b) * c to b. - fn sa_sub_sb_mul_sc_into_sb( + // Assigns (a + q - b) * c to b. + fn sa_sub_sb_mul_sc_into_sb( &self, a: &u64, c: &barrett::Barrett, b: &mut u64, ); + + // Assigns (a + b) * c to a. + fn sa_add_sb_mul_sc_into_sa( + &self, + b: &u64, + c: &barrett::Barrett, + a: &mut u64 + ); + + // Assigns (a + b) * c to d. + fn sa_add_sb_mul_sc_into_sd( + &self, + a: &u64, + b: &u64, + c: &barrett::Barrett, + d: &mut u64 + ); } pub trait VectorOperations { @@ -145,18 +162,18 @@ pub trait VectorOperations { fn va_reduce_into_va(&self, x: &mut [O]); // ADD - // Assigns a[i] + b[i] to c[i] + // vec(c) <- vec(a) + vec(b). fn va_add_vb_into_vc( &self, - a: &[O], - b: &[O], - c: &mut [O], + va: &[O], + vb: &[O], + vc: &mut [O], ); - // Assigns a[i] + b[i] to b[i] + // vec(b) <- vec(a) + vec(b). fn va_add_vb_into_vb(&self, a: &[O], b: &mut [O]); - // Assigns a[i] + b to c[i] + // vec(c) <- vec(a) + scalar(b). fn va_add_sb_into_vc( &self, a: &[O], @@ -164,37 +181,34 @@ pub trait VectorOperations { c: &mut [O], ); - // Assigns b[i] + a to b[i] - fn sa_add_vb_into_vb(&self, a: &O, b: &mut [O]); + // vec(b) <- vec(b) + scalar(a). + fn va_add_sb_into_va(&self, a: &O, b: &mut [O]); - // SUB - // Assigns a[i] - b[i] to b[i] - fn va_sub_vb_into_vb(&self, a: &[O], b: &mut [O]); + // vec(b) <- vec(a) - vec(b). + fn va_sub_vb_into_vb(&self, a: &[O], b: &mut [O]); - // Assigns a[i] - b[i] to c[i] - fn va_sub_vb_into_vc( + // vec(c) <- vec(a) - vec(b). + fn va_sub_vb_into_vc( &self, a: &[O], b: &[O], c: &mut [O], ); - // NEG - // Assigns -a[i] to a[i]. - fn va_neg_into_va(&self, a: &mut [O]); + // vec(a) <- -vec(a). + fn va_neg_into_va(&self, a: &mut [O]); - // Assigns -a[i] to a[i]. - fn va_neg_into_vb(&self, a: &[O], b: &mut [O]); + // vec(b) <- -vec(a). + fn va_neg_into_vb(&self, a: &[O], b: &mut [O]); - // MUL MONTGOMERY - // Assigns a * 2^64 to b. + // vec(b) <- vec(a) fn va_prep_mont_into_vb( &self, a: &[O], b: &mut [montgomery::Montgomery], ); - // Assigns a[i] * b[i] to c[i]. + // vec(c) <- vec(a) * vec(b). fn va_mont_mul_vb_into_vc( &self, a: &[montgomery::Montgomery], @@ -202,22 +216,21 @@ pub trait VectorOperations { c: &mut [O], ); - // Assigns a[i] * b[i] to b[i]. + // vec(b) <- vec(a) * vec(b). fn va_mont_mul_vb_into_vb( &self, a: &[montgomery::Montgomery], b: &mut [O], ); - // MUL BARRETT - // Assigns a * b[i] to b[i]. + // vec(b) <- vec(b) * scalar(a). fn sa_barrett_mul_vb_into_vb( &self, a: &barrett::Barrett, b: &mut [u64], ); - // Assigns a * b[i] to c[i]. + // vec(c) <- vec(b) * scalar(a). fn sa_barrett_mul_vb_into_vc( &self, a: &barrett::Barrett, @@ -225,9 +238,8 @@ pub trait VectorOperations { c: &mut [u64], ); - // OTHERS - // Assigns (a[i] + 2q - b[i]) * c to d[i]. - fn va_sub_vb_mul_sc_into_vd( + // vec(d) <- (vec(a) + VBRANGE * q - vec(b)) * scalar(c). + fn va_sub_vb_mul_sc_into_vd( &self, a: &[u64], b: &[u64], @@ -235,11 +247,28 @@ pub trait VectorOperations { d: &mut [u64], ); - // Assigns (a[i] + 2q - b[i]) * c to b[i]. - fn va_sub_vb_mul_sc_into_vb( + // vec(b) <- (vec(a) + VBRANGE * q - vec(b)) * scalar(c). + fn va_sub_vb_mul_sc_into_vb( &self, a: &[u64], c: &barrett::Barrett, b: &mut [u64], ); + + // vec(c) <- (vec(a) + scalar(b)) * scalar(c). + fn va_add_sb_mul_sc_into_vd( + &self, + va: &[u64], + sb: &u64, + sc: &barrett::Barrett, + vd: &mut [u64], + ); + + // vec(a) <- (vec(a) + scalar(b)) * scalar(c). + fn va_add_sb_mul_sc_into_va( + &self, + sb: &u64, + sc: &barrett::Barrett, + va: &mut [u64], + ); } diff --git a/math/src/modulus/impl_u64/operations.rs b/math/src/modulus/impl_u64/operations.rs index c9804fb..ae2a3cd 100644 --- a/math/src/modulus/impl_u64/operations.rs +++ b/math/src/modulus/impl_u64/operations.rs @@ -1,10 +1,9 @@ use crate::modulus::barrett::Barrett; use crate::modulus::montgomery::Montgomery; use crate::modulus::prime::Prime; -use crate::modulus::ReduceOnce; -use crate::modulus::REDUCEMOD; +use crate::modulus::{REDUCEMOD, NONE}; use crate::modulus::{ScalarOperations, VectorOperations}; -use crate::{apply_sv, apply_svv, apply_v, apply_vsv, apply_vv, apply_vvsv, apply_vvv}; +use crate::{apply_sv, apply_svv, apply_v, apply_vsv, apply_vv, apply_vvsv, apply_vvv, apply_ssv, apply_vssv}; use itertools::izip; impl ScalarOperations for Prime { @@ -33,24 +32,46 @@ impl ScalarOperations for Prime { } #[inline(always)] - fn sa_sub_sb_into_sc(&self, a: &u64, b: &u64, c: &mut u64) { - *c = a.wrapping_add(self.q.wrapping_sub(*b)).reduce_once(self.q); + fn sa_sub_sb_into_sc(&self, a: &u64, b: &u64, c: &mut u64) { + match SBRANGE{ + 1 =>{*c = *a + self.q - *b} + 2 =>{*c = *a + self.two_q - *b} + 4 =>{*c = *a + self.four_q - *b} + _ => unreachable!("invalid SBRANGE argument"), + } + self.sa_reduce_into_sa::(c) } #[inline(always)] - fn sa_sub_sb_into_sb(&self, a: &u64, b: &mut u64) { - *b = a.wrapping_add(self.q.wrapping_sub(*b)).reduce_once(self.q); + fn sa_sub_sb_into_sb(&self, a: &u64, b: &mut u64) { + match SBRANGE{ + 1 =>{*b = *a + self.q - *b} + 2 =>{*b = *a + self.two_q - *b} + 4 =>{*b = *a + self.four_q - *b} + _ => unreachable!("invalid SBRANGE argument"), + } + self.sa_reduce_into_sa::(b) } #[inline(always)] - fn sa_neg_into_sa(&self, a: &mut u64) { - *a = self.q.wrapping_sub(*a); + fn sa_neg_into_sa(&self, a: &mut u64) { + match SBRANGE{ + 1 =>{*a = self.q - *a} + 2 =>{*a = self.two_q - *a} + 4 =>{*a = self.four_q - *a} + _ => unreachable!("invalid SBRANGE argument"), + } self.sa_reduce_into_sa::(a) } #[inline(always)] - fn sa_neg_into_sb(&self, a: &u64, b: &mut u64) { - *b = self.q.wrapping_sub(*a); + fn sa_neg_into_sb(&self, a: &u64, b: &mut u64) { + match SBRANGE{ + 1 =>{*b = self.q - *a} + 2 =>{*b = self.two_q - *a} + 4 =>{*b = self.four_q - *a} + _ => unreachable!("invalid SBRANGE argument"), + } self.sa_reduce_into_sa::(b) } @@ -90,27 +111,54 @@ impl ScalarOperations for Prime { } #[inline(always)] - fn sa_sub_sb_mul_sc_into_sd( + fn sa_sub_sb_mul_sc_into_sd( &self, a: &u64, b: &u64, c: &Barrett, d: &mut u64, ) { - *d = self.two_q.wrapping_sub(*b).wrapping_add(*a); + match VBRANGE{ + 1 =>{*d = a + self.q - b} + 2 =>{*d = a + self.two_q - b} + 4 =>{*d = a + self.four_q - b} + _ => unreachable!("invalid SBRANGE argument"), + } self.barrett.mul_external_assign::(*c, d); } #[inline(always)] - fn sa_sub_sb_mul_sc_into_sb( + fn sa_sub_sb_mul_sc_into_sb( &self, a: &u64, c: &Barrett, b: &mut u64, ) { - *b = self.two_q.wrapping_sub(*b).wrapping_add(*a); + self.sa_sub_sb_into_sb::(a, b); self.barrett.mul_external_assign::(*c, b); } + + #[inline(always)] + fn sa_add_sb_mul_sc_into_sd( + &self, + a: &u64, + b: &u64, + c: &Barrett, + d: &mut u64 + ) { + *d = self.barrett.mul_external::(*c, *a + *b); + } + + #[inline(always)] + fn sa_add_sb_mul_sc_into_sa( + &self, + b: &u64, + c: &Barrett, + a: &mut u64 + ) { + *a = self.barrett.mul_external::(*c, *a + *b); + } + } impl VectorOperations for Prime { @@ -145,6 +193,15 @@ impl VectorOperations for Prime { apply_vv!(self, Self::sa_add_sb_into_sb::, a, b, CHUNK); } + #[inline(always)] + fn va_add_sb_into_va( + &self, + b: &u64, + a: &mut [u64], + ) { + apply_sv!(self, Self::sa_add_sb_into_sb::, b, a, CHUNK); + } + #[inline(always)] fn va_add_sb_into_vc( &self, @@ -156,45 +213,36 @@ impl VectorOperations for Prime { } #[inline(always)] - fn sa_add_vb_into_vb( - &self, - a: &u64, - b: &mut [u64], - ) { - apply_sv!(self, Self::sa_add_sb_into_sb::, a, b, CHUNK); - } - - #[inline(always)] - fn va_sub_vb_into_vc( + fn va_sub_vb_into_vc( &self, a: &[u64], b: &[u64], c: &mut [u64], ) { - apply_vvv!(self, Self::sa_sub_sb_into_sc::, a, b, c, CHUNK); + apply_vvv!(self, Self::sa_sub_sb_into_sc::, a, b, c, CHUNK); } #[inline(always)] - fn va_sub_vb_into_vb( + fn va_sub_vb_into_vb( &self, a: &[u64], b: &mut [u64], ) { - apply_vv!(self, Self::sa_sub_sb_into_sb::, a, b, CHUNK); + apply_vv!(self, Self::sa_sub_sb_into_sb::, a, b, CHUNK); } #[inline(always)] - fn va_neg_into_va(&self, a: &mut [u64]) { - apply_v!(self, Self::sa_neg_into_sa::, a, CHUNK); + fn va_neg_into_va(&self, a: &mut [u64]) { + apply_v!(self, Self::sa_neg_into_sa::, a, CHUNK); } #[inline(always)] - fn va_neg_into_vb( + fn va_neg_into_vb( &self, a: &[u64], b: &mut [u64], ) { - apply_vv!(self, Self::sa_neg_into_sb::, a, b, CHUNK); + apply_vv!(self, Self::sa_neg_into_sb::, a, b, CHUNK); } #[inline(always)] @@ -251,7 +299,7 @@ impl VectorOperations for Prime { apply_sv!(self, Self::sa_barrett_mul_sb_into_sb::, a, b, CHUNK); } - fn va_sub_vb_mul_sc_into_vd( + fn va_sub_vb_mul_sc_into_vd( &self, a: &[u64], b: &[u64], @@ -260,7 +308,7 @@ impl VectorOperations for Prime { ) { apply_vvsv!( self, - Self::sa_sub_sb_mul_sc_into_sd::, + Self::sa_sub_sb_mul_sc_into_sd::, a, b, c, @@ -269,7 +317,7 @@ impl VectorOperations for Prime { ); } - fn va_sub_vb_mul_sc_into_vb( + fn va_sub_vb_mul_sc_into_vb( &self, a: &[u64], b: &Barrett, @@ -277,11 +325,47 @@ impl VectorOperations for Prime { ) { apply_vsv!( self, - Self::sa_sub_sb_mul_sc_into_sb::, + Self::sa_sub_sb_mul_sc_into_sb::, a, b, c, CHUNK ); } + + // vec(a) <- (vec(a) + scalar(b)) * scalar(c); + fn va_add_sb_mul_sc_into_va( + &self, + b: &u64, + c: &Barrett, + a: &mut [u64], + ) { + apply_ssv!( + self, + Self::sa_add_sb_mul_sc_into_sa::, + b, + c, + a, + CHUNK + ); + } + + // vec(a) <- (vec(a) + scalar(b)) * scalar(c); + fn va_add_sb_mul_sc_into_vd( + &self, + a: &[u64], + b: &u64, + c: &Barrett, + d: &mut [u64], + ) { + apply_vssv!( + self, + Self::sa_add_sb_mul_sc_into_sd::, + a, + b, + c, + d, + CHUNK + ); + } } diff --git a/math/src/ring/impl_u64/rescaling_rns.rs b/math/src/ring/impl_u64/rescaling_rns.rs index 3296133..e4884ef 100644 --- a/math/src/ring/impl_u64/rescaling_rns.rs +++ b/math/src/ring/impl_u64/rescaling_rns.rs @@ -1,5 +1,5 @@ use crate::modulus::barrett::Barrett; -use crate::modulus::ONCE; +use crate::modulus::{NONE, ONCE, BARRETT}; use crate::poly::PolyRNS; use crate::ring::Ring; use crate::ring::RingRNS; @@ -34,8 +34,8 @@ impl RingRNS { let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); self.0[level].intt::(a.at(level), &mut buf_ntt_q_scaling[0]); for (i, r) in self.0[0..level].iter().enumerate() { - r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); - r.sum_aqqmb_prod_c_scalar_barrett::( + r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); + r.a_sub_b_mul_c_scalar_barrett::<2, ONCE>( &buf_ntt_qi_scaling[0], a.at(i), &rescaling_constants.0[i], @@ -44,7 +44,7 @@ impl RingRNS { } } else { for (i, r) in self.0[0..level].iter().enumerate() { - r.sum_aqqmb_prod_c_scalar_barrett::( + r.a_sub_b_mul_c_scalar_barrett::<2, ONCE>( a.at(level), a.at(i), &rescaling_constants.0[i], @@ -73,19 +73,19 @@ impl RingRNS { if NTT { let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); - self.0[level].intt::(a.at(level), &mut buf_ntt_q_scaling[0]); + self.0[level].intt::(a.at(level), &mut buf_ntt_q_scaling[0]); for (i, r) in self.0[0..level].iter().enumerate() { r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); - r.sum_aqqmb_prod_c_scalar_barrett_inplace::( + r.a_sub_b_mul_c_scalar_barrett_inplace::<2, ONCE>( &buf_ntt_qi_scaling[0], &rescaling_constants.0[i], a.at_mut(i), ); } } else { - let (a_i, a_level) = buf.0.split_at_mut(level); + let (a_i, a_level) = a.0.split_at_mut(level); for (i, r) in self.0[0..level].iter().enumerate() { - r.sum_aqqmb_prod_c_scalar_barrett_inplace::( + r.a_sub_b_mul_c_scalar_barrett_inplace::<2, ONCE>( &a_level[0], &rescaling_constants.0[i], &mut a_i[i], @@ -102,6 +102,9 @@ impl RingRNS { buf: &mut PolyRNS, c: &mut PolyRNS, ) { + + println!("{:?}", buf); + debug_assert!( self.level() <= a.level(), "invalid input a: self.level()={} > a.level()={}", @@ -205,26 +208,40 @@ impl RingRNS { let r_last: &Ring = &self.0[level]; let q_level_half: u64 = r_last.modulus.q >> 1; let rescaling_constants: ScalarRNS> = self.rescaling_constant(); - let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); + let (buf_q_scaling, buf_qi_scaling) = buf.0.split_at_mut(1); if NTT { - r_last.intt::(a.at(level), &mut buf_ntt_q_scaling[0]); - r_last.add_scalar_inplace::(&q_level_half, &mut buf_ntt_q_scaling[0]); + r_last.intt::(a.at(level), &mut buf_q_scaling[0]); + r_last.add_scalar_inplace::(&q_level_half, &mut buf_q_scaling[0]); for (i, r) in self.0[0..level].iter().enumerate() { - r_last.add_scalar::( - &buf_ntt_q_scaling[0], - &q_level_half, - &mut buf_ntt_qi_scaling[0], + r_last.add_scalar::( + &buf_q_scaling[0], + &(r.modulus.q - r_last.modulus.barrett.reduce::(&q_level_half)), + &mut buf_qi_scaling[0], ); - r.ntt_inplace::(&mut buf_ntt_qi_scaling[0]); - r.sum_aqqmb_prod_c_scalar_barrett::( - &buf_ntt_qi_scaling[0], + r.ntt_inplace::(&mut buf_qi_scaling[0]); + r.a_sub_b_mul_c_scalar_barrett::<2, ONCE>( + &buf_qi_scaling[0], a.at(i), &rescaling_constants.0[i], b.at_mut(i), ); } } else { + r_last.add_scalar_inplace::(&q_level_half, &mut buf_q_scaling[0]); + for (i, r) in self.0[0..level].iter().enumerate() { + r_last.add_scalar::( + &buf_q_scaling[0], + &(r.modulus.q - r_last.modulus.barrett.reduce::(&q_level_half)), + &mut buf_qi_scaling[0], + ); + r.a_sub_b_mul_c_scalar_barrett::<2, ONCE>( + &buf_qi_scaling[0], + a.at(i), + &rescaling_constants.0[i], + b.at_mut(i), + ); + } } } @@ -246,24 +263,124 @@ impl RingRNS { let r_last: &Ring = &self.0[level]; let q_level_half: u64 = r_last.modulus.q >> 1; let rescaling_constants: ScalarRNS> = self.rescaling_constant(); - let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); + let (buf_q_scaling, buf_qi_scaling) = buf.0.split_at_mut(1); if NTT { - r_last.intt::(a.at(level), &mut buf_ntt_q_scaling[0]); - r_last.add_scalar_inplace::(&q_level_half, &mut buf_ntt_q_scaling[0]); + r_last.intt::(a.at(level), &mut buf_q_scaling[0]); + r_last.add_scalar_inplace::(&q_level_half, &mut buf_q_scaling[0]); for (i, r) in self.0[0..level].iter().enumerate() { - r_last.add_scalar::( - &buf_ntt_q_scaling[0], - &q_level_half, - &mut buf_ntt_qi_scaling[0], + r_last.add_scalar::( + &buf_q_scaling[0], + &(r.modulus.q - r_last.modulus.barrett.reduce::(&q_level_half)), + &mut buf_qi_scaling[0], ); - r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); - r.sum_aqqmb_prod_c_scalar_barrett_inplace::( - &buf_ntt_qi_scaling[0], + r.ntt_inplace::(&mut buf_qi_scaling[0]); + r.a_sub_b_mul_c_scalar_barrett_inplace::<2, ONCE>( + &buf_qi_scaling[0], + &rescaling_constants.0[i], + a.at_mut(i), + ); + } + } else { + r_last.add_scalar_inplace::(&q_level_half, &mut buf_q_scaling[0]); + for (i, r) in self.0[0..level].iter().enumerate() { + r_last.add_scalar::( + &buf_q_scaling[0], + &(r.modulus.q - r_last.modulus.barrett.reduce::(&q_level_half)), + &mut buf_qi_scaling[0], + ); + r.a_sub_b_mul_c_scalar_barrett_inplace::<2, ONCE>( + &buf_qi_scaling[0], &rescaling_constants.0[i], a.at_mut(i), ); } } } + + /// Updates b to round(a / prod_{level - nb_moduli}^{level} q[i]) + pub fn div_round_by_last_moduli( + &self, + nb_moduli: usize, + a: &PolyRNS, + buf: &mut PolyRNS, + c: &mut PolyRNS, + ) { + debug_assert!( + self.level() <= a.level(), + "invalid input a: self.level()={} > a.level()={}", + self.level(), + a.level() + ); + debug_assert!( + c.level() >= a.level() - 1, + "invalid input b: b.level()={} < a.level()-1={}", + c.level(), + a.level() - 1 + ); + debug_assert!( + nb_moduli <= a.level(), + "invalid input nb_moduli: nb_moduli={} > a.level()={}", + nb_moduli, + a.level() + ); + + if nb_moduli == 0 { + if a != c { + c.copy(a); + } + } else { + if NTT { + self.intt::(a, buf); + (0..nb_moduli).for_each(|i| { + self.at_level(self.level() - i) + .div_round_by_last_modulus_inplace::( + &mut PolyRNS::::default(), + buf, + ) + }); + self.at_level(self.level() - nb_moduli).ntt::(buf, c); + } else { + self.div_round_by_last_modulus::(a, buf, c); + (1..nb_moduli).for_each(|i| { + self.at_level(self.level() - i) + .div_round_by_last_modulus_inplace::(buf, c) + }); + } + } + } + + /// Updates a to round(a / prod_{level - nb_moduli}^{level} q[i]) + pub fn div_round_by_last_moduli_inplace( + &self, + nb_moduli: usize, + buf: &mut PolyRNS, + a: &mut PolyRNS, + ) { + debug_assert!( + self.level() <= a.level(), + "invalid input a: self.level()={} > a.level()={}", + self.level(), + a.level() + ); + debug_assert!( + nb_moduli <= a.level(), + "invalid input nb_moduli: nb_moduli={} > a.level()={}", + nb_moduli, + a.level() + ); + if NTT { + self.intt::(a, buf); + (0..nb_moduli).for_each(|i| { + self.at_level(self.level() - i) + .div_round_by_last_modulus_inplace::(&mut PolyRNS::::default(), buf) + }); + self.at_level(self.level() - nb_moduli).ntt::(buf, a); + } else { + (0..nb_moduli).for_each(|i| { + self.at_level(self.level() - i) + .div_round_by_last_modulus_inplace::(buf, a) + }); + } + } } diff --git a/math/src/ring/impl_u64/ring.rs b/math/src/ring/impl_u64/ring.rs index 2dbfc15..644fa90 100644 --- a/math/src/ring/impl_u64/ring.rs +++ b/math/src/ring/impl_u64/ring.rs @@ -92,9 +92,9 @@ impl Ring { } #[inline(always)] - pub fn add_scalar_inplace(&self, a: &u64, b: &mut Poly) { - debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); - self.modulus.sa_add_vb_into_vb::(a, &mut b.0); + pub fn add_scalar_inplace(&self, b: &u64, a: &mut Poly) { + debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); + self.modulus.va_add_sb_into_va::(b, &mut a.0); } #[inline(always)] @@ -106,33 +106,47 @@ impl Ring { } #[inline(always)] - pub fn sub_inplace(&self, a: &Poly, b: &mut Poly) { - debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); - debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); - self.modulus - .va_sub_vb_into_vb::(&a.0, &mut b.0); + pub fn add_scalar_then_mul_scalar_barrett_inplace(&self, b: &u64, c: &Barrett, a: &mut Poly) { + debug_assert!(a.n() == self.n(), "b.n()={} != n={}", a.n(), self.n()); + self.modulus.va_add_sb_mul_sc_into_va::(b, c, &mut a.0); } #[inline(always)] - pub fn sub(&self, a: &Poly, b: &Poly, c: &mut Poly) { + pub fn add_scalar_then_mul_scalar_barrett(&self, a: &Poly, b: &u64, c: &Barrett, d: &mut Poly) { + debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); + debug_assert!(d.n() == self.n(), "c.n()={} != n={}", d.n(), self.n()); + self.modulus + .va_add_sb_mul_sc_into_vd::(&a.0, b, c, &mut d.0); + } + + #[inline(always)] + pub fn sub_inplace(&self, a: &Poly, b: &mut Poly) { + debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); + debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); + self.modulus + .va_sub_vb_into_vb::(&a.0, &mut b.0); + } + + #[inline(always)] + pub fn sub(&self, a: &Poly, b: &Poly, c: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); self.modulus - .va_sub_vb_into_vc::(&a.0, &b.0, &mut c.0); + .va_sub_vb_into_vc::(&a.0, &b.0, &mut c.0); } #[inline(always)] - pub fn neg(&self, a: &Poly, b: &mut Poly) { + pub fn neg(&self, a: &Poly, b: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); - self.modulus.va_neg_into_vb::(&a.0, &mut b.0); + self.modulus.va_neg_into_vb::(&a.0, &mut b.0); } #[inline(always)] - pub fn neg_inplace(&self, a: &mut Poly) { + pub fn neg_inplace(&self, a: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); - self.modulus.va_neg_into_va::(&mut a.0); + self.modulus.va_neg_into_va::(&mut a.0); } #[inline(always)] @@ -208,7 +222,7 @@ impl Ring { } #[inline(always)] - pub fn sum_aqqmb_prod_c_scalar_barrett( + pub fn a_sub_b_mul_c_scalar_barrett( &self, a: &Poly, b: &Poly, @@ -219,11 +233,11 @@ impl Ring { debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(d.n() == self.n(), "d.n()={} != n={}", d.n(), self.n()); self.modulus - .va_sub_vb_mul_sc_into_vd::(&a.0, &b.0, c, &mut d.0); + .va_sub_vb_mul_sc_into_vd::(&a.0, &b.0, c, &mut d.0); } #[inline(always)] - pub fn sum_aqqmb_prod_c_scalar_barrett_inplace( + pub fn a_sub_b_mul_c_scalar_barrett_inplace( &self, a: &Poly, c: &Barrett, @@ -232,6 +246,6 @@ impl Ring { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus - .va_sub_vb_mul_sc_into_vb::(&a.0, c, &mut b.0); + .va_sub_vb_mul_sc_into_vb::(&a.0, c, &mut b.0); } } diff --git a/math/src/ring/impl_u64/ring_rns.rs b/math/src/ring/impl_u64/ring_rns.rs index 238b504..4549d8b 100644 --- a/math/src/ring/impl_u64/ring_rns.rs +++ b/math/src/ring/impl_u64/ring_rns.rs @@ -172,7 +172,7 @@ impl RingRNS { } #[inline(always)] - pub fn sub( + pub fn sub( &self, a: &PolyRNS, b: &PolyRNS, @@ -199,11 +199,11 @@ impl RingRNS { self.0 .iter() .enumerate() - .for_each(|(i, ring)| ring.sub::(&a.0[i], &b.0[i], &mut c.0[i])); + .for_each(|(i, ring)| ring.sub::(&a.0[i], &b.0[i], &mut c.0[i])); } #[inline(always)] - pub fn sub_inplace(&self, a: &PolyRNS, b: &mut PolyRNS) { + pub fn sub_inplace(&self, a: &PolyRNS, b: &mut PolyRNS) { debug_assert!( a.level() >= self.level(), "a.level()={} < self.level()={}", @@ -219,11 +219,11 @@ impl RingRNS { self.0 .iter() .enumerate() - .for_each(|(i, ring)| ring.sub_inplace::(&a.0[i], &mut b.0[i])); + .for_each(|(i, ring)| ring.sub_inplace::(&a.0[i], &mut b.0[i])); } #[inline(always)] - pub fn neg(&self, a: &PolyRNS, b: &mut PolyRNS) { + pub fn neg(&self, a: &PolyRNS, b: &mut PolyRNS) { debug_assert!( a.level() >= self.level(), "a.level()={} < self.level()={}", @@ -239,11 +239,11 @@ impl RingRNS { self.0 .iter() .enumerate() - .for_each(|(i, ring)| ring.neg::(&a.0[i], &mut b.0[i])); + .for_each(|(i, ring)| ring.neg::(&a.0[i], &mut b.0[i])); } #[inline(always)] - pub fn neg_inplace(&self, a: &mut PolyRNS) { + pub fn neg_inplace(&self, a: &mut PolyRNS) { debug_assert!( a.level() >= self.level(), "a.level()={} < self.level()={}", @@ -253,7 +253,7 @@ impl RingRNS { self.0 .iter() .enumerate() - .for_each(|(i, ring)| ring.neg_inplace::(&mut a.0[i])); + .for_each(|(i, ring)| ring.neg_inplace::(&mut a.0[i])); } #[inline(always)] diff --git a/math/tests/rescaling_rns.rs b/math/tests/rescaling_rns.rs index fea7744..5dae27b 100644 --- a/math/tests/rescaling_rns.rs +++ b/math/tests/rescaling_rns.rs @@ -7,11 +7,17 @@ use sampling::source::Source; #[test] fn rescaling_rns_u64() { let n = 1 << 10; - let moduli: Vec = vec![0x1fffffffffc80001u64, 0x1fffffffffe00001u64]; + let moduli: Vec = vec![0x1fffffffffc80001u64, 0x1fffffffffe00001u64, 0x1fffffffffb40001, 0x1fffffffff500001]; let ring_rns: RingRNS = RingRNS::new(n, moduli); - test_div_floor_by_last_modulus::(&ring_rns); - test_div_floor_by_last_modulus::(&ring_rns); + //test_div_floor_by_last_modulus::(&ring_rns); + //test_div_floor_by_last_modulus::(&ring_rns); + test_div_floor_by_last_modulus_inplace::(&ring_rns); + //test_div_floor_by_last_modulus_inplace::(&ring_rns); + //test_div_floor_by_last_moduli::(&ring_rns); + //test_div_floor_by_last_moduli::(&ring_rns); + //test_div_floor_by_last_moduli_inplace::(&ring_rns); + //test_div_floor_by_last_moduli_inplace::(&ring_rns); } fn test_div_floor_by_last_modulus(ring_rns: &RingRNS) { @@ -58,5 +64,155 @@ fn test_div_floor_by_last_modulus(ring_rns: &RingRNS) { } }); - assert!(coeffs_a == coeffs_c); + assert!(coeffs_a == coeffs_c, "test_div_floor_by_last_modulus"); } + +fn test_div_floor_by_last_modulus_inplace(ring_rns: &RingRNS) { + let seed: [u8; 32] = [0; 32]; + let mut source: Source = Source::new(seed); + + let mut a: PolyRNS = ring_rns.new_polyrns(); + let mut b: PolyRNS = ring_rns.new_polyrns(); + + // Allocates a random PolyRNS + ring_rns.fill_uniform(&mut source, &mut a); + + // Maps PolyRNS to [BigInt] + let mut coeffs_a: Vec = (0..a.n()).map(|i| BigInt::from(i)).collect(); + ring_rns + .at_level(a.level()) + .to_bigint_inplace(&a, 1, &mut coeffs_a); + + println!("{:?}", &coeffs_a[..8]); + + // Performs c = intt(ntt(a) / q_level) + if NTT { + ring_rns.ntt_inplace::(&mut a); + } + + ring_rns.div_floor_by_last_modulus_inplace::(&mut b, &mut a); + + if NTT { + ring_rns.at_level(a.level()-1).intt_inplace::(&mut a); + } + + // Exports c to coeffs_c + let mut coeffs_c = vec![BigInt::from(0); a.n()]; + ring_rns + .at_level(a.level()-1) + .to_bigint_inplace(&a, 1, &mut coeffs_c); + + // Performs floor division on a + let scalar_big = BigInt::from(ring_rns.0[ring_rns.level()].modulus.q); + coeffs_a.iter_mut().for_each(|a| { + // Emulates floor division in [0, q-1] and maps to [-(q-1)/2, (q-1)/2-1] + *a /= &scalar_big; + if a.sign() == Sign::Minus { + *a -= 1; + } + }); + + println!("{:?}", &coeffs_a[..8]); + println!("{:?}", &coeffs_c[..8]); + + assert!(coeffs_a == coeffs_c, "test_div_floor_by_last_modulus_inplace"); +} + +fn test_div_floor_by_last_moduli(ring_rns: &RingRNS) { + let seed: [u8; 32] = [0; 32]; + let mut source: Source = Source::new(seed); + + let nb_moduli: usize = ring_rns.level()-1; + + let mut a: PolyRNS = ring_rns.new_polyrns(); + let mut b: PolyRNS = ring_rns.new_polyrns(); + let mut c: PolyRNS = ring_rns.at_level(ring_rns.level() - 1).new_polyrns(); + + // Allocates a random PolyRNS + ring_rns.fill_uniform(&mut source, &mut a); + + // Maps PolyRNS to [BigInt] + let mut coeffs_a: Vec = (0..a.n()).map(|i| BigInt::from(i)).collect(); + ring_rns + .at_level(a.level()) + .to_bigint_inplace(&a, 1, &mut coeffs_a); + + // Performs c = intt(ntt(a) / q_level) + if NTT { + ring_rns.ntt_inplace::(&mut a); + } + + ring_rns.div_floor_by_last_moduli::(nb_moduli, &a, &mut b, &mut c); + + if NTT { + ring_rns.at_level(c.level()).intt_inplace::(&mut c); + } + + // Exports c to coeffs_c + let mut coeffs_c = vec![BigInt::from(0); c.n()]; + ring_rns + .at_level(c.level()) + .to_bigint_inplace(&c, 1, &mut coeffs_c); + + // Performs floor division on a + let mut scalar_big = BigInt::from(1); + (0..nb_moduli).for_each(|i|{scalar_big *= BigInt::from(ring_rns.0[ring_rns.level()].modulus.q)}); + coeffs_a.iter_mut().for_each(|a| { + // Emulates floor division in [0, q-1] and maps to [-(q-1)/2, (q-1)/2-1] + *a /= &scalar_big; + if a.sign() == Sign::Minus { + *a -= 1; + } + }); + + assert!(coeffs_a == coeffs_c, "test_div_floor_by_last_moduli"); +} + +fn test_div_floor_by_last_moduli_inplace(ring_rns: &RingRNS) { + let seed: [u8; 32] = [0; 32]; + let mut source: Source = Source::new(seed); + + let nb_moduli: usize = ring_rns.level()-1; + + let mut a: PolyRNS = ring_rns.new_polyrns(); + let mut b: PolyRNS = ring_rns.new_polyrns(); + + // Allocates a random PolyRNS + ring_rns.fill_uniform(&mut source, &mut a); + + // Maps PolyRNS to [BigInt] + let mut coeffs_a: Vec = (0..a.n()).map(|i| BigInt::from(i)).collect(); + ring_rns + .at_level(a.level()) + .to_bigint_inplace(&a, 1, &mut coeffs_a); + + // Performs c = intt(ntt(a) / q_level) + if NTT { + ring_rns.ntt_inplace::(&mut a); + } + + ring_rns.div_floor_by_last_moduli_inplace::(nb_moduli, &mut b, &mut a); + + if NTT { + ring_rns.at_level(a.level()-nb_moduli).intt_inplace::(&mut a); + } + + // Exports c to coeffs_c + let mut coeffs_c = vec![BigInt::from(0); a.n()]; + ring_rns + .at_level(a.level()-nb_moduli) + .to_bigint_inplace(&a, 1, &mut coeffs_c); + + // Performs floor division on a + let mut scalar_big = BigInt::from(1); + (0..nb_moduli).for_each(|i|{scalar_big *= BigInt::from(ring_rns.0[ring_rns.level()].modulus.q)}); + coeffs_a.iter_mut().for_each(|a| { + // Emulates floor division in [0, q-1] and maps to [-(q-1)/2, (q-1)/2-1] + *a /= &scalar_big; + if a.sign() == Sign::Minus { + *a -= 1; + } + }); + + assert!(coeffs_a == coeffs_c, "test_div_floor_by_last_moduli_inplace"); +} \ No newline at end of file