From 78cc0514ec3d68f6cdbafcf611eaa3ad97011776 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sat, 4 Jan 2025 23:35:35 +0100 Subject: [PATCH] wip --- math/src/dft/ntt.rs | 11 +-- math/src/modulus/barrett.rs | 2 + math/src/modulus/impl_u64/barrett.rs | 11 ++- math/src/poly.rs | 8 +- math/src/ring.rs | 3 +- math/src/ring/impl_u64/rescaling_rns.rs | 51 +++++++---- math/src/ring/impl_u64/ring.rs | 23 +++-- math/src/ring/impl_u64/ring_rns.rs | 111 ++++++++++++++---------- 8 files changed, 137 insertions(+), 83 deletions(-) diff --git a/math/src/dft/ntt.rs b/math/src/dft/ntt.rs index 8a2a49f..7bd2702 100644 --- a/math/src/dft/ntt.rs +++ b/math/src/dft/ntt.rs @@ -150,8 +150,9 @@ impl Table{ debug_assert!(*b < self.four_q, "b:{} q:{}", b, self.four_q); a.reduce_once_assign(self.two_q); let bt: u64 = self.prime.barrett.mul_external::(t, *b); - *b = a.wrapping_add(self.two_q-bt); - *a = a.wrapping_add(bt); + debug_assert!(bt < self.two_q, "bt:{} two_q:{}", bt, self.two_q); + *b = *a + self.two_q-bt; + *a += bt; if !LAZY { a.reduce_once_assign(self.two_q); b.reduce_once_assign(self.two_q); @@ -223,10 +224,10 @@ impl Table{ #[inline(always)] fn dif_inplace(&self, a: &mut u64, b: &mut u64, t: Barrett) { - debug_assert!(*a < self.two_q, "a:{} q:{}", a, self.four_q); - debug_assert!(*b < self.two_q, "b:{} q:{}", b, self.four_q); + debug_assert!(*a < self.two_q, "a:{} q:{}", a, self.two_q); + debug_assert!(*b < self.two_q, "b:{} q:{}", b, self.two_q); let d: u64 = self.prime.barrett.mul_external::(t, *a + self.two_q - *b); - *a = a.wrapping_add(*b); + *a = *a + *b; a.reduce_once_assign(self.two_q); *b = d; if !LAZY { diff --git a/math/src/modulus/barrett.rs b/math/src/modulus/barrett.rs index efaee2e..fea91f1 100644 --- a/math/src/modulus/barrett.rs +++ b/math/src/modulus/barrett.rs @@ -14,6 +14,8 @@ impl Barrett { } } +pub struct BarrettRNS(pub Vec>); + #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct BarrettPrecomp{ pub q: O, diff --git a/math/src/modulus/impl_u64/barrett.rs b/math/src/modulus/impl_u64/barrett.rs index 7a0141f..0cec9f6 100644 --- a/math/src/modulus/impl_u64/barrett.rs +++ b/math/src/modulus/impl_u64/barrett.rs @@ -21,10 +21,6 @@ impl BarrettPrecomp{ self.one } - /// Applies a modular reduction on x based on REDUCE: - /// - LAZY: no modular reduction. - /// - ONCE: subtracts q if x >= q. - /// - FULL: maps x to x mod q using Barrett reduction. #[inline(always)] pub fn reduce_assign(&self, x: &mut u64){ match REDUCE { @@ -45,6 +41,13 @@ impl BarrettPrecomp{ } } + #[inline(always)] + pub fn reduce(&self, x: &u64) -> u64{ + let mut r = *x; + self.reduce_assign::(&mut r); + r + } + #[inline(always)] pub fn prepare(&self, v: u64) -> Barrett { debug_assert!(v < self.q); diff --git a/math/src/poly.rs b/math/src/poly.rs index 2d63277..286557c 100644 --- a/math/src/poly.rs +++ b/math/src/poly.rs @@ -20,12 +20,12 @@ impl Polywhere self.0 = Vec::from(&buf[..n]); } - pub fn n(&self) -> usize{ - (usize::BITS - self.0.len().leading_zeros()) as usize + pub fn log_n(&self) -> usize{ + (usize::BITS - (self.n()-1).leading_zeros()) as usize } - pub fn log_n(&self) -> usize{ - self.0.len()-1 + pub fn n(&self) -> usize{ + self.0.len() } pub fn resize(&mut self, n:usize){ diff --git a/math/src/ring.rs b/math/src/ring.rs index b62f9a9..d9aaffe 100644 --- a/math/src/ring.rs +++ b/math/src/ring.rs @@ -2,7 +2,6 @@ pub mod impl_u64; use crate::modulus::prime::Prime; use crate::poly::{Poly, PolyRNS}; -use num_bigint::BigInt; use crate::dft::DFT; @@ -37,7 +36,7 @@ impl RingRNS<'_, O> { pub fn max_level(&self) -> usize{ self.0.len()-1 } - + pub fn level(&self) -> usize{ self.0.len()-1 } diff --git a/math/src/ring/impl_u64/rescaling_rns.rs b/math/src/ring/impl_u64/rescaling_rns.rs index 3634e00..0f9387b 100644 --- a/math/src/ring/impl_u64/rescaling_rns.rs +++ b/math/src/ring/impl_u64/rescaling_rns.rs @@ -1,6 +1,6 @@ use crate::ring::RingRNS; use crate::poly::{Poly, PolyRNS}; -use crate::modulus::barrett::Barrett; +use crate::modulus::barrett::BarrettRNS; use crate::modulus::ONCE; extern crate test; @@ -11,12 +11,12 @@ impl RingRNS<'_, u64>{ pub fn div_floor_by_last_modulus_ntt(&self, a: &PolyRNS, buf: &mut PolyRNS, b: &mut PolyRNS){ assert!(b.level() >= a.level()-1, "invalid input b: b.level()={} < a.level()-1={}", b.level(), a.level()-1); let level = self.level(); - self.0[level].intt::(a.at(level), buf.at_mut(0)); - let rescaling_constants: Vec> = self.rescaling_constant(); + self.0[level].intt::(a.at(level), buf.at_mut(0)); + let rescaling_constants: BarrettRNS = self.rescaling_constant(); let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); for (i, r) in self.0[0..level].iter().enumerate(){ - r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); - r.sum_aqqmb_prod_c_scalar_barrett::(&buf_ntt_qi_scaling[0], a.at(i), &rescaling_constants[i], b.at_mut(i)); + r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); + r.sum_aqqmb_prod_c_scalar_barrett::(&buf_ntt_qi_scaling[0], a.at(i), &rescaling_constants.0[i], b.at_mut(i)); } } @@ -25,11 +25,11 @@ impl RingRNS<'_, u64>{ pub fn div_floor_by_last_modulus_ntt_inplace(&self, buf: &mut PolyRNS, b: &mut PolyRNS){ let level = self.level(); self.0[level].intt::(b.at(level), buf.at_mut(0)); - let rescaling_constants: Vec> = self.rescaling_constant(); + let rescaling_constants: BarrettRNS = self.rescaling_constant(); let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); for (i, r) in self.0[0..level].iter().enumerate(){ r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); - r.sum_aqqmb_prod_c_scalar_barrett_inplace::(&buf_ntt_qi_scaling[0], &rescaling_constants[i], b.at_mut(i)); + r.sum_aqqmb_prod_c_scalar_barrett_inplace::(&buf_ntt_qi_scaling[0], &rescaling_constants.0[i], b.at_mut(i)); } } @@ -37,19 +37,19 @@ impl RingRNS<'_, u64>{ pub fn div_floor_by_last_modulus(&self, a: &PolyRNS, b: &mut PolyRNS){ assert!(b.level() >= a.level()-1, "invalid input b: b.level()={} < a.level()-1={}", b.level(), a.level()-1); let level = self.level(); - let rescaling_constants: Vec> = self.rescaling_constant(); + let rescaling_constants:crate::modulus::barrett::BarrettRNS = self.rescaling_constant(); for (i, r) in self.0[0..level].iter().enumerate(){ - r.sum_aqqmb_prod_c_scalar_barrett::(a.at(level), a.at(i), &rescaling_constants[i], b.at_mut(i)); + r.sum_aqqmb_prod_c_scalar_barrett::(a.at(level), a.at(i), &rescaling_constants.0[i], b.at_mut(i)); } } /// Updates a to floor(b / q[b.level()]). pub fn div_floor_by_last_modulus_inplace(&self, a: &mut PolyRNS){ let level = self.level(); - let rescaling_constants: Vec> = self.rescaling_constant(); + let rescaling_constants: BarrettRNS = self.rescaling_constant(); let (a_i, a_level) = a.split_at_mut(level); for (i, r) in self.0[0..level].iter().enumerate(){ - r.sum_aqqmb_prod_c_scalar_barrett_inplace::(&a_level[0], &rescaling_constants[i], &mut a_i[i]); + r.sum_aqqmb_prod_c_scalar_barrett_inplace::(&a_level[0], &rescaling_constants.0[i], &mut a_i[i]); } } @@ -76,6 +76,7 @@ impl RingRNS<'_, u64>{ #[cfg(test)] mod tests { + use num_bigint::BigInt; use crate::ring::Ring; use crate::ring::impl_u64::ring_rns::new_rings; use super::*; @@ -83,16 +84,36 @@ mod tests { #[test] fn test_div_floor_by_last_modulus_ntt() { let n = 1<<10; - let moduli: Vec = vec![0x1fffffffffe00001u64, 0x1fffffffffc80001u64]; + let moduli: Vec = vec![0x1fffffffffc80001u64, 0x1fffffffffe00001u64]; let rings: Vec> = new_rings(n, moduli); let ring_rns = RingRNS::new(&rings); - let a: PolyRNS = ring_rns.new_polyrns(); + let mut a: PolyRNS = ring_rns.new_polyrns(); let mut b: PolyRNS = ring_rns.new_polyrns(); - let mut c: PolyRNS = ring_rns.new_polyrns(); + let mut c: PolyRNS = ring_rns.at_level(ring_rns.level()-1).new_polyrns(); + // Allocates an rns poly with values [0..n] + let mut coeffs_a: Vec = (0..n).map(|i|{BigInt::from(i)}).collect(); + ring_rns.from_bigint_inplace(&coeffs_a, 1, &mut a); + + // Scales by q_level both a and coeffs_a + let scalar: u64 = ring_rns.0[ring_rns.level()].modulus.q; + ring_rns.mul_scalar_inplace::(&scalar, &mut a); + let scalar_big = BigInt::from(scalar); + coeffs_a.iter_mut().for_each(|a|{*a *= &scalar_big}); + + // Performs c = intt(ntt(a) / q_level) + ring_rns.ntt_inplace::(&mut a); ring_rns.div_floor_by_last_modulus_ntt(&a, &mut b, &mut c); + ring_rns.at_level(c.level()).intt_inplace::(&mut c); - //assert!(m_precomp.mul_external::(y_mont, x) == (x as u128 * y as u128 % q as u128) as u64); + // Exports c to coeffs_c + let mut coeffs_c = vec![BigInt::from(0);c.n()]; + ring_rns.at_level(c.level()).to_bigint_inplace(&c, 1, &mut coeffs_c); + + // Performs floor division on a + coeffs_a.iter_mut().for_each(|a|{*a /= &scalar_big}); + + assert!(coeffs_a == coeffs_c); } } \ No newline at end of file diff --git a/math/src/ring/impl_u64/ring.rs b/math/src/ring/impl_u64/ring.rs index b04d901..0b8cfa0 100644 --- a/math/src/ring/impl_u64/ring.rs +++ b/math/src/ring/impl_u64/ring.rs @@ -4,7 +4,7 @@ use crate::modulus::prime::Prime; use crate::modulus::montgomery::Montgomery; use crate::modulus::barrett::Barrett; use crate::poly::Poly; -use crate::modulus::REDUCEMOD; +use crate::modulus::{REDUCEMOD, BARRETT}; use crate::modulus::VecOperations; use num_bigint::BigInt; use num_traits::ToPrimitive; @@ -38,16 +38,16 @@ impl Ring{ pub fn intt_inplace(&self, poly: &mut Poly){ match LAZY{ - true => self.dft.forward_inplace_lazy(&mut poly.0), - false => self.dft.forward_inplace(&mut poly.0) + true => self.dft.backward_inplace_lazy(&mut poly.0), + false => self.dft.backward_inplace(&mut poly.0) } } pub fn ntt(&self, poly_in: &Poly, poly_out: &mut Poly){ poly_out.0.copy_from_slice(&poly_in.0); match LAZY{ - true => self.dft.backward_inplace_lazy(&mut poly_out.0), - false => self.dft.backward_inplace(&mut poly_out.0) + true => self.dft.forward_inplace_lazy(&mut poly_out.0), + false => self.dft.forward_inplace(&mut poly_out.0) } } @@ -120,6 +120,19 @@ impl Ring{ self.modulus.vec_mul_montgomery_external_unary_assign::(&a.0, &mut b.0); } + #[inline(always)] + pub fn mul_scalar(&self, a:&Poly, b: &u64, c:&mut Poly){ + debug_assert!(a.n() == self.n(), "b.n()={} != n={}", a.n(), self.n()); + debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); + self.modulus.vec_mul_scalar_barrett_external_binary_assign::(&self.modulus.barrett.prepare(*b), &a.0, &mut c.0); + } + + #[inline(always)] + pub fn mul_scalar_inplace(&self, a:&u64, b:&mut Poly){ + debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); + self.modulus.vec_mul_scalar_barrett_external_unary_assign::(&self.modulus.barrett.prepare(self.modulus.barrett.reduce::(a)), &mut b.0); + } + #[inline(always)] pub fn mul_scalar_barrett_inplace(&self, a:&Barrett, b:&mut Poly){ debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); diff --git a/math/src/ring/impl_u64/ring_rns.rs b/math/src/ring/impl_u64/ring_rns.rs index dc82615..3ece9b1 100644 --- a/math/src/ring/impl_u64/ring_rns.rs +++ b/math/src/ring/impl_u64/ring_rns.rs @@ -1,7 +1,7 @@ use crate::ring::{Ring, RingRNS}; use crate::poly::PolyRNS; use crate::modulus::montgomery::Montgomery; -use crate::modulus::barrett::Barrett; +use crate::modulus::barrett::BarrettRNS; use crate::modulus::REDUCEMOD; use num_bigint::BigInt; @@ -25,19 +25,19 @@ impl<'a> RingRNS<'a, u64>{ modulus } - pub fn rescaling_constant(&self) -> Vec> { + pub fn rescaling_constant(&self) -> BarrettRNS { let level = self.level(); let q_scale: u64 = self.0[level].modulus.q; - (0..level).map(|i| {self.0[i].modulus.barrett.prepare(self.0[i].modulus.q - self.0[i].modulus.inv(q_scale))}).collect() + BarrettRNS((0..level).map(|i| {self.0[i].modulus.barrett.prepare(self.0[i].modulus.q - self.0[i].modulus.inv(q_scale))}).collect()) } - pub fn set_poly_from_bigint(&self, coeffs: &[BigInt], step:usize, a: &mut PolyRNS){ + pub fn from_bigint_inplace(&self, coeffs: &[BigInt], step:usize, a: &mut PolyRNS){ let level = self.level(); assert!(level <= a.level(), "invalid level: level={} > a.level()={}", level, a.level()); (0..level).for_each(|i|{self.0[i].from_bigint(coeffs, step, a.at_mut(i))}); } - pub fn set_bigint_from_poly(&self, a: &PolyRNS, step: usize, coeffs: &mut [BigInt]){ + pub fn to_bigint_inplace(&self, a: &PolyRNS, step: usize, coeffs: &mut [BigInt]){ assert!(step <= a.n(), "invalid step: step={} > a.n()={}", step, a.n()); assert!(coeffs.len() <= a.n() / step, "invalid coeffs: coeffs.len()={} > a.n()/step={}", coeffs.len(), a.n()/step); @@ -47,7 +47,7 @@ impl<'a> RingRNS<'a, u64>{ inv_crt.iter_mut().enumerate().for_each(|(i, a)|{ let qi_big = BigInt::from(self.0[i].modulus.q); - *a = (&q_big / &qi_big); + *a = &q_big / &qi_big; *a *= a.modinv(&qi_big).unwrap(); }); @@ -64,79 +64,94 @@ impl<'a> RingRNS<'a, u64>{ } } +impl RingRNS<'_, u64>{ + pub fn ntt_inplace(&self, a: &mut PolyRNS){ + self.0.iter().enumerate().for_each(|(i, ring)| ring.ntt_inplace::(&mut a.0[i])); + } + + pub fn intt_inplace(&self, a: &mut PolyRNS){ + self.0.iter().enumerate().for_each(|(i, ring)| ring.intt_inplace::(&mut a.0[i])); + } + + pub fn ntt(&self, a: &PolyRNS, b: &mut PolyRNS){ + self.0.iter().enumerate().for_each(|(i, ring)| ring.ntt::(&a.0[i], &mut b.0[i])); + } + + pub fn intt(&self, a: &PolyRNS, b: &mut PolyRNS){ + self.0.iter().enumerate().for_each(|(i, ring)| ring.intt::(&a.0[i], &mut b.0[i])); + } +} + impl RingRNS<'_, u64>{ #[inline(always)] pub fn add(&self, a: &PolyRNS, b: &PolyRNS, c: &mut PolyRNS){ - let level: usize = self.level(); - debug_assert!(self.max_level() <= level, "max_level={} < level={}", self.max_level(), level); - debug_assert!(a.level() >= level, "a.level()={} < level={}", a.level(), level); - debug_assert!(b.level() >= level, "b.level()={} < level={}", b.level(), level); - debug_assert!(c.level() >= level, "c.level()={} < level={}", c.level(), level); - self.0.iter().take(level + 1).enumerate().for_each(|(i, ring)| ring.add::(&a.0[i], &b.0[i], &mut c.0[i])); + debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); + debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); + debug_assert!(c.level() >= self.level(), "c.level()={} < self.level()={}", c.level(), self.level()); + self.0.iter().enumerate().for_each(|(i, ring)| ring.add::(&a.0[i], &b.0[i], &mut c.0[i])); } #[inline(always)] pub fn add_inplace(&self, a: &PolyRNS, b: &mut PolyRNS){ - let level: usize = self.level(); - debug_assert!(self.max_level() <= level, "max_level={} < level={}", self.max_level(), level); - debug_assert!(a.level() >= level, "a.level()={} < level={}", a.level(), level); - debug_assert!(b.level() >= level, "b.level()={} < level={}", b.level(), level); - self.0.iter().take(level + 1).enumerate().for_each(|(i, ring)| ring.add_inplace::(&a.0[i], &mut b.0[i])); + debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); + debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); + self.0.iter().enumerate().for_each(|(i, ring)| ring.add_inplace::(&a.0[i], &mut b.0[i])); } #[inline(always)] pub fn sub(&self, a: &PolyRNS, b: &PolyRNS, c: &mut PolyRNS){ - let level: usize = self.level(); - debug_assert!(self.max_level() <= level, "max_level={} < level={}", self.max_level(), level); - debug_assert!(a.level() >= level, "a.level()={} < level={}", a.level(), level); - debug_assert!(b.level() >= level, "b.level()={} < level={}", b.level(), level); - debug_assert!(c.level() >= level, "c.level()={} < level={}", c.level(), level); - self.0.iter().take(level + 1).enumerate().for_each(|(i, ring)| ring.sub::(&a.0[i], &b.0[i], &mut c.0[i])); + debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); + debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); + debug_assert!(c.level() >= self.level(), "c.level()={} < self.level()={}", c.level(), self.level()); + self.0.iter().enumerate().for_each(|(i, ring)| ring.sub::(&a.0[i], &b.0[i], &mut c.0[i])); } #[inline(always)] pub fn sub_inplace(&self, a: &PolyRNS, b: &mut PolyRNS){ - let level: usize = self.level(); - debug_assert!(self.max_level() <= level, "max_level={} < level={}", self.max_level(), level); - debug_assert!(a.level() >= level, "a.level()={} < level={}", a.level(), level); - debug_assert!(b.level() >= level, "b.level()={} < level={}", b.level(), level); - self.0.iter().take(level + 1).enumerate().for_each(|(i, ring)| ring.sub_inplace::(&a.0[i], &mut b.0[i])); + debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); + debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); + self.0.iter().enumerate().for_each(|(i, ring)| ring.sub_inplace::(&a.0[i], &mut b.0[i])); } #[inline(always)] pub fn neg(&self, a: &PolyRNS, b: &mut PolyRNS){ - let level: usize = self.level(); - debug_assert!(self.max_level() <= level, "max_level={} < level={}", self.max_level(), level); - debug_assert!(a.level() >= level, "a.level()={} < level={}", a.level(), level); - debug_assert!(b.level() >= level, "b.level()={} < level={}", b.level(), level); - self.0.iter().take(level + 1).enumerate().for_each(|(i, ring)| ring.neg::(&a.0[i], &mut b.0[i])); + debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); + debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); + self.0.iter().enumerate().for_each(|(i, ring)| ring.neg::(&a.0[i], &mut b.0[i])); } #[inline(always)] pub fn neg_inplace(&self, a: &mut PolyRNS){ - let level: usize = self.level(); - debug_assert!(self.max_level() <= level, "max_level={} < level={}", self.max_level(), level); - debug_assert!(a.level() >= level, "a.level()={} < level={}", a.level(), level); - self.0.iter().take(level + 1).enumerate().for_each(|(i, ring)| ring.neg_inplace::(&mut a.0[i])); + debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); + self.0.iter().enumerate().for_each(|(i, ring)| ring.neg_inplace::(&mut a.0[i])); } #[inline(always)] pub fn mul_montgomery_external(&self, a:&PolyRNS>, b:&PolyRNS, c: &mut PolyRNS){ - let level: usize = self.level(); - debug_assert!(self.max_level() <= level, "max_level={} < level={}", self.max_level(), level); - debug_assert!(a.level() >= level, "a.level()={} < level={}", a.level(), level); - debug_assert!(b.level() >= level, "b.level()={} < level={}", b.level(), level); - debug_assert!(c.level() >= level, "c.level()={} < level={}", c.level(), level); - self.0.iter().take(level + 1).enumerate().for_each(|(i, ring)| ring.mul_montgomery_external::(&a.0[i], &b.0[i], &mut c.0[i])); + debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); + debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); + debug_assert!(c.level() >= self.level(), "c.level()={} < self.level()={}", c.level(), self.level()); + self.0.iter().enumerate().for_each(|(i, ring)| ring.mul_montgomery_external::(&a.0[i], &b.0[i], &mut c.0[i])); } #[inline(always)] pub fn mul_montgomery_external_inplace(&self, a:&PolyRNS>, b:&mut PolyRNS){ - let level: usize = self.level(); - debug_assert!(self.max_level() <= level, "max_level={} < level={}", self.max_level(), level); - debug_assert!(a.level() >= level, "a.level()={} < level={}", a.level(), level); - debug_assert!(b.level() >= level, "b.level()={} < level={}", b.level(), level); - self.0.iter().take(level + 1).enumerate().for_each(|(i, ring)| ring.mul_montgomery_external_inplace::(&a.0[i], &mut b.0[i])); + debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); + debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); + self.0.iter().enumerate().for_each(|(i, ring)| ring.mul_montgomery_external_inplace::(&a.0[i], &mut b.0[i])); + } + + #[inline(always)] + pub fn mul_scalar(&self, a: &PolyRNS, b: &u64, c: &mut PolyRNS){ + debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); + debug_assert!(c.level() >= self.level(), "b.level()={} < self.level()={}", c.level(), self.level()); + self.0.iter().enumerate().for_each(|(i, ring)| ring.mul_scalar::(&a.0[i], b, &mut c.0[i])); + } + + #[inline(always)] + pub fn mul_scalar_inplace(&self, a: &u64, b: &mut PolyRNS){ + debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); + self.0.iter().enumerate().for_each(|(i, ring)| ring.mul_scalar_inplace::(a, &mut b.0[i])); } } \ No newline at end of file