refactored RingRNS

This commit is contained in:
Jean-Philippe Bossuat
2025-01-06 14:40:03 +01:00
parent a074886b3e
commit c69bd6985a
8 changed files with 28 additions and 34 deletions

View File

@@ -1,10 +1,9 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use math::poly::PolyRNS; use math::poly::PolyRNS;
use math::ring::impl_u64::ring_rns::new_rings; use math::ring::RingRNS;
use math::ring::{Ring, RingRNS};
fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) { fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) {
fn runner(r: RingRNS<u64>) -> Box<dyn FnMut() + '_> { fn runner(r: RingRNS<u64>) -> Box<dyn FnMut()> {
let a: PolyRNS<u64> = r.new_polyrns(); let a: PolyRNS<u64> = r.new_polyrns();
let mut b: PolyRNS<u64> = r.new_polyrns(); let mut b: PolyRNS<u64> = r.new_polyrns();
let mut c: PolyRNS<u64> = r.new_polyrns(); let mut c: PolyRNS<u64> = r.new_polyrns();
@@ -22,8 +21,8 @@ fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) {
0x1fffffffffb40001, 0x1fffffffffb40001,
0x1fffffffff500001, 0x1fffffffff500001,
]; ];
let rings: Vec<Ring<u64>> = new_rings(n, moduli);
let ring_rns: RingRNS<'_, u64> = RingRNS::new(&rings); let ring_rns: RingRNS<u64> = RingRNS::new(n, moduli);
let runners = [(format!("prime/n={}/level={}", n, ring_rns.level()), { let runners = [(format!("prime/n={}/level={}", n, ring_rns.level()), {
runner(ring_rns) runner(ring_rns)

View File

@@ -1,11 +1,10 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use math::poly::PolyRNS; use math::poly::PolyRNS;
use math::ring::impl_u64::ring_rns::new_rings; use math::ring::RingRNS;
use math::ring::{Ring, RingRNS};
use sampling::source::Source; use sampling::source::Source;
fn fill_uniform(c: &mut Criterion) { fn fill_uniform(c: &mut Criterion) {
fn runner(r: RingRNS<u64>) -> Box<dyn FnMut() + '_> { fn runner(r: RingRNS<u64>) -> Box<dyn FnMut()> {
let mut a: PolyRNS<u64> = r.new_polyrns(); let mut a: PolyRNS<u64> = r.new_polyrns();
let seed: [u8; 32] = [0; 32]; let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed); let mut source: Source = Source::new(seed);
@@ -25,8 +24,8 @@ fn fill_uniform(c: &mut Criterion) {
0x1fffffffffb40001, 0x1fffffffffb40001,
0x1fffffffff500001, 0x1fffffffff500001,
]; ];
let rings: Vec<Ring<u64>> = new_rings(n, moduli);
let ring_rns: RingRNS<'_, u64> = RingRNS::new(&rings); let ring_rns: RingRNS<u64> = RingRNS::new(n, moduli);
let runners = [(format!("prime/n={}/level={}", n, ring_rns.level()), { let runners = [(format!("prime/n={}/level={}", n, ring_rns.level()), {
runner(ring_rns) runner(ring_rns)

View File

@@ -19,7 +19,7 @@ pub struct Table<O> {
} }
impl Table<u64> { impl Table<u64> {
pub fn new(prime: Prime<u64>, nth_root: u64) -> Self { pub fn new(prime: Prime<u64>, nth_root: u64) -> Table<u64> {
assert!( assert!(
nth_root & (nth_root - 1) == 0, nth_root & (nth_root - 1) == 0,
"invalid argument: nth_root = {} is not a power of two", "invalid argument: nth_root = {} is not a power of two",

View File

@@ -4,6 +4,7 @@ use crate::dft::DFT;
use crate::modulus::prime::Prime; use crate::modulus::prime::Prime;
use crate::poly::{Poly, PolyRNS}; use crate::poly::{Poly, PolyRNS};
use num::traits::Unsigned; use num::traits::Unsigned;
use std::sync::Arc;
pub struct Ring<O: Unsigned> { pub struct Ring<O: Unsigned> {
pub n: usize, pub n: usize,
@@ -21,9 +22,9 @@ impl<O: Unsigned> Ring<O> {
} }
} }
pub struct RingRNS<'a, O: Unsigned>(pub &'a [Ring<O>]); pub struct RingRNS<O: Unsigned>(pub Vec<Arc<Ring<O>>>);
impl<O: Unsigned> RingRNS<'_, O> { impl<O: Unsigned> RingRNS<O> {
pub fn n(&self) -> usize { pub fn n(&self) -> usize {
self.0[0].n() self.0[0].n()
} }
@@ -42,6 +43,6 @@ impl<O: Unsigned> RingRNS<'_, O> {
pub fn at_level(&self, level: usize) -> RingRNS<O> { pub fn at_level(&self, level: usize) -> RingRNS<O> {
assert!(level <= self.0.len()); assert!(level <= self.0.len());
RingRNS(&self.0[..level + 1]) RingRNS(self.0[..level + 1].to_vec())
} }
} }

View File

@@ -6,7 +6,7 @@ use crate::ring::RingRNS;
use crate::scalar::ScalarRNS; use crate::scalar::ScalarRNS;
extern crate test; extern crate test;
impl RingRNS<'_, u64> { impl RingRNS<u64> {
/// Updates b to floor(a / q[b.level()]). /// Updates b to floor(a / q[b.level()]).
pub fn div_floor_by_last_modulus<const NTT: bool>( pub fn div_floor_by_last_modulus<const NTT: bool>(
&self, &self,

View File

@@ -5,19 +5,16 @@ use crate::poly::PolyRNS;
use crate::ring::{Ring, RingRNS}; use crate::ring::{Ring, RingRNS};
use crate::scalar::ScalarRNS; use crate::scalar::ScalarRNS;
use num_bigint::BigInt; use num_bigint::BigInt;
use std::sync::Arc;
pub fn new_rings(n: usize, moduli: Vec<u64>) -> Vec<Ring<u64>> { impl RingRNS<u64> {
pub fn new(n: usize, moduli: Vec<u64>) -> Self {
assert!(!moduli.is_empty(), "moduli cannot be empty"); assert!(!moduli.is_empty(), "moduli cannot be empty");
let rings: Vec<Ring<u64>> = moduli let rings: Vec<Arc<Ring<u64>>> = moduli
.into_iter() .into_iter()
.map(|prime| Ring::new(n, prime, 1)) .map(|prime| Arc::new(Ring::new(n, prime, 1)))
.collect(); .collect();
return rings; return RingRNS(rings);
}
impl<'a> RingRNS<'a, u64> {
pub fn new(rings: &'a [Ring<u64>]) -> Self {
RingRNS(rings)
} }
pub fn modulus(&self) -> BigInt { pub fn modulus(&self) -> BigInt {
@@ -92,7 +89,7 @@ impl<'a> RingRNS<'a, u64> {
} }
} }
impl RingRNS<'_, u64> { impl RingRNS<u64> {
pub fn ntt_inplace<const LAZY: bool>(&self, a: &mut PolyRNS<u64>) { pub fn ntt_inplace<const LAZY: bool>(&self, a: &mut PolyRNS<u64>) {
self.0 self.0
.iter() .iter()
@@ -122,7 +119,7 @@ impl RingRNS<'_, u64> {
} }
} }
impl RingRNS<'_, u64> { impl RingRNS<u64> {
#[inline(always)] #[inline(always)]
pub fn add<const REDUCE: REDUCEMOD>( pub fn add<const REDUCE: REDUCEMOD>(
&self, &self,

View File

@@ -12,7 +12,7 @@ impl Ring<u64> {
} }
} }
impl RingRNS<'_, u64> { impl RingRNS<u64> {
pub fn fill_uniform(&self, source: &mut Source, a: &mut PolyRNS<u64>) { pub fn fill_uniform(&self, source: &mut Source, a: &mut PolyRNS<u64>) {
self.0 self.0
.iter() .iter()

View File

@@ -1,6 +1,5 @@
use math::poly::PolyRNS; use math::poly::PolyRNS;
use math::ring::impl_u64::ring_rns::new_rings; use math::ring::RingRNS;
use math::ring::{Ring, RingRNS};
use num_bigint::BigInt; use num_bigint::BigInt;
use num_bigint::Sign; use num_bigint::Sign;
use sampling::source::Source; use sampling::source::Source;
@@ -9,8 +8,7 @@ use sampling::source::Source;
fn rescaling_rns_u64() { fn rescaling_rns_u64() {
let n = 1 << 10; let n = 1 << 10;
let moduli: Vec<u64> = vec![0x1fffffffffc80001u64, 0x1fffffffffe00001u64]; let moduli: Vec<u64> = vec![0x1fffffffffc80001u64, 0x1fffffffffe00001u64];
let rings: Vec<Ring<u64>> = new_rings(n, moduli); let ring_rns: RingRNS<u64> = RingRNS::new(n, moduli);
let ring_rns: RingRNS<'_, u64> = RingRNS::new(&rings);
test_div_floor_by_last_modulus::<false>(&ring_rns); test_div_floor_by_last_modulus::<false>(&ring_rns);
test_div_floor_by_last_modulus::<true>(&ring_rns); test_div_floor_by_last_modulus::<true>(&ring_rns);