From c69bd6985ac7280cea7e4dac12355044fc1ea4a1 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 6 Jan 2025 14:40:03 +0100 Subject: [PATCH] refactored RingRNS --- math/benches/ring_rns.rs | 9 ++++----- math/benches/sampling.rs | 9 ++++----- math/src/dft/ntt.rs | 2 +- math/src/ring.rs | 7 ++++--- math/src/ring/impl_u64/rescaling_rns.rs | 2 +- math/src/ring/impl_u64/ring_rns.rs | 25 +++++++++++-------------- math/src/ring/impl_u64/sampling.rs | 2 +- math/tests/rescaling_rns.rs | 6 ++---- 8 files changed, 28 insertions(+), 34 deletions(-) diff --git a/math/benches/ring_rns.rs b/math/benches/ring_rns.rs index 52279ef..aee46e1 100644 --- a/math/benches/ring_rns.rs +++ b/math/benches/ring_rns.rs @@ -1,10 +1,9 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use math::poly::PolyRNS; -use math::ring::impl_u64::ring_rns::new_rings; -use math::ring::{Ring, RingRNS}; +use math::ring::RingRNS; fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) { - fn runner(r: RingRNS) -> Box { + fn runner(r: RingRNS) -> Box { let a: PolyRNS = r.new_polyrns(); let mut b: PolyRNS = r.new_polyrns(); let mut c: PolyRNS = r.new_polyrns(); @@ -22,8 +21,8 @@ fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) { 0x1fffffffffb40001, 0x1fffffffff500001, ]; - let rings: Vec> = new_rings(n, moduli); - let ring_rns: RingRNS<'_, u64> = RingRNS::new(&rings); + + let ring_rns: RingRNS = RingRNS::new(n, moduli); let runners = [(format!("prime/n={}/level={}", n, ring_rns.level()), { runner(ring_rns) diff --git a/math/benches/sampling.rs b/math/benches/sampling.rs index a695c0b..6fd8e6d 100644 --- a/math/benches/sampling.rs +++ b/math/benches/sampling.rs @@ -1,11 +1,10 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use math::poly::PolyRNS; -use math::ring::impl_u64::ring_rns::new_rings; -use math::ring::{Ring, RingRNS}; +use math::ring::RingRNS; use sampling::source::Source; fn fill_uniform(c: &mut Criterion) { - fn runner(r: RingRNS) -> Box { + fn runner(r: RingRNS) -> Box { let mut a: PolyRNS = r.new_polyrns(); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); @@ -25,8 +24,8 @@ fn fill_uniform(c: &mut Criterion) { 0x1fffffffffb40001, 0x1fffffffff500001, ]; - let rings: Vec> = new_rings(n, moduli); - let ring_rns: RingRNS<'_, u64> = RingRNS::new(&rings); + + let ring_rns: RingRNS = RingRNS::new(n, moduli); let runners = [(format!("prime/n={}/level={}", n, ring_rns.level()), { runner(ring_rns) diff --git a/math/src/dft/ntt.rs b/math/src/dft/ntt.rs index db0dc67..5612c1e 100644 --- a/math/src/dft/ntt.rs +++ b/math/src/dft/ntt.rs @@ -19,7 +19,7 @@ pub struct Table { } impl Table { - pub fn new(prime: Prime, nth_root: u64) -> Self { + pub fn new(prime: Prime, nth_root: u64) -> Table { assert!( nth_root & (nth_root - 1) == 0, "invalid argument: nth_root = {} is not a power of two", diff --git a/math/src/ring.rs b/math/src/ring.rs index 2771e19..a335421 100644 --- a/math/src/ring.rs +++ b/math/src/ring.rs @@ -4,6 +4,7 @@ use crate::dft::DFT; use crate::modulus::prime::Prime; use crate::poly::{Poly, PolyRNS}; use num::traits::Unsigned; +use std::sync::Arc; pub struct Ring { pub n: usize, @@ -21,9 +22,9 @@ impl Ring { } } -pub struct RingRNS<'a, O: Unsigned>(pub &'a [Ring]); +pub struct RingRNS(pub Vec>>); -impl RingRNS<'_, O> { +impl RingRNS { pub fn n(&self) -> usize { self.0[0].n() } @@ -42,6 +43,6 @@ impl RingRNS<'_, O> { pub fn at_level(&self, level: usize) -> RingRNS { assert!(level <= self.0.len()); - RingRNS(&self.0[..level + 1]) + RingRNS(self.0[..level + 1].to_vec()) } } diff --git a/math/src/ring/impl_u64/rescaling_rns.rs b/math/src/ring/impl_u64/rescaling_rns.rs index 6bc1d5f..3296133 100644 --- a/math/src/ring/impl_u64/rescaling_rns.rs +++ b/math/src/ring/impl_u64/rescaling_rns.rs @@ -6,7 +6,7 @@ use crate::ring::RingRNS; use crate::scalar::ScalarRNS; extern crate test; -impl RingRNS<'_, u64> { +impl RingRNS { /// Updates b to floor(a / q[b.level()]). pub fn div_floor_by_last_modulus( &self, diff --git a/math/src/ring/impl_u64/ring_rns.rs b/math/src/ring/impl_u64/ring_rns.rs index e8cbd69..238b504 100644 --- a/math/src/ring/impl_u64/ring_rns.rs +++ b/math/src/ring/impl_u64/ring_rns.rs @@ -5,19 +5,16 @@ use crate::poly::PolyRNS; use crate::ring::{Ring, RingRNS}; use crate::scalar::ScalarRNS; use num_bigint::BigInt; +use std::sync::Arc; -pub fn new_rings(n: usize, moduli: Vec) -> Vec> { - assert!(!moduli.is_empty(), "moduli cannot be empty"); - let rings: Vec> = moduli - .into_iter() - .map(|prime| Ring::new(n, prime, 1)) - .collect(); - return rings; -} - -impl<'a> RingRNS<'a, u64> { - pub fn new(rings: &'a [Ring]) -> Self { - RingRNS(rings) +impl RingRNS { + pub fn new(n: usize, moduli: Vec) -> Self { + assert!(!moduli.is_empty(), "moduli cannot be empty"); + let rings: Vec>> = moduli + .into_iter() + .map(|prime| Arc::new(Ring::new(n, prime, 1))) + .collect(); + return RingRNS(rings); } pub fn modulus(&self) -> BigInt { @@ -92,7 +89,7 @@ impl<'a> RingRNS<'a, u64> { } } -impl RingRNS<'_, u64> { +impl RingRNS { pub fn ntt_inplace(&self, a: &mut PolyRNS) { self.0 .iter() @@ -122,7 +119,7 @@ impl RingRNS<'_, u64> { } } -impl RingRNS<'_, u64> { +impl RingRNS { #[inline(always)] pub fn add( &self, diff --git a/math/src/ring/impl_u64/sampling.rs b/math/src/ring/impl_u64/sampling.rs index 8a2ab27..5500e0a 100644 --- a/math/src/ring/impl_u64/sampling.rs +++ b/math/src/ring/impl_u64/sampling.rs @@ -12,7 +12,7 @@ impl Ring { } } -impl RingRNS<'_, u64> { +impl RingRNS { pub fn fill_uniform(&self, source: &mut Source, a: &mut PolyRNS) { self.0 .iter() diff --git a/math/tests/rescaling_rns.rs b/math/tests/rescaling_rns.rs index b5c7dba..fea7744 100644 --- a/math/tests/rescaling_rns.rs +++ b/math/tests/rescaling_rns.rs @@ -1,6 +1,5 @@ use math::poly::PolyRNS; -use math::ring::impl_u64::ring_rns::new_rings; -use math::ring::{Ring, RingRNS}; +use math::ring::RingRNS; use num_bigint::BigInt; use num_bigint::Sign; use sampling::source::Source; @@ -9,8 +8,7 @@ use sampling::source::Source; fn rescaling_rns_u64() { let n = 1 << 10; let moduli: Vec = vec![0x1fffffffffc80001u64, 0x1fffffffffe00001u64]; - let rings: Vec> = new_rings(n, moduli); - let ring_rns: RingRNS<'_, u64> = RingRNS::new(&rings); + let ring_rns: RingRNS = RingRNS::new(n, moduli); test_div_floor_by_last_modulus::(&ring_rns); test_div_floor_by_last_modulus::(&ring_rns);