refactored RingRNS

This commit is contained in:
Jean-Philippe Bossuat
2025-01-06 14:40:03 +01:00
parent a074886b3e
commit c69bd6985a
8 changed files with 28 additions and 34 deletions

View File

@@ -1,10 +1,9 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use math::poly::PolyRNS;
use math::ring::impl_u64::ring_rns::new_rings;
use math::ring::{Ring, RingRNS};
use math::ring::RingRNS;
fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) {
fn runner(r: RingRNS<u64>) -> Box<dyn FnMut() + '_> {
fn runner(r: RingRNS<u64>) -> Box<dyn FnMut()> {
let a: PolyRNS<u64> = r.new_polyrns();
let mut b: PolyRNS<u64> = r.new_polyrns();
let mut c: PolyRNS<u64> = r.new_polyrns();
@@ -22,8 +21,8 @@ fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) {
0x1fffffffffb40001,
0x1fffffffff500001,
];
let rings: Vec<Ring<u64>> = new_rings(n, moduli);
let ring_rns: RingRNS<'_, u64> = RingRNS::new(&rings);
let ring_rns: RingRNS<u64> = RingRNS::new(n, moduli);
let runners = [(format!("prime/n={}/level={}", n, ring_rns.level()), {
runner(ring_rns)

View File

@@ -1,11 +1,10 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use math::poly::PolyRNS;
use math::ring::impl_u64::ring_rns::new_rings;
use math::ring::{Ring, RingRNS};
use math::ring::RingRNS;
use sampling::source::Source;
fn fill_uniform(c: &mut Criterion) {
fn runner(r: RingRNS<u64>) -> Box<dyn FnMut() + '_> {
fn runner(r: RingRNS<u64>) -> Box<dyn FnMut()> {
let mut a: PolyRNS<u64> = r.new_polyrns();
let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed);
@@ -25,8 +24,8 @@ fn fill_uniform(c: &mut Criterion) {
0x1fffffffffb40001,
0x1fffffffff500001,
];
let rings: Vec<Ring<u64>> = new_rings(n, moduli);
let ring_rns: RingRNS<'_, u64> = RingRNS::new(&rings);
let ring_rns: RingRNS<u64> = RingRNS::new(n, moduli);
let runners = [(format!("prime/n={}/level={}", n, ring_rns.level()), {
runner(ring_rns)

View File

@@ -19,7 +19,7 @@ pub struct Table<O> {
}
impl Table<u64> {
pub fn new(prime: Prime<u64>, nth_root: u64) -> Self {
pub fn new(prime: Prime<u64>, nth_root: u64) -> Table<u64> {
assert!(
nth_root & (nth_root - 1) == 0,
"invalid argument: nth_root = {} is not a power of two",

View File

@@ -4,6 +4,7 @@ use crate::dft::DFT;
use crate::modulus::prime::Prime;
use crate::poly::{Poly, PolyRNS};
use num::traits::Unsigned;
use std::sync::Arc;
pub struct Ring<O: Unsigned> {
pub n: usize,
@@ -21,9 +22,9 @@ impl<O: Unsigned> Ring<O> {
}
}
pub struct RingRNS<'a, O: Unsigned>(pub &'a [Ring<O>]);
pub struct RingRNS<O: Unsigned>(pub Vec<Arc<Ring<O>>>);
impl<O: Unsigned> RingRNS<'_, O> {
impl<O: Unsigned> RingRNS<O> {
pub fn n(&self) -> usize {
self.0[0].n()
}
@@ -42,6 +43,6 @@ impl<O: Unsigned> RingRNS<'_, O> {
pub fn at_level(&self, level: usize) -> RingRNS<O> {
assert!(level <= self.0.len());
RingRNS(&self.0[..level + 1])
RingRNS(self.0[..level + 1].to_vec())
}
}

View File

@@ -6,7 +6,7 @@ use crate::ring::RingRNS;
use crate::scalar::ScalarRNS;
extern crate test;
impl RingRNS<'_, u64> {
impl RingRNS<u64> {
/// Updates b to floor(a / q[b.level()]).
pub fn div_floor_by_last_modulus<const NTT: bool>(
&self,

View File

@@ -5,19 +5,16 @@ use crate::poly::PolyRNS;
use crate::ring::{Ring, RingRNS};
use crate::scalar::ScalarRNS;
use num_bigint::BigInt;
use std::sync::Arc;
pub fn new_rings(n: usize, moduli: Vec<u64>) -> Vec<Ring<u64>> {
impl RingRNS<u64> {
pub fn new(n: usize, moduli: Vec<u64>) -> Self {
assert!(!moduli.is_empty(), "moduli cannot be empty");
let rings: Vec<Ring<u64>> = moduli
let rings: Vec<Arc<Ring<u64>>> = moduli
.into_iter()
.map(|prime| Ring::new(n, prime, 1))
.map(|prime| Arc::new(Ring::new(n, prime, 1)))
.collect();
return rings;
}
impl<'a> RingRNS<'a, u64> {
pub fn new(rings: &'a [Ring<u64>]) -> Self {
RingRNS(rings)
return RingRNS(rings);
}
pub fn modulus(&self) -> BigInt {
@@ -92,7 +89,7 @@ impl<'a> RingRNS<'a, u64> {
}
}
impl RingRNS<'_, u64> {
impl RingRNS<u64> {
pub fn ntt_inplace<const LAZY: bool>(&self, a: &mut PolyRNS<u64>) {
self.0
.iter()
@@ -122,7 +119,7 @@ impl RingRNS<'_, u64> {
}
}
impl RingRNS<'_, u64> {
impl RingRNS<u64> {
#[inline(always)]
pub fn add<const REDUCE: REDUCEMOD>(
&self,

View File

@@ -12,7 +12,7 @@ impl Ring<u64> {
}
}
impl RingRNS<'_, u64> {
impl RingRNS<u64> {
pub fn fill_uniform(&self, source: &mut Source, a: &mut PolyRNS<u64>) {
self.0
.iter()

View File

@@ -1,6 +1,5 @@
use math::poly::PolyRNS;
use math::ring::impl_u64::ring_rns::new_rings;
use math::ring::{Ring, RingRNS};
use math::ring::RingRNS;
use num_bigint::BigInt;
use num_bigint::Sign;
use sampling::source::Source;
@@ -9,8 +8,7 @@ use sampling::source::Source;
fn rescaling_rns_u64() {
let n = 1 << 10;
let moduli: Vec<u64> = vec![0x1fffffffffc80001u64, 0x1fffffffffe00001u64];
let rings: Vec<Ring<u64>> = new_rings(n, moduli);
let ring_rns: RingRNS<'_, u64> = RingRNS::new(&rings);
let ring_rns: RingRNS<u64> = RingRNS::new(n, moduli);
test_div_floor_by_last_modulus::<false>(&ring_rns);
test_div_floor_by_last_modulus::<true>(&ring_rns);