This commit is contained in:
Jean-Philippe Bossuat
2025-01-06 14:10:28 +01:00
parent 681268c28e
commit a074886b3e
29 changed files with 1650 additions and 928 deletions

View File

@@ -1,6 +1,6 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use math::{modulus::prime::Prime,dft::ntt::Table};
use math::dft::DFT; use math::dft::DFT;
use math::{dft::ntt::Table, modulus::prime::Prime};
fn forward_inplace(c: &mut Criterion) { fn forward_inplace(c: &mut Criterion) {
fn runner(prime_instance: Prime<u64>, nth_root: u64) -> Box<dyn FnMut()> { fn runner(prime_instance: Prime<u64>, nth_root: u64) -> Box<dyn FnMut()> {
@@ -9,21 +9,15 @@ fn forward_inplace(c: &mut Criterion) {
for i in 0..a.len() { for i in 0..a.len() {
a[i] = i as u64; a[i] = i as u64;
} }
Box::new(move || { Box::new(move || ntt_table.forward_inplace::<false>(&mut a))
ntt_table.forward_inplace::<false>(&mut a)
})
} }
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("forward_inplace"); let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
c.benchmark_group("forward_inplace");
for log_nth_root in 11..18 { for log_nth_root in 11..18 {
let prime_instance: Prime<u64> = Prime::<u64>::new(0x1fffffffffe00001, 1); let prime_instance: Prime<u64> = Prime::<u64>::new(0x1fffffffffe00001, 1);
let runners = [ let runners = [("prime", { runner(prime_instance, 1 << log_nth_root) })];
("prime", {
runner(prime_instance, 1<<log_nth_root)
}),
];
for (name, mut runner) in runners { for (name, mut runner) in runners {
let id = BenchmarkId::new(name, 1 << (log_nth_root - 1)); let id = BenchmarkId::new(name, 1 << (log_nth_root - 1));
b.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); b.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
@@ -38,21 +32,15 @@ fn forward_inplace_lazy(c: &mut Criterion) {
for i in 0..a.len() { for i in 0..a.len() {
a[i] = i as u64; a[i] = i as u64;
} }
Box::new(move || { Box::new(move || ntt_table.forward_inplace_lazy(&mut a))
ntt_table.forward_inplace_lazy(&mut a)
})
} }
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("forward_inplace_lazy"); let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
c.benchmark_group("forward_inplace_lazy");
for log_nth_root in 11..17 { for log_nth_root in 11..17 {
let prime_instance: Prime<u64> = Prime::<u64>::new(0x1fffffffffe00001, 1); let prime_instance: Prime<u64> = Prime::<u64>::new(0x1fffffffffe00001, 1);
let runners = [ let runners = [("prime", { runner(prime_instance, 1 << log_nth_root) })];
("prime", {
runner(prime_instance, 1<<log_nth_root)
}),
];
for (name, mut runner) in runners { for (name, mut runner) in runners {
let id = BenchmarkId::new(name, 1 << (log_nth_root - 1)); let id = BenchmarkId::new(name, 1 << (log_nth_root - 1));
b.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); b.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
@@ -67,21 +55,15 @@ fn backward_inplace(c: &mut Criterion) {
for i in 0..a.len() { for i in 0..a.len() {
a[i] = i as u64; a[i] = i as u64;
} }
Box::new(move || { Box::new(move || ntt_table.backward_inplace::<false>(&mut a))
ntt_table.backward_inplace::<false>(&mut a)
})
} }
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("backward_inplace"); let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
c.benchmark_group("backward_inplace");
for log_nth_root in 11..18 { for log_nth_root in 11..18 {
let prime_instance: Prime<u64> = Prime::<u64>::new(0x1fffffffffe00001, 1); let prime_instance: Prime<u64> = Prime::<u64>::new(0x1fffffffffe00001, 1);
let runners = [ let runners = [("prime", { runner(prime_instance, 1 << log_nth_root) })];
("prime", {
runner(prime_instance, 1<<log_nth_root)
}),
];
for (name, mut runner) in runners { for (name, mut runner) in runners {
let id = BenchmarkId::new(name, 1 << (log_nth_root - 1)); let id = BenchmarkId::new(name, 1 << (log_nth_root - 1));
b.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); b.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
@@ -96,21 +78,15 @@ fn backward_inplace_lazy(c: &mut Criterion) {
for i in 0..a.len() { for i in 0..a.len() {
a[i] = i as u64; a[i] = i as u64;
} }
Box::new(move || { Box::new(move || ntt_table.backward_inplace::<true>(&mut a))
ntt_table.backward_inplace::<true>(&mut a)
})
} }
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("backward_inplace_lazy"); let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
c.benchmark_group("backward_inplace_lazy");
for log_nth_root in 11..17 { for log_nth_root in 11..17 {
let prime_instance: Prime<u64> = Prime::<u64>::new(0x1fffffffffe00001, 1); let prime_instance: Prime<u64> = Prime::<u64>::new(0x1fffffffffe00001, 1);
let runners = [ let runners = [("prime", { runner(prime_instance, 1 << log_nth_root) })];
("prime", {
runner(prime_instance, 1<<log_nth_root)
}),
];
for (name, mut runner) in runners { for (name, mut runner) in runners {
let id = BenchmarkId::new(name, 1 << (log_nth_root - 1)); let id = BenchmarkId::new(name, 1 << (log_nth_root - 1));
b.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); b.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
@@ -118,5 +94,11 @@ fn backward_inplace_lazy(c: &mut Criterion) {
} }
} }
criterion_group!(benches, forward_inplace, forward_inplace_lazy, backward_inplace, backward_inplace_lazy); criterion_group!(
benches,
forward_inplace,
forward_inplace_lazy,
backward_inplace,
backward_inplace_lazy
);
criterion_main!(benches); criterion_main!(benches);

View File

@@ -1,13 +1,12 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use math::ring::Ring;
use math::modulus::VectorOperations;
use math::modulus::montgomery::Montgomery; use math::modulus::montgomery::Montgomery;
use math::modulus::VectorOperations;
use math::modulus::ONCE; use math::modulus::ONCE;
use math::ring::Ring;
use math::CHUNK; use math::CHUNK;
fn va_add_vb_into_vb(c: &mut Criterion) { fn va_add_vb_into_vb(c: &mut Criterion) {
fn runner(r: Ring<u64>) -> Box<dyn FnMut()> { fn runner(r: Ring<u64>) -> Box<dyn FnMut()> {
let mut p0: math::poly::Poly<u64> = r.new_poly(); let mut p0: math::poly::Poly<u64> = r.new_poly();
let mut p1: math::poly::Poly<u64> = r.new_poly(); let mut p1: math::poly::Poly<u64> = r.new_poly();
for i in 0..p0.n() { for i in 0..p0.n() {
@@ -19,18 +18,14 @@ fn va_add_vb_into_vb(c: &mut Criterion) {
}) })
} }
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("va_add_vb_into_vb"); let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
c.benchmark_group("va_add_vb_into_vb");
for log_n in 11..17 { for log_n in 11..17 {
let n: usize = 1 << log_n as usize; let n: usize = 1 << log_n as usize;
let q_base: u64 = 0x1fffffffffe00001u64; let q_base: u64 = 0x1fffffffffe00001u64;
let q_power: usize = 1usize; let q_power: usize = 1usize;
let r: Ring<u64> = Ring::<u64>::new(n, q_base, q_power); let r: Ring<u64> = Ring::<u64>::new(n, q_base, q_power);
let runners = [ let runners = [("prime", { runner(r) })];
("prime", {
runner(r)
}),
];
for (name, mut runner) in runners { for (name, mut runner) in runners {
let id = BenchmarkId::new(name, n); let id = BenchmarkId::new(name, n);
b.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); b.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
@@ -40,7 +35,6 @@ fn va_add_vb_into_vb(c: &mut Criterion) {
fn va_mont_mul_vb_into_vb(c: &mut Criterion) { fn va_mont_mul_vb_into_vb(c: &mut Criterion) {
fn runner(r: Ring<u64>) -> Box<dyn FnMut()> { fn runner(r: Ring<u64>) -> Box<dyn FnMut()> {
let mut p0: math::poly::Poly<Montgomery<u64>> = r.new_poly(); let mut p0: math::poly::Poly<Montgomery<u64>> = r.new_poly();
let mut p1: math::poly::Poly<u64> = r.new_poly(); let mut p1: math::poly::Poly<u64> = r.new_poly();
for i in 0..p0.n() { for i in 0..p0.n() {
@@ -48,22 +42,19 @@ fn va_mont_mul_vb_into_vb(c: &mut Criterion) {
p1.0[i] = i as u64; p1.0[i] = i as u64;
} }
Box::new(move || { Box::new(move || {
r.modulus.va_mont_mul_vb_into_vb::<CHUNK, ONCE>(&p0.0, &mut p1.0); r.modulus
.va_mont_mul_vb_into_vb::<CHUNK, ONCE>(&p0.0, &mut p1.0);
}) })
} }
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("va_mont_mul_vb_into_vb"); let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
c.benchmark_group("va_mont_mul_vb_into_vb");
for log_n in 11..17 { for log_n in 11..17 {
let n: usize = 1 << log_n as usize; let n: usize = 1 << log_n as usize;
let q_base: u64 = 0x1fffffffffe00001u64; let q_base: u64 = 0x1fffffffffe00001u64;
let q_power: usize = 1usize; let q_power: usize = 1usize;
let r: Ring<u64> = Ring::<u64>::new(n, q_base, q_power); let r: Ring<u64> = Ring::<u64>::new(n, q_base, q_power);
let runners = [ let runners = [("prime", { runner(r) })];
("prime", {
runner(r)
}),
];
for (name, mut runner) in runners { for (name, mut runner) in runners {
let id = BenchmarkId::new(name, n); let id = BenchmarkId::new(name, n);
b.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); b.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
@@ -73,7 +64,6 @@ fn va_mont_mul_vb_into_vb(c: &mut Criterion) {
fn va_mont_mul_vb_into_vc(c: &mut Criterion) { fn va_mont_mul_vb_into_vc(c: &mut Criterion) {
fn runner(r: Ring<u64>) -> Box<dyn FnMut()> { fn runner(r: Ring<u64>) -> Box<dyn FnMut()> {
let mut p0: math::poly::Poly<Montgomery<u64>> = r.new_poly(); let mut p0: math::poly::Poly<Montgomery<u64>> = r.new_poly();
let mut p1: math::poly::Poly<u64> = r.new_poly(); let mut p1: math::poly::Poly<u64> = r.new_poly();
let mut p2: math::poly::Poly<u64> = r.new_poly(); let mut p2: math::poly::Poly<u64> = r.new_poly();
@@ -82,22 +72,19 @@ fn va_mont_mul_vb_into_vc(c: &mut Criterion) {
p1.0[i] = i as u64; p1.0[i] = i as u64;
} }
Box::new(move || { Box::new(move || {
r.modulus.va_mont_mul_vb_into_vc::<CHUNK,ONCE>(&p0.0, & p1.0, &mut p2.0); r.modulus
.va_mont_mul_vb_into_vc::<CHUNK, ONCE>(&p0.0, &p1.0, &mut p2.0);
}) })
} }
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("va_mont_mul_vb_into_vc"); let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
c.benchmark_group("va_mont_mul_vb_into_vc");
for log_n in 11..17 { for log_n in 11..17 {
let n: usize = 1 << log_n as usize; let n: usize = 1 << log_n as usize;
let q_base: u64 = 0x1fffffffffe00001u64; let q_base: u64 = 0x1fffffffffe00001u64;
let q_power: usize = 1usize; let q_power: usize = 1usize;
let r: Ring<u64> = Ring::<u64>::new(n, q_base, q_power); let r: Ring<u64> = Ring::<u64>::new(n, q_base, q_power);
let runners = [ let runners = [("prime", { runner(r) })];
("prime", {
runner(r)
}),
];
for (name, mut runner) in runners { for (name, mut runner) in runners {
let id = BenchmarkId::new(name, n); let id = BenchmarkId::new(name, n);
b.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); b.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
@@ -105,5 +92,10 @@ fn va_mont_mul_vb_into_vc(c: &mut Criterion) {
} }
} }
criterion_group!(benches, va_add_vb_into_vb, va_mont_mul_vb_into_vb, va_mont_mul_vb_into_vc); criterion_group!(
benches,
va_add_vb_into_vb,
va_mont_mul_vb_into_vb,
va_mont_mul_vb_into_vc
);
criterion_main!(benches); criterion_main!(benches);

View File

@@ -1,33 +1,33 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use math::ring::{Ring, RingRNS};
use math::ring::impl_u64::ring_rns::new_rings;
use math::poly::PolyRNS; use math::poly::PolyRNS;
use math::ring::impl_u64::ring_rns::new_rings;
use math::ring::{Ring, RingRNS};
fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) { fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) {
fn runner(r: RingRNS<u64>) -> Box<dyn FnMut() + '_> { fn runner(r: RingRNS<u64>) -> Box<dyn FnMut() + '_> {
let a: PolyRNS<u64> = r.new_polyrns(); let a: PolyRNS<u64> = r.new_polyrns();
let mut b: PolyRNS<u64> = r.new_polyrns(); let mut b: PolyRNS<u64> = r.new_polyrns();
let mut c: PolyRNS<u64> = r.new_polyrns(); let mut c: PolyRNS<u64> = r.new_polyrns();
Box::new(move || { Box::new(move || r.div_floor_by_last_modulus::<true>(&a, &mut b, &mut c))
r.div_floor_by_last_modulus::<true>(&a, &mut b, &mut c)
})
} }
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("div_floor_by_last_modulus_ntt_true"); let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
c.benchmark_group("div_floor_by_last_modulus_ntt_true");
for log_n in 11..18 { for log_n in 11..18 {
let n = 1 << log_n; let n = 1 << log_n;
let moduli: Vec<u64> = vec![0x1fffffffffe00001u64, 0x1fffffffffc80001u64, 0x1fffffffffb40001, 0x1fffffffff500001]; let moduli: Vec<u64> = vec![
0x1fffffffffe00001u64,
0x1fffffffffc80001u64,
0x1fffffffffb40001,
0x1fffffffff500001,
];
let rings: Vec<Ring<u64>> = new_rings(n, moduli); let rings: Vec<Ring<u64>> = new_rings(n, moduli);
let ring_rns: RingRNS<'_, u64> = RingRNS::new(&rings); let ring_rns: RingRNS<'_, u64> = RingRNS::new(&rings);
let runners = [ let runners = [(format!("prime/n={}/level={}", n, ring_rns.level()), {
(format!("prime/n={}/level={}", n, ring_rns.level()), {
runner(ring_rns) runner(ring_rns)
}), })];
];
for (name, mut runner) in runners { for (name, mut runner) in runners {
b.bench_with_input(name, &(), |b, _| b.iter(&mut runner)); b.bench_with_input(name, &(), |b, _| b.iter(&mut runner));

View File

@@ -1,12 +1,11 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use math::ring::{Ring, RingRNS};
use math::ring::impl_u64::ring_rns::new_rings;
use math::poly::PolyRNS; use math::poly::PolyRNS;
use math::ring::impl_u64::ring_rns::new_rings;
use math::ring::{Ring, RingRNS};
use sampling::source::Source; use sampling::source::Source;
fn fill_uniform(c: &mut Criterion) { fn fill_uniform(c: &mut Criterion) {
fn runner(r: RingRNS<u64>) -> Box<dyn FnMut() + '_> { fn runner(r: RingRNS<u64>) -> Box<dyn FnMut() + '_> {
let mut a: PolyRNS<u64> = r.new_polyrns(); let mut a: PolyRNS<u64> = r.new_polyrns();
let seed: [u8; 32] = [0; 32]; let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed); let mut source: Source = Source::new(seed);
@@ -16,19 +15,22 @@ fn fill_uniform(c: &mut Criterion) {
}) })
} }
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("fill_uniform"); let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
c.benchmark_group("fill_uniform");
for log_n in 11..18 { for log_n in 11..18 {
let n = 1 << log_n; let n = 1 << log_n;
let moduli: Vec<u64> = vec![0x1fffffffffe00001u64, 0x1fffffffffc80001u64, 0x1fffffffffb40001, 0x1fffffffff500001]; let moduli: Vec<u64> = vec![
0x1fffffffffe00001u64,
0x1fffffffffc80001u64,
0x1fffffffffb40001,
0x1fffffffff500001,
];
let rings: Vec<Ring<u64>> = new_rings(n, moduli); let rings: Vec<Ring<u64>> = new_rings(n, moduli);
let ring_rns: RingRNS<'_, u64> = RingRNS::new(&rings); let ring_rns: RingRNS<'_, u64> = RingRNS::new(&rings);
let runners = [ let runners = [(format!("prime/n={}/level={}", n, ring_rns.level()), {
(format!("prime/n={}/level={}", n, ring_rns.level()), {
runner(ring_rns) runner(ring_rns)
}), })];
];
for (name, mut runner) in runners { for (name, mut runner) in runners {
b.bench_with_input(name, &(), |b, _| b.iter(&mut runner)); b.bench_with_input(name, &(), |b, _| b.iter(&mut runner));

View File

@@ -1,6 +1,6 @@
use math::ring::Ring;
use math::modulus::prime::Prime;
use math::dft::ntt::Table; use math::dft::ntt::Table;
use math::modulus::prime::Prime;
use math::ring::Ring;
fn main() { fn main() {
// Example usage of `Prime<u64>` // Example usage of `Prime<u64>`
@@ -47,5 +47,4 @@ fn main() {
r.automorphism(p0, (2 * r.n - 1) as u64, &mut p1); r.automorphism(p0, (2 * r.n - 1) as u64, &mut p1);
println!("{:?}", p1); println!("{:?}", p1);
} }

View File

@@ -1,10 +1,10 @@
use crate::modulus::montgomery::Montgomery; use crate::dft::DFT;
use crate::modulus::barrett::Barrett; use crate::modulus::barrett::Barrett;
use crate::modulus::montgomery::Montgomery;
use crate::modulus::prime::Prime; use crate::modulus::prime::Prime;
use crate::modulus::ReduceOnce; use crate::modulus::ReduceOnce;
use crate::modulus::WordOps; use crate::modulus::WordOps;
use crate::modulus::{NONE, ONCE, BARRETT}; use crate::modulus::{BARRETT, NONE, ONCE};
use crate::dft::DFT;
use itertools::izip; use itertools::izip;
#[allow(dead_code)] #[allow(dead_code)]
@@ -20,8 +20,11 @@ pub struct Table<O>{
impl Table<u64> { impl Table<u64> {
pub fn new(prime: Prime<u64>, nth_root: u64) -> Self { pub fn new(prime: Prime<u64>, nth_root: u64) -> Self {
assert!(
assert!(nth_root&(nth_root-1) == 0, "invalid argument: nth_root = {} is not a power of two", nth_root); nth_root & (nth_root - 1) == 0,
"invalid argument: nth_root = {} is not a power of two",
nth_root
);
let psi: u64 = prime.primitive_nth_root(nth_root); let psi: u64 = prime.primitive_nth_root(nth_root);
@@ -40,11 +43,14 @@ impl Table< u64> {
let mut powers_backward: u64 = 1u64; let mut powers_backward: u64 = 1u64;
for i in 1..(nth_root >> 1) as usize { for i in 1..(nth_root >> 1) as usize {
let i_rev: usize = i.reverse_bits_msb(log_nth_root_half); let i_rev: usize = i.reverse_bits_msb(log_nth_root_half);
prime.montgomery.mul_external_assign::<ONCE>(psi_mont, &mut powers_forward); prime
prime.montgomery.mul_external_assign::<ONCE>(psi_inv_mont, &mut powers_backward); .montgomery
.mul_external_assign::<ONCE>(psi_mont, &mut powers_forward);
prime
.montgomery
.mul_external_assign::<ONCE>(psi_inv_mont, &mut powers_backward);
psi_forward_rev[i_rev] = prime.barrett.prepare(powers_forward); psi_forward_rev[i_rev] = prime.barrett.prepare(powers_forward);
psi_backward_rev[i_rev] = prime.barrett.prepare(powers_backward); psi_backward_rev[i_rev] = prime.barrett.prepare(powers_backward);
@@ -64,9 +70,7 @@ impl Table< u64> {
} }
} }
impl DFT<u64> for Table<u64> { impl DFT<u64> for Table<u64> {
fn forward_inplace(&self, a: &mut [u64]) { fn forward_inplace(&self, a: &mut [u64]) {
self.forward_inplace::<false>(a) self.forward_inplace::<false>(a)
} }
@@ -85,15 +89,20 @@ impl DFT<u64> for Table<u64>{
} }
impl Table<u64> { impl Table<u64> {
pub fn forward_inplace<const LAZY: bool>(&self, a: &mut [u64]) { pub fn forward_inplace<const LAZY: bool>(&self, a: &mut [u64]) {
self.forward_inplace_core::<LAZY, 0, 0>(a); self.forward_inplace_core::<LAZY, 0, 0>(a);
} }
pub fn forward_inplace_core<const LAZY: bool, const SKIPSTART: u8, const SKIPEND: u8>(&self, a: &mut [u64]) { pub fn forward_inplace_core<const LAZY: bool, const SKIPSTART: u8, const SKIPEND: u8>(
&self,
a: &mut [u64],
) {
let n: usize = a.len(); let n: usize = a.len();
assert!(n & n-1 == 0, "invalid x.len()= {} must be a power of two", n); assert!(
n & n - 1 == 0,
"invalid x.len()= {} must be a power of two",
n
);
let log_n: u32 = usize::BITS - ((n as usize) - 1).leading_zeros(); let log_n: u32 = usize::BITS - ((n as usize) - 1).leading_zeros();
let start: u32 = SKIPSTART as u32; let start: u32 = SKIPSTART as u32;
@@ -104,23 +113,46 @@ impl Table<u64>{
let t: usize = 2 * size; let t: usize = 2 * size;
if layer == log_n - 1 { if layer == log_n - 1 {
if LAZY { if LAZY {
izip!(a.chunks_exact_mut(t), &self.psi_forward_rev[m..]).for_each(|(a, psi)| { izip!(a.chunks_exact_mut(t), &self.psi_forward_rev[m..]).for_each(
|(a, psi)| {
let (a, b) = a.split_at_mut(size); let (a, b) = a.split_at_mut(size);
self.dit_inplace::<false>(&mut a[0], &mut b[0], *psi); self.dit_inplace::<false>(&mut a[0], &mut b[0], *psi);
debug_assert!(a[0] < self.two_q, "forward_inplace_core::<LAZY=true> output {} > {} (2q-1)", a[0], self.two_q-1); debug_assert!(
debug_assert!(b[0] < self.two_q, "forward_inplace_core::<LAZY=true> output {} > {} (2q-1)", b[0], self.two_q-1); a[0] < self.two_q,
}); "forward_inplace_core::<LAZY=true> output {} > {} (2q-1)",
a[0],
self.two_q - 1
);
debug_assert!(
b[0] < self.two_q,
"forward_inplace_core::<LAZY=true> output {} > {} (2q-1)",
b[0],
self.two_q - 1
);
},
);
} else { } else {
izip!(a.chunks_exact_mut(t), &self.psi_forward_rev[m..]).for_each(|(a, psi)| { izip!(a.chunks_exact_mut(t), &self.psi_forward_rev[m..]).for_each(
|(a, psi)| {
let (a, b) = a.split_at_mut(size); let (a, b) = a.split_at_mut(size);
self.dit_inplace::<true>(&mut a[0], &mut b[0], *psi); self.dit_inplace::<true>(&mut a[0], &mut b[0], *psi);
self.prime.barrett.reduce_assign::<BARRETT>(&mut a[0]); self.prime.barrett.reduce_assign::<BARRETT>(&mut a[0]);
self.prime.barrett.reduce_assign::<BARRETT>(&mut b[0]); self.prime.barrett.reduce_assign::<BARRETT>(&mut b[0]);
debug_assert!(a[0] < self.q, "forward_inplace_core::<LAZY=false> output {} > {} (q-1)", a[0], self.q-1); debug_assert!(
debug_assert!(b[0] < self.q, "forward_inplace_core::<LAZY=false> output {} > {} (q-1)", b[0], self.q-1); a[0] < self.q,
}); "forward_inplace_core::<LAZY=false> output {} > {} (q-1)",
a[0],
self.q - 1
);
debug_assert!(
b[0] < self.q,
"forward_inplace_core::<LAZY=false> output {} > {} (q-1)",
b[0],
self.q - 1
);
},
);
} }
} else if t >= 16 { } else if t >= 16 {
izip!(a.chunks_exact_mut(t), &self.psi_forward_rev[m..]).for_each(|(a, psi)| { izip!(a.chunks_exact_mut(t), &self.psi_forward_rev[m..]).for_each(|(a, psi)| {
let (a, b) = a.split_at_mut(size); let (a, b) = a.split_at_mut(size);
@@ -162,9 +194,16 @@ impl Table<u64>{
self.backward_inplace_core::<LAZY, 0, 0>(a); self.backward_inplace_core::<LAZY, 0, 0>(a);
} }
pub fn backward_inplace_core<const LAZY:bool, const SKIPSTART: u8, const SKIPEND: u8>(&self, a: &mut [u64]) { pub fn backward_inplace_core<const LAZY: bool, const SKIPSTART: u8, const SKIPEND: u8>(
&self,
a: &mut [u64],
) {
let n: usize = a.len(); let n: usize = a.len();
assert!(n & n-1 == 0, "invalid x.len()= {} must be a power of two", n); assert!(
n & n - 1 == 0,
"invalid x.len()= {} must be a power of two",
n
);
let log_n = usize::BITS - ((n as usize) - 1).leading_zeros(); let log_n = usize::BITS - ((n as usize) - 1).leading_zeros();
let start: u32 = SKIPEND as u32; let start: u32 = SKIPEND as u32;
@@ -174,12 +213,14 @@ impl Table<u64>{
let (m, size) = (1 << layer, 1 << (log_n - layer - 1)); let (m, size) = (1 << layer, 1 << (log_n - layer - 1));
let t: usize = 2 * size; let t: usize = 2 * size;
if layer == 0 { if layer == 0 {
let n_inv: Barrett<u64> = self.prime.barrett.prepare(self.prime.inv(n as u64)); let n_inv: Barrett<u64> = self.prime.barrett.prepare(self.prime.inv(n as u64));
let psi: Barrett<u64> = self.prime.barrett.prepare(self.prime.barrett.mul_external::<ONCE>(n_inv, self.psi_backward_rev[1].0)); let psi: Barrett<u64> = self.prime.barrett.prepare(
self.prime
.barrett
.mul_external::<ONCE>(n_inv, self.psi_backward_rev[1].0),
);
izip!(a.chunks_exact_mut(2 * size)).for_each( izip!(a.chunks_exact_mut(2 * size)).for_each(|a| {
|a| {
let (a, b) = a.split_at_mut(size); let (a, b) = a.split_at_mut(size);
izip!(a.chunks_exact_mut(8), b.chunks_exact_mut(8)).for_each(|(a, b)| { izip!(a.chunks_exact_mut(8), b.chunks_exact_mut(8)).for_each(|(a, b)| {
self.dif_last_inplace::<LAZY>(&mut a[0], &mut b[0], psi, n_inv); self.dif_last_inplace::<LAZY>(&mut a[0], &mut b[0], psi, n_inv);
@@ -191,12 +232,9 @@ impl Table<u64>{
self.dif_last_inplace::<LAZY>(&mut a[6], &mut b[6], psi, n_inv); self.dif_last_inplace::<LAZY>(&mut a[6], &mut b[6], psi, n_inv);
self.dif_last_inplace::<LAZY>(&mut a[7], &mut b[7], psi, n_inv); self.dif_last_inplace::<LAZY>(&mut a[7], &mut b[7], psi, n_inv);
}); });
}, });
);
} else if t >= 16 { } else if t >= 16 {
izip!(a.chunks_exact_mut(t), &self.psi_backward_rev[m..]).for_each( izip!(a.chunks_exact_mut(t), &self.psi_backward_rev[m..]).for_each(|(a, psi)| {
|(a, psi)| {
let (a, b) = a.split_at_mut(size); let (a, b) = a.split_at_mut(size);
izip!(a.chunks_exact_mut(8), b.chunks_exact_mut(8)).for_each(|(a, b)| { izip!(a.chunks_exact_mut(8), b.chunks_exact_mut(8)).for_each(|(a, b)| {
self.dif_inplace::<true>(&mut a[0], &mut b[0], *psi); self.dif_inplace::<true>(&mut a[0], &mut b[0], *psi);
@@ -208,8 +246,7 @@ impl Table<u64>{
self.dif_inplace::<true>(&mut a[6], &mut b[6], *psi); self.dif_inplace::<true>(&mut a[6], &mut b[6], *psi);
self.dif_inplace::<true>(&mut a[7], &mut b[7], *psi); self.dif_inplace::<true>(&mut a[7], &mut b[7], *psi);
}); });
}, });
);
} else { } else {
izip!(a.chunks_exact_mut(2 * size), &self.psi_backward_rev[m..]).for_each( izip!(a.chunks_exact_mut(2 * size), &self.psi_backward_rev[m..]).for_each(
|(a, psi)| { |(a, psi)| {
@@ -225,7 +262,10 @@ impl Table<u64>{
fn dif_inplace<const LAZY: bool>(&self, a: &mut u64, b: &mut u64, t: Barrett<u64>) { fn dif_inplace<const LAZY: bool>(&self, a: &mut u64, b: &mut u64, t: Barrett<u64>) {
debug_assert!(*a < self.two_q, "a:{} q:{}", a, self.two_q); debug_assert!(*a < self.two_q, "a:{} q:{}", a, self.two_q);
debug_assert!(*b < self.two_q, "b:{} q:{}", b, self.two_q); debug_assert!(*b < self.two_q, "b:{} q:{}", b, self.two_q);
let d: u64 = self.prime.barrett.mul_external::<NONE>(t, *a + self.two_q - *b); let d: u64 = self
.prime
.barrett
.mul_external::<NONE>(t, *a + self.two_q - *b);
*a = *a + *b; *a = *a + *b;
a.reduce_once_assign(self.two_q); a.reduce_once_assign(self.two_q);
*b = d; *b = d;
@@ -235,15 +275,27 @@ impl Table<u64>{
} }
} }
fn dif_last_inplace<const LAZY:bool>(&self, a: &mut u64, b: &mut u64, psi: Barrett<u64>, n_inv: Barrett<u64>){ fn dif_last_inplace<const LAZY: bool>(
&self,
a: &mut u64,
b: &mut u64,
psi: Barrett<u64>,
n_inv: Barrett<u64>,
) {
debug_assert!(*a < self.two_q); debug_assert!(*a < self.two_q);
debug_assert!(*b < self.two_q); debug_assert!(*b < self.two_q);
if LAZY { if LAZY {
let d: u64 = self.prime.barrett.mul_external::<NONE>(psi, *a + self.two_q - *b); let d: u64 = self
.prime
.barrett
.mul_external::<NONE>(psi, *a + self.two_q - *b);
*a = self.prime.barrett.mul_external::<NONE>(n_inv, *a + *b); *a = self.prime.barrett.mul_external::<NONE>(n_inv, *a + *b);
*b = d; *b = d;
} else { } else {
let d: u64 = self.prime.barrett.mul_external::<ONCE>(psi, *a + self.two_q - *b); let d: u64 = self
.prime
.barrett
.mul_external::<ONCE>(psi, *a + self.two_q - *b);
*a = self.prime.barrett.mul_external::<ONCE>(n_inv, *a + *b); *a = self.prime.barrett.mul_external::<ONCE>(n_inv, *a + *b);
*b = d; *b = d;
} }

View File

@@ -1,10 +1,10 @@
#![feature(bigint_helper_methods)] #![feature(bigint_helper_methods)]
#![feature(test)] #![feature(test)]
pub mod modulus;
pub mod dft; pub mod dft;
pub mod ring; pub mod modulus;
pub mod poly; pub mod poly;
pub mod ring;
pub mod scalar; pub mod scalar;
pub const CHUNK: usize = 8; pub const CHUNK: usize = 8;
@@ -13,9 +13,7 @@ pub mod macros{
#[macro_export] #[macro_export]
macro_rules! apply_v { macro_rules! apply_v {
($self:expr, $f:expr, $a:expr, $CHUNK:expr) => { ($self:expr, $f:expr, $a:expr, $CHUNK:expr) => {
match CHUNK { match CHUNK {
8 => { 8 => {
$a.chunks_exact_mut(8).for_each(|a| { $a.chunks_exact_mut(8).for_each(|a| {
@@ -34,7 +32,7 @@ pub mod macros{
$a[m..].iter_mut().for_each(|a| { $a[m..].iter_mut().for_each(|a| {
$f(&$self, a); $f(&$self, a);
}); });
}, }
_ => { _ => {
$a.iter_mut().for_each(|a| { $a.iter_mut().for_each(|a| {
$f(&$self, a); $f(&$self, a);
@@ -46,16 +44,21 @@ pub mod macros{
#[macro_export] #[macro_export]
macro_rules! apply_vv { macro_rules! apply_vv {
($self:expr, $f:expr, $a:expr, $b:expr, $CHUNK:expr) => { ($self:expr, $f:expr, $a:expr, $b:expr, $CHUNK:expr) => {
let n: usize = $a.len(); let n: usize = $a.len();
debug_assert!($b.len() == n, "invalid argument b: b.len() = {} != a.len() = {}", $b.len(), n); debug_assert!(
debug_assert!(CHUNK&(CHUNK-1) == 0, "invalid CHUNK const: not a power of two"); $b.len() == n,
"invalid argument b: b.len() = {} != a.len() = {}",
$b.len(),
n
);
debug_assert!(
CHUNK & (CHUNK - 1) == 0,
"invalid CHUNK const: not a power of two"
);
match CHUNK { match CHUNK {
8 => { 8 => {
izip!($a.chunks_exact(8), $b.chunks_exact_mut(8)).for_each(|(a, b)| { izip!($a.chunks_exact(8), $b.chunks_exact_mut(8)).for_each(|(a, b)| {
$f(&$self, &a[0], &mut b[0]); $f(&$self, &a[0], &mut b[0]);
$f(&$self, &a[1], &mut b[1]); $f(&$self, &a[1], &mut b[1]);
@@ -71,7 +74,7 @@ pub mod macros{
izip!($a[m..].iter(), $b[m..].iter_mut()).for_each(|(a, b)| { izip!($a[m..].iter(), $b[m..].iter_mut()).for_each(|(a, b)| {
$f(&$self, a, b); $f(&$self, a, b);
}); });
}, }
_ => { _ => {
izip!($a.iter(), $b.iter_mut()).for_each(|(a, b)| { izip!($a.iter(), $b.iter_mut()).for_each(|(a, b)| {
$f(&$self, a, b); $f(&$self, a, b);
@@ -83,18 +86,33 @@ pub mod macros{
#[macro_export] #[macro_export]
macro_rules! apply_vvv { macro_rules! apply_vvv {
($self:expr, $f:expr, $a:expr, $b:expr, $c:expr, $CHUNK:expr) => { ($self:expr, $f:expr, $a:expr, $b:expr, $c:expr, $CHUNK:expr) => {
let n: usize = $a.len(); let n: usize = $a.len();
debug_assert!($b.len() == n, "invalid argument b: b.len() = {} != a.len() = {}", $b.len(), n); debug_assert!(
debug_assert!($c.len() == n, "invalid argument c: b.len() = {} != a.len() = {}", $c.len(), n); $b.len() == n,
debug_assert!(CHUNK&(CHUNK-1) == 0, "invalid CHUNK const: not a power of two"); "invalid argument b: b.len() = {} != a.len() = {}",
$b.len(),
n
);
debug_assert!(
$c.len() == n,
"invalid argument c: b.len() = {} != a.len() = {}",
$c.len(),
n
);
debug_assert!(
CHUNK & (CHUNK - 1) == 0,
"invalid CHUNK const: not a power of two"
);
match CHUNK { match CHUNK {
8 => { 8 => {
izip!(
izip!($a.chunks_exact(8), $b.chunks_exact(8), $c.chunks_exact_mut(8)).for_each(|(a, b, c)| { $a.chunks_exact(8),
$b.chunks_exact(8),
$c.chunks_exact_mut(8)
)
.for_each(|(a, b, c)| {
$f(&$self, &a[0], &b[0], &mut c[0]); $f(&$self, &a[0], &b[0], &mut c[0]);
$f(&$self, &a[1], &b[1], &mut c[1]); $f(&$self, &a[1], &b[1], &mut c[1]);
$f(&$self, &a[2], &b[2], &mut c[2]); $f(&$self, &a[2], &b[2], &mut c[2]);
@@ -106,10 +124,12 @@ pub mod macros{
}); });
let m = n - (n & 7); let m = n - (n & 7);
izip!($a[m..].iter(), $b[m..].iter(), $c[m..].iter_mut()).for_each(|(a, b, c)| { izip!($a[m..].iter(), $b[m..].iter(), $c[m..].iter_mut()).for_each(
|(a, b, c)| {
$f(&$self, a, b, c); $f(&$self, a, b, c);
});
}, },
);
}
_ => { _ => {
izip!($a.iter(), $b.iter(), $c.iter_mut()).for_each(|(a, b, c)| { izip!($a.iter(), $b.iter(), $c.iter_mut()).for_each(|(a, b, c)| {
$f(&$self, a, b, c); $f(&$self, a, b, c);
@@ -121,16 +141,16 @@ pub mod macros{
#[macro_export] #[macro_export]
macro_rules! apply_sv { macro_rules! apply_sv {
($self:expr, $f:expr, $a:expr, $b:expr, $CHUNK:expr) => { ($self:expr, $f:expr, $a:expr, $b:expr, $CHUNK:expr) => {
let n: usize = $b.len(); let n: usize = $b.len();
debug_assert!(CHUNK&(CHUNK-1) == 0, "invalid CHUNK const: not a power of two"); debug_assert!(
CHUNK & (CHUNK - 1) == 0,
"invalid CHUNK const: not a power of two"
);
match CHUNK { match CHUNK {
8 => { 8 => {
izip!($b.chunks_exact_mut(8)).for_each(|b| { izip!($b.chunks_exact_mut(8)).for_each(|b| {
$f(&$self, $a, &mut b[0]); $f(&$self, $a, &mut b[0]);
$f(&$self, $a, &mut b[1]); $f(&$self, $a, &mut b[1]);
@@ -146,7 +166,7 @@ pub mod macros{
izip!($b[m..].iter_mut()).for_each(|b| { izip!($b[m..].iter_mut()).for_each(|b| {
$f(&$self, $a, b); $f(&$self, $a, b);
}); });
}, }
_ => { _ => {
izip!($b.iter_mut()).for_each(|b| { izip!($b.iter_mut()).for_each(|b| {
$f(&$self, $a, b); $f(&$self, $a, b);
@@ -158,16 +178,21 @@ pub mod macros{
#[macro_export] #[macro_export]
macro_rules! apply_svv { macro_rules! apply_svv {
($self:expr, $f:expr, $a:expr, $b:expr, $c:expr, $CHUNK:expr) => { ($self:expr, $f:expr, $a:expr, $b:expr, $c:expr, $CHUNK:expr) => {
let n: usize = $b.len(); let n: usize = $b.len();
debug_assert!($c.len() == n, "invalid argument c: c.len() = {} != b.len() = {}", $c.len(), n); debug_assert!(
debug_assert!(CHUNK&(CHUNK-1) == 0, "invalid CHUNK const: not a power of two"); $c.len() == n,
"invalid argument c: c.len() = {} != b.len() = {}",
$c.len(),
n
);
debug_assert!(
CHUNK & (CHUNK - 1) == 0,
"invalid CHUNK const: not a power of two"
);
match CHUNK { match CHUNK {
8 => { 8 => {
izip!($b.chunks_exact(8), $c.chunks_exact_mut(8)).for_each(|(b, c)| { izip!($b.chunks_exact(8), $c.chunks_exact_mut(8)).for_each(|(b, c)| {
$f(&$self, $a, &b[0], &mut c[0]); $f(&$self, $a, &b[0], &mut c[0]);
$f(&$self, $a, &b[1], &mut c[1]); $f(&$self, $a, &b[1], &mut c[1]);
@@ -183,7 +208,7 @@ pub mod macros{
izip!($b[m..].iter(), $c[m..].iter_mut()).for_each(|(b, c)| { izip!($b[m..].iter(), $c[m..].iter_mut()).for_each(|(b, c)| {
$f(&$self, $a, b, c); $f(&$self, $a, b, c);
}); });
}, }
_ => { _ => {
izip!($b.iter(), $c.iter_mut()).for_each(|(b, c)| { izip!($b.iter(), $c.iter_mut()).for_each(|(b, c)| {
$f(&$self, $a, b, c); $f(&$self, $a, b, c);
@@ -195,18 +220,33 @@ pub mod macros{
#[macro_export] #[macro_export]
macro_rules! apply_vvsv { macro_rules! apply_vvsv {
($self:expr, $f:expr, $a:expr, $b:expr, $c:expr, $d:expr, $CHUNK:expr) => { ($self:expr, $f:expr, $a:expr, $b:expr, $c:expr, $d:expr, $CHUNK:expr) => {
let n: usize = $a.len(); let n: usize = $a.len();
debug_assert!($b.len() == n, "invalid argument b: b.len() = {} != a.len() = {}", $b.len(), n); debug_assert!(
debug_assert!($d.len() == n, "invalid argument d: d.len() = {} != a.len() = {}", $d.len(), n); $b.len() == n,
debug_assert!(CHUNK&(CHUNK-1) == 0, "invalid CHUNK const: not a power of two"); "invalid argument b: b.len() = {} != a.len() = {}",
$b.len(),
n
);
debug_assert!(
$d.len() == n,
"invalid argument d: d.len() = {} != a.len() = {}",
$d.len(),
n
);
debug_assert!(
CHUNK & (CHUNK - 1) == 0,
"invalid CHUNK const: not a power of two"
);
match CHUNK { match CHUNK {
8 => { 8 => {
izip!(
izip!($a.chunks_exact(8), $b.chunks_exact(8), $d.chunks_exact_mut(8)).for_each(|(a, b, d)| { $a.chunks_exact(8),
$b.chunks_exact(8),
$d.chunks_exact_mut(8)
)
.for_each(|(a, b, d)| {
$f(&$self, &a[0], &b[0], $c, &mut d[0]); $f(&$self, &a[0], &b[0], $c, &mut d[0]);
$f(&$self, &a[1], &b[1], $c, &mut d[1]); $f(&$self, &a[1], &b[1], $c, &mut d[1]);
$f(&$self, &a[2], &b[2], $c, &mut d[2]); $f(&$self, &a[2], &b[2], $c, &mut d[2]);
@@ -218,10 +258,12 @@ pub mod macros{
}); });
let m = n - (n & 7); let m = n - (n & 7);
izip!($a[m..].iter(), $b[m..].iter(), $d[m..].iter_mut()).for_each(|(a, b, d)| { izip!($a[m..].iter(), $b[m..].iter(), $d[m..].iter_mut()).for_each(
|(a, b, d)| {
$f(&$self, a, b, $c, d); $f(&$self, a, b, $c, d);
});
}, },
);
}
_ => { _ => {
izip!($a.iter(), $b.iter(), $d.iter_mut()).for_each(|(a, b, d)| { izip!($a.iter(), $b.iter(), $d.iter_mut()).for_each(|(a, b, d)| {
$f(&$self, a, b, $c, d); $f(&$self, a, b, $c, d);
@@ -233,16 +275,21 @@ pub mod macros{
#[macro_export] #[macro_export]
macro_rules! apply_vsv { macro_rules! apply_vsv {
($self:expr, $f:expr, $a:expr, $c:expr, $b:expr, $CHUNK:expr) => { ($self:expr, $f:expr, $a:expr, $c:expr, $b:expr, $CHUNK:expr) => {
let n: usize = $a.len(); let n: usize = $a.len();
debug_assert!($b.len() == n, "invalid argument b: b.len() = {} != a.len() = {}", $b.len(), n); debug_assert!(
debug_assert!(CHUNK&(CHUNK-1) == 0, "invalid CHUNK const: not a power of two"); $b.len() == n,
"invalid argument b: b.len() = {} != a.len() = {}",
$b.len(),
n
);
debug_assert!(
CHUNK & (CHUNK - 1) == 0,
"invalid CHUNK const: not a power of two"
);
match CHUNK { match CHUNK {
8 => { 8 => {
izip!($a.chunks_exact(8), $b.chunks_exact_mut(8)).for_each(|(a, b)| { izip!($a.chunks_exact(8), $b.chunks_exact_mut(8)).for_each(|(a, b)| {
$f(&$self, &a[0], $c, &mut b[0]); $f(&$self, &a[0], $c, &mut b[0]);
$f(&$self, &a[1], $c, &mut b[1]); $f(&$self, &a[1], $c, &mut b[1]);
@@ -258,7 +305,7 @@ pub mod macros{
izip!($a[m..].iter(), $b[m..].iter_mut()).for_each(|(a, b)| { izip!($a[m..].iter(), $b[m..].iter_mut()).for_each(|(a, b)| {
$f(&$self, a, $c, b); $f(&$self, a, $c, b);
}); });
}, }
_ => { _ => {
izip!($a.iter(), $b.iter_mut()).for_each(|(a, b)| { izip!($a.iter(), $b.iter_mut()).for_each(|(a, b)| {
$f(&$self, a, $c, b); $f(&$self, a, $c, b);

View File

@@ -1,7 +1,7 @@
pub mod prime;
pub mod barrett; pub mod barrett;
pub mod montgomery;
pub mod impl_u64; pub mod impl_u64;
pub mod montgomery;
pub mod prime;
pub type REDUCEMOD = u8; pub type REDUCEMOD = u8;
@@ -64,7 +64,6 @@ pub trait ReduceOnce<O>{
} }
pub trait ScalarOperations<O> { pub trait ScalarOperations<O> {
// Applies a parameterized modular reduction. // Applies a parameterized modular reduction.
fn sa_reduce_into_sa<const REDUCE: REDUCEMOD>(&self, x: &mut O); fn sa_reduce_into_sa<const REDUCE: REDUCEMOD>(&self, x: &mut O);
@@ -87,41 +86,83 @@ pub trait ScalarOperations<O>{
fn sa_neg_into_sb<const REDUCE: REDUCEMOD>(&self, a: &O, b: &mut O); fn sa_neg_into_sb<const REDUCE: REDUCEMOD>(&self, a: &O, b: &mut O);
// Assigns a * 2^64 to b. // Assigns a * 2^64 to b.
fn sa_prep_mont_into_sb<const REDUCE:REDUCEMOD>(&self, a: &O, b: &mut montgomery::Montgomery<O>); fn sa_prep_mont_into_sb<const REDUCE: REDUCEMOD>(
&self,
a: &O,
b: &mut montgomery::Montgomery<O>,
);
// Assigns a * b to c. // Assigns a * b to c.
fn sa_mont_mul_sb_into_sc<const REDUCE:REDUCEMOD>(&self, a:&montgomery::Montgomery<O>, b:&O, c: &mut O); fn sa_mont_mul_sb_into_sc<const REDUCE: REDUCEMOD>(
&self,
a: &montgomery::Montgomery<O>,
b: &O,
c: &mut O,
);
// Assigns a * b to b. // Assigns a * b to b.
fn sa_mont_mul_sb_into_sb<const REDUCE:REDUCEMOD>(&self, a:&montgomery::Montgomery<O>, b:&mut O); fn sa_mont_mul_sb_into_sb<const REDUCE: REDUCEMOD>(
&self,
a: &montgomery::Montgomery<O>,
b: &mut O,
);
// Assigns a * b to c. // Assigns a * b to c.
fn sa_barrett_mul_sb_into_sc<const REDUCE:REDUCEMOD>(&self, a: &barrett::Barrett<O>, b:&O, c: &mut O); fn sa_barrett_mul_sb_into_sc<const REDUCE: REDUCEMOD>(
&self,
a: &barrett::Barrett<O>,
b: &O,
c: &mut O,
);
// Assigns a * b to b. // Assigns a * b to b.
fn sa_barrett_mul_sb_into_sb<const REDUCE:REDUCEMOD>(&self, a:&barrett::Barrett<O>, b:&mut O); fn sa_barrett_mul_sb_into_sb<const REDUCE: REDUCEMOD>(
&self,
a: &barrett::Barrett<O>,
b: &mut O,
);
// Assigns (a + 2q - b) * c to d. // Assigns (a + 2q - b) * c to d.
fn sa_sub_sb_mul_sc_into_sd<const REDUCE:REDUCEMOD>(&self, a: &O, b: &O, c: &barrett::Barrett<O>, d: &mut O); fn sa_sub_sb_mul_sc_into_sd<const REDUCE: REDUCEMOD>(
&self,
a: &O,
b: &O,
c: &barrett::Barrett<O>,
d: &mut O,
);
// Assigns (a + 2q - b) * c to b. // Assigns (a + 2q - b) * c to b.
fn sa_sub_sb_mul_sc_into_sb<const REDUCE:REDUCEMOD>(&self, a: &u64, c: &barrett::Barrett<u64>, b: &mut u64); fn sa_sub_sb_mul_sc_into_sb<const REDUCE: REDUCEMOD>(
&self,
a: &u64,
c: &barrett::Barrett<u64>,
b: &mut u64,
);
} }
pub trait VectorOperations<O> { pub trait VectorOperations<O> {
// Applies a parameterized modular reduction. // Applies a parameterized modular reduction.
fn va_reduce_into_va<const CHUNK: usize, const REDUCE: REDUCEMOD>(&self, x: &mut [O]); fn va_reduce_into_va<const CHUNK: usize, const REDUCE: REDUCEMOD>(&self, x: &mut [O]);
// ADD // ADD
// Assigns a[i] + b[i] to c[i] // Assigns a[i] + b[i] to c[i]
fn va_add_vb_into_vc<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a: &[O], b:&[O], c: &mut [O]); fn va_add_vb_into_vc<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[O],
b: &[O],
c: &mut [O],
);
// Assigns a[i] + b[i] to b[i] // Assigns a[i] + b[i] to b[i]
fn va_add_vb_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(&self, a: &[O], b: &mut [O]); fn va_add_vb_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(&self, a: &[O], b: &mut [O]);
// Assigns a[i] + b to c[i] // Assigns a[i] + b to c[i]
fn va_add_sb_into_vc<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a: &[O], b:&O, c:&mut [O]); fn va_add_sb_into_vc<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[O],
b: &O,
c: &mut [O],
);
// Assigns b[i] + a to b[i] // Assigns b[i] + a to b[i]
fn sa_add_vb_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(&self, a: &O, b: &mut [O]); fn sa_add_vb_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(&self, a: &O, b: &mut [O]);
@@ -131,7 +172,12 @@ pub trait VectorOperations<O>{
fn va_sub_vb_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(&self, a: &[O], b: &mut [O]); fn va_sub_vb_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(&self, a: &[O], b: &mut [O]);
// Assigns a[i] - b[i] to c[i] // Assigns a[i] - b[i] to c[i]
fn va_sub_vb_into_vc<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a: &[O], b:&[O], c: &mut [O]); fn va_sub_vb_into_vc<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[O],
b: &[O],
c: &mut [O],
);
// NEG // NEG
// Assigns -a[i] to a[i]. // Assigns -a[i] to a[i].
@@ -142,29 +188,58 @@ pub trait VectorOperations<O>{
// MUL MONTGOMERY // MUL MONTGOMERY
// Assigns a * 2^64 to b. // Assigns a * 2^64 to b.
fn va_prep_mont_into_vb<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a: &[O], b: &mut [montgomery::Montgomery<O>]); fn va_prep_mont_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[O],
b: &mut [montgomery::Montgomery<O>],
);
// Assigns a[i] * b[i] to c[i]. // Assigns a[i] * b[i] to c[i].
fn va_mont_mul_vb_into_vc<const CHUNK:usize,const REDUCE:REDUCEMOD>(&self, a:&[montgomery::Montgomery<O>], b:&[O], c: &mut [O]); fn va_mont_mul_vb_into_vc<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[montgomery::Montgomery<O>],
b: &[O],
c: &mut [O],
);
// Assigns a[i] * b[i] to b[i]. // Assigns a[i] * b[i] to b[i].
fn va_mont_mul_vb_into_vb<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a:&[montgomery::Montgomery<O>], b:&mut [O]); fn va_mont_mul_vb_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[montgomery::Montgomery<O>],
b: &mut [O],
);
// MUL BARRETT // MUL BARRETT
// Assigns a * b[i] to b[i]. // Assigns a * b[i] to b[i].
fn sa_barrett_mul_vb_into_vb<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a:& barrett::Barrett<u64>, b:&mut [u64]); fn sa_barrett_mul_vb_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &barrett::Barrett<u64>,
b: &mut [u64],
);
// Assigns a * b[i] to c[i]. // Assigns a * b[i] to c[i].
fn sa_barrett_mul_vb_into_vc<const CHUNK:usize,const REDUCE:REDUCEMOD>(&self, a:& barrett::Barrett<u64>, b:&[u64], c: &mut [u64]); fn sa_barrett_mul_vb_into_vc<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &barrett::Barrett<u64>,
b: &[u64],
c: &mut [u64],
);
// OTHERS // OTHERS
// Assigns (a[i] + 2q - b[i]) * c to d[i]. // Assigns (a[i] + 2q - b[i]) * c to d[i].
fn va_sub_vb_mul_sc_into_vd<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a: &[u64], b: &[u64], c: &barrett::Barrett<u64>, d: &mut [u64]); fn va_sub_vb_mul_sc_into_vd<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[u64],
b: &[u64],
c: &barrett::Barrett<u64>,
d: &mut [u64],
);
// Assigns (a[i] + 2q - b[i]) * c to b[i]. // Assigns (a[i] + 2q - b[i]) * c to b[i].
fn va_sub_vb_mul_sc_into_vb<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a: &[u64], c: &barrett::Barrett<u64>, b: &mut [u64]); fn va_sub_vb_mul_sc_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[u64],
c: &barrett::Barrett<u64>,
b: &mut [u64],
);
} }

View File

@@ -2,7 +2,6 @@
pub struct Barrett<O>(pub O, pub O); pub struct Barrett<O>(pub O, pub O);
impl<O> Barrett<O> { impl<O> Barrett<O> {
#[inline(always)] #[inline(always)]
pub fn value(&self) -> &O { pub fn value(&self) -> &O {
&self.0 &self.0
@@ -25,7 +24,6 @@ pub struct BarrettPrecomp<O>{
} }
impl<O> BarrettPrecomp<O> { impl<O> BarrettPrecomp<O> {
#[inline(always)] #[inline(always)]
pub fn value_hi(&self) -> &O { pub fn value_hi(&self) -> &O {
&self.hi &self.hi
@@ -36,4 +34,3 @@ impl<O> BarrettPrecomp<O>{
&self.lo &self.lo
} }
} }

View File

@@ -1,17 +1,24 @@
use crate::modulus::barrett::{Barrett, BarrettPrecomp}; use crate::modulus::barrett::{Barrett, BarrettPrecomp};
use crate::modulus::ReduceOnce; use crate::modulus::ReduceOnce;
use crate::modulus::{REDUCEMOD, NONE, ONCE, TWICE, FOURTIMES, BARRETT, BARRETTLAZY}; use crate::modulus::{BARRETT, BARRETTLAZY, FOURTIMES, NONE, ONCE, REDUCEMOD, TWICE};
use num_bigint::BigUint; use num_bigint::BigUint;
use num_traits::cast::ToPrimitive; use num_traits::cast::ToPrimitive;
impl BarrettPrecomp<u64> { impl BarrettPrecomp<u64> {
pub fn new(q: u64) -> BarrettPrecomp<u64> { pub fn new(q: u64) -> BarrettPrecomp<u64> {
let big_r: BigUint = (BigUint::from(1 as usize)<<((u64::BITS<<1) as usize)) / BigUint::from(q); let big_r: BigUint =
(BigUint::from(1 as usize) << ((u64::BITS << 1) as usize)) / BigUint::from(q);
let lo: u64 = (&big_r & BigUint::from(u64::MAX)).to_u64().unwrap(); let lo: u64 = (&big_r & BigUint::from(u64::MAX)).to_u64().unwrap();
let hi: u64 = (big_r >> u64::BITS).to_u64().unwrap(); let hi: u64 = (big_r >> u64::BITS).to_u64().unwrap();
let mut precomp: BarrettPrecomp<u64> = Self{q:q, two_q:q<<1, four_q:q<<2, lo:lo, hi:hi, one:Barrett(0,0)}; let mut precomp: BarrettPrecomp<u64> = Self {
q: q,
two_q: q << 1,
four_q: q << 2,
lo: lo,
hi: hi,
one: Barrett(0, 0),
};
precomp.one = precomp.prepare(1); precomp.one = precomp.prepare(1);
precomp precomp
} }
@@ -24,20 +31,20 @@ impl BarrettPrecomp<u64>{
#[inline(always)] #[inline(always)]
pub fn reduce_assign<const REDUCE: REDUCEMOD>(&self, x: &mut u64) { pub fn reduce_assign<const REDUCE: REDUCEMOD>(&self, x: &mut u64) {
match REDUCE { match REDUCE {
NONE =>{}, NONE => {}
ONCE =>{x.reduce_once_assign(self.q)}, ONCE => x.reduce_once_assign(self.q),
TWICE=>{x.reduce_once_assign(self.two_q)}, TWICE => x.reduce_once_assign(self.two_q),
FOURTIMES =>{x.reduce_once_assign(self.four_q)}, FOURTIMES => x.reduce_once_assign(self.four_q),
BARRETT => { BARRETT => {
let (_, mhi) = x.widening_mul(self.hi); let (_, mhi) = x.widening_mul(self.hi);
*x = *x - mhi.wrapping_mul(self.q); *x = *x - mhi.wrapping_mul(self.q);
x.reduce_once_assign(self.q); x.reduce_once_assign(self.q);
}, }
BARRETTLAZY => { BARRETTLAZY => {
let (_, mhi) = x.widening_mul(self.hi); let (_, mhi) = x.widening_mul(self.hi);
*x = *x - mhi.wrapping_mul(self.q) *x = *x - mhi.wrapping_mul(self.q)
}, }
_ => unreachable!("invalid REDUCE argument") _ => unreachable!("invalid REDUCE argument"),
} }
} }

View File

@@ -1,7 +1,7 @@
pub mod prime;
pub mod barrett; pub mod barrett;
pub mod montgomery; pub mod montgomery;
pub mod operations; pub mod operations;
pub mod prime;
use crate::modulus::ReduceOnce; use crate::modulus::ReduceOnce;

View File

@@ -1,19 +1,21 @@
use crate::modulus::ReduceOnce;
use crate::modulus::montgomery::{MontgomeryPrecomp, Montgomery};
use crate::modulus::barrett::BarrettPrecomp; use crate::modulus::barrett::BarrettPrecomp;
use crate::modulus::{REDUCEMOD, ONCE}; use crate::modulus::montgomery::{Montgomery, MontgomeryPrecomp};
use crate::modulus::ReduceOnce;
use crate::modulus::{ONCE, REDUCEMOD};
extern crate test; extern crate test;
/// MontgomeryPrecomp is a set of methods implemented for MontgomeryPrecomp<u64> /// MontgomeryPrecomp is a set of methods implemented for MontgomeryPrecomp<u64>
/// enabling Montgomery arithmetic over u64 values. /// enabling Montgomery arithmetic over u64 values.
impl MontgomeryPrecomp<u64> { impl MontgomeryPrecomp<u64> {
/// Returns an new instance of MontgomeryPrecomp<u64>. /// Returns an new instance of MontgomeryPrecomp<u64>.
/// This method will fail if gcd(q, 2^64) != 1. /// This method will fail if gcd(q, 2^64) != 1.
#[inline(always)] #[inline(always)]
pub fn new(q: u64) -> MontgomeryPrecomp<u64> { pub fn new(q: u64) -> MontgomeryPrecomp<u64> {
assert!(q & 1 != 0, "Invalid argument: gcd(q={}, radix=2^64) != 1", q); assert!(
q & 1 != 0,
"Invalid argument: gcd(q={}, radix=2^64) != 1",
q
);
let mut q_inv: u64 = 1; let mut q_inv: u64 = 1;
let mut q_pow = q; let mut q_pow = q;
for _i in 0..63 { for _i in 0..63 {
@@ -80,7 +82,9 @@ impl MontgomeryPrecomp<u64>{
#[inline(always)] #[inline(always)]
pub fn prepare_assign<const REDUCE: REDUCEMOD>(&self, lhs: u64, rhs: &mut Montgomery<u64>) { pub fn prepare_assign<const REDUCE: REDUCEMOD>(&self, lhs: u64, rhs: &mut Montgomery<u64>) {
let (_, mhi) = lhs.widening_mul(*self.barrett.value_lo()); let (_, mhi) = lhs.widening_mul(*self.barrett.value_lo());
*rhs = (lhs.wrapping_mul(*self.barrett.value_hi()).wrapping_add(mhi)).wrapping_mul(self.q).wrapping_neg(); *rhs = (lhs.wrapping_mul(*self.barrett.value_hi()).wrapping_add(mhi))
.wrapping_mul(self.q)
.wrapping_neg();
self.reduce_assign::<REDUCE>(rhs); self.reduce_assign::<REDUCE>(rhs);
} }
@@ -109,7 +113,11 @@ impl MontgomeryPrecomp<u64>{
/// Assigns lhs * rhs * (2^{64})^-1 mod q to rhs. /// Assigns lhs * rhs * (2^{64})^-1 mod q to rhs.
#[inline(always)] #[inline(always)]
pub fn mul_external_assign<const REDUCE:REDUCEMOD>(&self, lhs: Montgomery<u64>, rhs: &mut u64){ pub fn mul_external_assign<const REDUCE: REDUCEMOD>(
&self,
lhs: Montgomery<u64>,
rhs: &mut u64,
) {
let (mlo, mhi) = lhs.widening_mul(*rhs); let (mlo, mhi) = lhs.widening_mul(*rhs);
let (_, hhi) = self.q.widening_mul(mlo.wrapping_mul(self.q_inv)); let (_, hhi) = self.q.widening_mul(mlo.wrapping_mul(self.q_inv));
*rhs = self.reduce::<REDUCE>(mhi.wrapping_sub(hhi).wrapping_add(self.q)); *rhs = self.reduce::<REDUCE>(mhi.wrapping_sub(hhi).wrapping_add(self.q));
@@ -117,13 +125,21 @@ impl MontgomeryPrecomp<u64>{
/// Returns lhs * rhs * (2^{64})^-1 mod q in range [0, 2q-1]. /// Returns lhs * rhs * (2^{64})^-1 mod q in range [0, 2q-1].
#[inline(always)] #[inline(always)]
pub fn mul_internal<const REDUCE:REDUCEMOD>(&self, lhs: Montgomery<u64>, rhs: Montgomery<u64>) -> Montgomery<u64>{ pub fn mul_internal<const REDUCE: REDUCEMOD>(
&self,
lhs: Montgomery<u64>,
rhs: Montgomery<u64>,
) -> Montgomery<u64> {
self.mul_external::<REDUCE>(lhs, rhs) self.mul_external::<REDUCE>(lhs, rhs)
} }
/// Assigns lhs * rhs * (2^{64})^-1 mod q to rhs. /// Assigns lhs * rhs * (2^{64})^-1 mod q to rhs.
#[inline(always)] #[inline(always)]
pub fn mul_internal_assign<const REDUCE:REDUCEMOD>(&self, lhs: Montgomery<u64>, rhs: &mut Montgomery<u64>){ pub fn mul_internal_assign<const REDUCE: REDUCEMOD>(
&self,
lhs: Montgomery<u64>,
rhs: &mut Montgomery<u64>,
) {
self.mul_external_assign::<REDUCE>(lhs, rhs); self.mul_external_assign::<REDUCE>(lhs, rhs);
} }
@@ -140,7 +156,11 @@ impl MontgomeryPrecomp<u64>{
/// Assigns lhs + rhs - q if (lhs + rhs) >= q to rhs. /// Assigns lhs + rhs - q if (lhs + rhs) >= q to rhs.
#[inline(always)] #[inline(always)]
pub fn add_internal_reduce_once_assign<const LAZY:bool>(&self, lhs: Montgomery<u64>, rhs: &mut Montgomery<u64>){ pub fn add_internal_reduce_once_assign<const LAZY: bool>(
&self,
lhs: Montgomery<u64>,
rhs: &mut Montgomery<u64>,
) {
self.add_internal_lazy_assign(lhs, rhs); self.add_internal_lazy_assign(lhs, rhs);
rhs.reduce_once_assign(self.q); rhs.reduce_once_assign(self.q);
} }
@@ -166,8 +186,8 @@ impl MontgomeryPrecomp<u64>{
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::modulus::montgomery;
use super::*; use super::*;
use crate::modulus::montgomery;
use test::Bencher; use test::Bencher;
#[test] #[test]
@@ -177,7 +197,9 @@ mod tests {
let x: u64 = 0x5f876e514845cc8b; let x: u64 = 0x5f876e514845cc8b;
let y: u64 = 0xad726f98f24a761a; let y: u64 = 0xad726f98f24a761a;
let y_mont = m_precomp.prepare::<ONCE>(y); let y_mont = m_precomp.prepare::<ONCE>(y);
assert!(m_precomp.mul_external::<ONCE>(y_mont, x) == (x as u128 * y as u128 % q as u128) as u64); assert!(
m_precomp.mul_external::<ONCE>(y_mont, x) == (x as u128 * y as u128 % q as u128) as u64
);
} }
#[bench] #[bench]

View File

@@ -1,15 +1,13 @@
use crate::modulus::barrett::Barrett;
use crate::modulus::{ScalarOperations, VectorOperations}; use crate::modulus::montgomery::Montgomery;
use crate::modulus::prime::Prime; use crate::modulus::prime::Prime;
use crate::modulus::ReduceOnce; use crate::modulus::ReduceOnce;
use crate::modulus::montgomery::Montgomery;
use crate::modulus::barrett::Barrett;
use crate::modulus::REDUCEMOD; use crate::modulus::REDUCEMOD;
use crate::{apply_v, apply_vv, apply_vvv, apply_sv, apply_svv, apply_vvsv, apply_vsv}; use crate::modulus::{ScalarOperations, VectorOperations};
use crate::{apply_sv, apply_svv, apply_v, apply_vsv, apply_vv, apply_vvsv, apply_vvv};
use itertools::izip; use itertools::izip;
impl ScalarOperations<u64> for Prime<u64> { impl ScalarOperations<u64> for Prime<u64> {
/// Applies a modular reduction on x based on REDUCE: /// Applies a modular reduction on x based on REDUCE:
/// - LAZY: no modular reduction. /// - LAZY: no modular reduction.
/// - ONCE: subtracts q if x >= q. /// - ONCE: subtracts q if x >= q.
@@ -62,7 +60,12 @@ impl ScalarOperations<u64> for Prime<u64>{
} }
#[inline(always)] #[inline(always)]
fn sa_mont_mul_sb_into_sc<const REDUCE:REDUCEMOD>(&self, a: &Montgomery<u64>, b:&u64, c: &mut u64){ fn sa_mont_mul_sb_into_sc<const REDUCE: REDUCEMOD>(
&self,
a: &Montgomery<u64>,
b: &u64,
c: &mut u64,
) {
*c = self.montgomery.mul_external::<REDUCE>(*a, *b); *c = self.montgomery.mul_external::<REDUCE>(*a, *b);
} }
@@ -72,7 +75,12 @@ impl ScalarOperations<u64> for Prime<u64>{
} }
#[inline(always)] #[inline(always)]
fn sa_barrett_mul_sb_into_sc<const REDUCE:REDUCEMOD>(&self, a: &Barrett<u64>, b:&u64, c: &mut u64){ fn sa_barrett_mul_sb_into_sc<const REDUCE: REDUCEMOD>(
&self,
a: &Barrett<u64>,
b: &u64,
c: &mut u64,
) {
*c = self.barrett.mul_external::<REDUCE>(*a, *b); *c = self.barrett.mul_external::<REDUCE>(*a, *b);
} }
@@ -82,20 +90,30 @@ impl ScalarOperations<u64> for Prime<u64>{
} }
#[inline(always)] #[inline(always)]
fn sa_sub_sb_mul_sc_into_sd<const REDUCE:REDUCEMOD>(&self, a: &u64, b: &u64, c: &Barrett<u64>, d: &mut u64){ fn sa_sub_sb_mul_sc_into_sd<const REDUCE: REDUCEMOD>(
&self,
a: &u64,
b: &u64,
c: &Barrett<u64>,
d: &mut u64,
) {
*d = self.two_q.wrapping_sub(*b).wrapping_add(*a); *d = self.two_q.wrapping_sub(*b).wrapping_add(*a);
self.barrett.mul_external_assign::<REDUCE>(*c, d); self.barrett.mul_external_assign::<REDUCE>(*c, d);
} }
#[inline(always)] #[inline(always)]
fn sa_sub_sb_mul_sc_into_sb<const REDUCE:REDUCEMOD>(&self, a: &u64, c: &Barrett<u64>, b: &mut u64){ fn sa_sub_sb_mul_sc_into_sb<const REDUCE: REDUCEMOD>(
&self,
a: &u64,
c: &Barrett<u64>,
b: &mut u64,
) {
*b = self.two_q.wrapping_sub(*b).wrapping_add(*a); *b = self.two_q.wrapping_sub(*b).wrapping_add(*a);
self.barrett.mul_external_assign::<REDUCE>(*c, b); self.barrett.mul_external_assign::<REDUCE>(*c, b);
} }
} }
impl VectorOperations<u64> for Prime<u64> { impl VectorOperations<u64> for Prime<u64> {
/// Applies a modular reduction on x based on REDUCE: /// Applies a modular reduction on x based on REDUCE:
/// - LAZY: no modular reduction. /// - LAZY: no modular reduction.
/// - ONCE: subtracts q if x >= q. /// - ONCE: subtracts q if x >= q.
@@ -109,32 +127,59 @@ impl VectorOperations<u64> for Prime<u64>{
} }
#[inline(always)] #[inline(always)]
fn va_add_vb_into_vc<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a: &[u64], b:&[u64], c:&mut [u64]){ fn va_add_vb_into_vc<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[u64],
b: &[u64],
c: &mut [u64],
) {
apply_vvv!(self, Self::sa_add_sb_into_sc::<REDUCE>, a, b, c, CHUNK); apply_vvv!(self, Self::sa_add_sb_into_sc::<REDUCE>, a, b, c, CHUNK);
} }
#[inline(always)] #[inline(always)]
fn va_add_vb_into_vb<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a: &[u64], b:&mut [u64]){ fn va_add_vb_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[u64],
b: &mut [u64],
) {
apply_vv!(self, Self::sa_add_sb_into_sb::<REDUCE>, a, b, CHUNK); apply_vv!(self, Self::sa_add_sb_into_sb::<REDUCE>, a, b, CHUNK);
} }
#[inline(always)] #[inline(always)]
fn va_add_sb_into_vc<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a: &[u64], b:&u64, c:&mut [u64]){ fn va_add_sb_into_vc<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[u64],
b: &u64,
c: &mut [u64],
) {
apply_vsv!(self, Self::sa_add_sb_into_sc::<REDUCE>, a, b, c, CHUNK); apply_vsv!(self, Self::sa_add_sb_into_sc::<REDUCE>, a, b, c, CHUNK);
} }
#[inline(always)] #[inline(always)]
fn sa_add_vb_into_vb<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a:&u64, b:&mut [u64]){ fn sa_add_vb_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &u64,
b: &mut [u64],
) {
apply_sv!(self, Self::sa_add_sb_into_sb::<REDUCE>, a, b, CHUNK); apply_sv!(self, Self::sa_add_sb_into_sb::<REDUCE>, a, b, CHUNK);
} }
#[inline(always)] #[inline(always)]
fn va_sub_vb_into_vc<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a: &[u64], b:&[u64], c:&mut [u64]){ fn va_sub_vb_into_vc<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[u64],
b: &[u64],
c: &mut [u64],
) {
apply_vvv!(self, Self::sa_sub_sb_into_sc::<REDUCE>, a, b, c, CHUNK); apply_vvv!(self, Self::sa_sub_sb_into_sc::<REDUCE>, a, b, c, CHUNK);
} }
#[inline(always)] #[inline(always)]
fn va_sub_vb_into_vb<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a: &[u64], b:&mut [u64]){ fn va_sub_vb_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[u64],
b: &mut [u64],
) {
apply_vv!(self, Self::sa_sub_sb_into_sb::<REDUCE>, a, b, CHUNK); apply_vv!(self, Self::sa_sub_sb_into_sb::<REDUCE>, a, b, CHUNK);
} }
@@ -144,40 +189,99 @@ impl VectorOperations<u64> for Prime<u64>{
} }
#[inline(always)] #[inline(always)]
fn va_neg_into_vb<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a: &[u64], b: &mut [u64]){ fn va_neg_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[u64],
b: &mut [u64],
) {
apply_vv!(self, Self::sa_neg_into_sb::<REDUCE>, a, b, CHUNK); apply_vv!(self, Self::sa_neg_into_sb::<REDUCE>, a, b, CHUNK);
} }
#[inline(always)] #[inline(always)]
fn va_prep_mont_into_vb<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a: &[u64], b: &mut [Montgomery<u64>]){ fn va_prep_mont_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[u64],
b: &mut [Montgomery<u64>],
) {
apply_vv!(self, Self::sa_prep_mont_into_sb::<REDUCE>, a, b, CHUNK); apply_vv!(self, Self::sa_prep_mont_into_sb::<REDUCE>, a, b, CHUNK);
} }
#[inline(always)] #[inline(always)]
fn va_mont_mul_vb_into_vc<const CHUNK:usize,const REDUCE:REDUCEMOD>(&self, a:& [Montgomery<u64>], b:&[u64], c: &mut [u64]){ fn va_mont_mul_vb_into_vc<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[Montgomery<u64>],
b: &[u64],
c: &mut [u64],
) {
apply_vvv!(self, Self::sa_mont_mul_sb_into_sc::<REDUCE>, a, b, c, CHUNK); apply_vvv!(self, Self::sa_mont_mul_sb_into_sc::<REDUCE>, a, b, c, CHUNK);
} }
#[inline(always)] #[inline(always)]
fn va_mont_mul_vb_into_vb<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a:& [Montgomery<u64>], b:&mut [u64]){ fn va_mont_mul_vb_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &[Montgomery<u64>],
b: &mut [u64],
) {
apply_vv!(self, Self::sa_mont_mul_sb_into_sb::<REDUCE>, a, b, CHUNK); apply_vv!(self, Self::sa_mont_mul_sb_into_sb::<REDUCE>, a, b, CHUNK);
} }
#[inline(always)] #[inline(always)]
fn sa_barrett_mul_vb_into_vc<const CHUNK:usize,const REDUCE:REDUCEMOD>(&self, a:& Barrett<u64>, b:&[u64], c: &mut [u64]){ fn sa_barrett_mul_vb_into_vc<const CHUNK: usize, const REDUCE: REDUCEMOD>(
apply_svv!(self, Self::sa_barrett_mul_sb_into_sc::<REDUCE>, a, b, c, CHUNK); &self,
a: &Barrett<u64>,
b: &[u64],
c: &mut [u64],
) {
apply_svv!(
self,
Self::sa_barrett_mul_sb_into_sc::<REDUCE>,
a,
b,
c,
CHUNK
);
} }
#[inline(always)] #[inline(always)]
fn sa_barrett_mul_vb_into_vb<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a:& Barrett<u64>, b:&mut [u64]){ fn sa_barrett_mul_vb_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(
&self,
a: &Barrett<u64>,
b: &mut [u64],
) {
apply_sv!(self, Self::sa_barrett_mul_sb_into_sb::<REDUCE>, a, b, CHUNK); apply_sv!(self, Self::sa_barrett_mul_sb_into_sb::<REDUCE>, a, b, CHUNK);
} }
fn va_sub_vb_mul_sc_into_vd<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a: &[u64], b: &[u64], c: &Barrett<u64>, d: &mut [u64]){ fn va_sub_vb_mul_sc_into_vd<const CHUNK: usize, const REDUCE: REDUCEMOD>(
apply_vvsv!(self, Self::sa_sub_sb_mul_sc_into_sd::<REDUCE>, a, b, c, d, CHUNK); &self,
a: &[u64],
b: &[u64],
c: &Barrett<u64>,
d: &mut [u64],
) {
apply_vvsv!(
self,
Self::sa_sub_sb_mul_sc_into_sd::<REDUCE>,
a,
b,
c,
d,
CHUNK
);
} }
fn va_sub_vb_mul_sc_into_vb<const CHUNK:usize, const REDUCE:REDUCEMOD>(&self, a: &[u64], b: &Barrett<u64>, c: &mut [u64]){ fn va_sub_vb_mul_sc_into_vb<const CHUNK: usize, const REDUCE: REDUCEMOD>(
apply_vsv!(self, Self::sa_sub_sb_mul_sc_into_sb::<REDUCE>, a, b, c, CHUNK); &self,
a: &[u64],
b: &Barrett<u64>,
c: &mut [u64],
) {
apply_vsv!(
self,
Self::sa_sub_sb_mul_sc_into_sb::<REDUCE>,
a,
b,
c,
CHUNK
);
} }
} }

View File

@@ -1,12 +1,11 @@
use crate::modulus::prime::Prime;
use crate::modulus::montgomery::{Montgomery, MontgomeryPrecomp};
use crate::modulus::barrett::BarrettPrecomp; use crate::modulus::barrett::BarrettPrecomp;
use crate::modulus::montgomery::{Montgomery, MontgomeryPrecomp};
use crate::modulus::prime::Prime;
use crate::modulus::ONCE; use crate::modulus::ONCE;
use primality_test::is_prime; use primality_test::is_prime;
use prime_factorization::Factorization; use prime_factorization::Factorization;
impl Prime<u64> { impl Prime<u64> {
/// Returns a new instance of Prime<u64>. /// Returns a new instance of Prime<u64>.
/// Panics if q_base is not a prime > 2 and /// Panics if q_base is not a prime > 2 and
/// if q_base^q_power would overflow u64. /// if q_base^q_power would overflow u64.
@@ -19,7 +18,6 @@ impl Prime<u64>{
/// Does not check if q_base is a prime > 2. /// Does not check if q_base is a prime > 2.
/// Panics if q_base^q_power would overflow u64. /// Panics if q_base^q_power would overflow u64.
pub fn new_unchecked(q_base: u64, q_power: usize) -> Self { pub fn new_unchecked(q_base: u64, q_power: usize) -> Self {
let mut q = q_base; let mut q = q_base;
for _i in 1..q_power { for _i in 1..q_power {
q *= q_base q *= q_base
@@ -47,7 +45,6 @@ impl Prime<u64>{
prime.check_factors(); prime.check_factors();
prime prime
} }
pub fn q(&self) -> u64 { pub fn q(&self) -> u64 {
@@ -70,13 +67,14 @@ impl Prime<u64>{
let mut i: u64 = exponent; let mut i: u64 = exponent;
while i > 0 { while i > 0 {
if i & 1 == 1 { if i & 1 == 1 {
self.montgomery.mul_internal_assign::<ONCE>(x_mont, &mut y_mont); self.montgomery
.mul_internal_assign::<ONCE>(x_mont, &mut y_mont);
} }
self.montgomery.mul_internal_assign::<ONCE>(x_mont, &mut x_mont); self.montgomery
.mul_internal_assign::<ONCE>(x_mont, &mut x_mont);
i >>= 1; i >>= 1;
} }
self.montgomery.unprepare::<ONCE>(y_mont) self.montgomery.unprepare::<ONCE>(y_mont)
@@ -93,19 +91,16 @@ impl Prime<u64>{
impl Prime<u64> { impl Prime<u64> {
/// Returns the smallest nth primitive root of q_base. /// Returns the smallest nth primitive root of q_base.
pub fn primitive_root(&self) -> u64 { pub fn primitive_root(&self) -> u64 {
let mut candidate: u64 = 1u64; let mut candidate: u64 = 1u64;
let mut not_found: bool = true; let mut not_found: bool = true;
while not_found { while not_found {
candidate += 1; candidate += 1;
for &factor in &self.factors { for &factor in &self.factors {
if pow(candidate, (self.q_base - 1) / factor, self.q_base) == 1 { if pow(candidate, (self.q_base - 1) / factor, self.q_base) == 1 {
not_found = true; not_found = true;
break break;
} }
not_found = false; not_found = false;
} }
@@ -120,8 +115,13 @@ impl Prime<u64>{
/// Returns an nth primitive root of q = q_base^q_power in Montgomery. /// Returns an nth primitive root of q = q_base^q_power in Montgomery.
pub fn primitive_nth_root(&self, nth_root: u64) -> u64 { pub fn primitive_nth_root(&self, nth_root: u64) -> u64 {
assert!(
assert!(self.q & (nth_root-1) == 1, "invalid prime: q = {} % nth_root = {} = {} != 1", self.q, nth_root, self.q & (nth_root-1)); self.q & (nth_root - 1) == 1,
"invalid prime: q = {} % nth_root = {} = {} != 1",
self.q,
nth_root,
self.q & (nth_root - 1)
);
let psi: u64 = self.primitive_root(); let psi: u64 = self.primitive_root();
@@ -131,8 +131,14 @@ impl Prime<u64>{
// lifts nth primitive root mod q_base to q = q_base^q_power // lifts nth primitive root mod q_base to q = q_base^q_power
let psi_nth_q: u64 = self.hensel_lift(psi_nth_q_base, nth_root); let psi_nth_q: u64 = self.hensel_lift(psi_nth_q_base, nth_root);
assert!(self.pow(psi_nth_q, nth_root) == 1, "invalid nth primitive root: psi^nth_root != 1 mod q"); assert!(
assert!(self.pow(psi_nth_q, nth_root>>1) == self.q-1, "invalid nth primitive root: psi^(nth_root/2) != -1 mod q"); self.pow(psi_nth_q, nth_root) == 1,
"invalid nth primitive root: psi^nth_root != 1 mod q"
);
assert!(
self.pow(psi_nth_q, nth_root >> 1) == self.q - 1,
"invalid nth primitive root: psi^(nth_root/2) != -1 mod q"
);
psi_nth_q psi_nth_q
} }
@@ -141,16 +147,13 @@ impl Prime<u64>{
/// If not, factorize q_base-1 and populates self.factor. /// If not, factorize q_base-1 and populates self.factor.
/// If yes, checks that it contains the unique factors of q_base-1. /// If yes, checks that it contains the unique factors of q_base-1.
pub fn check_factors(&mut self) { pub fn check_factors(&mut self) {
if self.factors.len() == 0 { if self.factors.len() == 0 {
let factors = Factorization::run(self.q_base - 1).prime_factor_repr(); let factors = Factorization::run(self.q_base - 1).prime_factor_repr();
let mut distincts_factors: Vec<u64> = Vec::with_capacity(factors.len()); let mut distincts_factors: Vec<u64> = Vec::with_capacity(factors.len());
for factor in factors.iter() { for factor in factors.iter() {
distincts_factors.push(factor.0) distincts_factors.push(factor.0)
} }
self.factors = distincts_factors self.factors = distincts_factors
} else { } else {
let mut q_base: u64 = self.q_base; let mut q_base: u64 = self.q_base;
@@ -173,22 +176,29 @@ impl Prime<u64>{
/// Returns (psi + a * q_base)^{nth_root} = 1 mod q = q_base^q_power given psi^{nth_root} = 1 mod q_base. /// Returns (psi + a * q_base)^{nth_root} = 1 mod q = q_base^q_power given psi^{nth_root} = 1 mod q_base.
/// Panics if psi^{nth_root} != 1 mod q_base. /// Panics if psi^{nth_root} != 1 mod q_base.
fn hensel_lift(&self, psi: u64, nth_root: u64) -> u64 { fn hensel_lift(&self, psi: u64, nth_root: u64) -> u64 {
assert!(pow(psi, nth_root, self.q_base)==1, "invalid argument psi: psi^nth_root = {} != 1", pow(psi, nth_root, self.q_base)); assert!(
pow(psi, nth_root, self.q_base) == 1,
"invalid argument psi: psi^nth_root = {} != 1",
pow(psi, nth_root, self.q_base)
);
let mut psi_mont: Montgomery<u64> = self.montgomery.prepare::<ONCE>(psi); let mut psi_mont: Montgomery<u64> = self.montgomery.prepare::<ONCE>(psi);
let nth_root_mont: Montgomery<u64> = self.montgomery.prepare::<ONCE>(nth_root); let nth_root_mont: Montgomery<u64> = self.montgomery.prepare::<ONCE>(nth_root);
for _i in 1..self.q_power { for _i in 1..self.q_power {
let psi_pow: Montgomery<u64> = self.montgomery.pow(psi_mont, nth_root - 1); let psi_pow: Montgomery<u64> = self.montgomery.pow(psi_mont, nth_root - 1);
let num: Montgomery<u64> = self.montgomery.one() + self.q - self.montgomery.mul_internal::<ONCE>(psi_pow, psi_mont); let num: Montgomery<u64> = self.montgomery.one() + self.q
- self.montgomery.mul_internal::<ONCE>(psi_pow, psi_mont);
let mut den: Montgomery<u64> = self.montgomery.mul_internal::<ONCE>(nth_root_mont, psi_pow); let mut den: Montgomery<u64> =
self.montgomery.mul_internal::<ONCE>(nth_root_mont, psi_pow);
den = self.montgomery.pow(den, self.phi - 1); den = self.montgomery.pow(den, self.phi - 1);
psi_mont = self.montgomery.add_internal(psi_mont, self.montgomery.mul_internal::<ONCE>(num, den)); psi_mont = self
.montgomery
.add_internal(psi_mont, self.montgomery.mul_internal::<ONCE>(num, den));
} }
self.montgomery.unprepare::<ONCE>(psi_mont) self.montgomery.unprepare::<ONCE>(psi_mont)

View File

@@ -1,14 +1,16 @@
use crate::modulus::montgomery::MontgomeryPrecomp;
use crate::modulus::barrett::BarrettPrecomp; use crate::modulus::barrett::BarrettPrecomp;
use crate::modulus::montgomery::MontgomeryPrecomp;
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub struct Prime<O> { pub struct Prime<O> {
pub q: O, /// q_base^q_powers pub q: O,
/// q_base^q_powers
pub two_q: O, pub two_q: O,
pub four_q: O, pub four_q: O,
pub q_base: O, pub q_base: O,
pub q_power: usize, pub q_power: usize,
pub factors: Vec<O>, /// distinct factors of q-1 pub factors: Vec<O>,
/// distinct factors of q-1
pub montgomery: MontgomeryPrecomp<O>, pub montgomery: MontgomeryPrecomp<O>,
pub barrett: BarrettPrecomp<O>, pub barrett: BarrettPrecomp<O>,
pub phi: O, pub phi: O,

View File

@@ -4,7 +4,8 @@ use std::cmp::PartialEq;
#[derive(Clone, Debug, Eq)] #[derive(Clone, Debug, Eq)]
pub struct Poly<O>(pub Vec<O>); pub struct Poly<O>(pub Vec<O>);
impl<O> Poly<O>where impl<O> Poly<O>
where
O: Default + Clone + Copy, O: Default + Clone + Copy,
{ {
pub fn new(n: usize) -> Self { pub fn new(n: usize) -> Self {
@@ -12,11 +13,16 @@ impl<O> Poly<O>where
} }
pub fn buffer_size(&self) -> usize { pub fn buffer_size(&self) -> usize {
return self.0.len() return self.0.len();
} }
pub fn from_buffer(&mut self, n: usize, buf: &mut [O]) { pub fn from_buffer(&mut self, n: usize, buf: &mut [O]) {
assert!(buf.len() >= n, "invalid buffer: buf.len()={} < n={}", buf.len(), n); assert!(
buf.len() >= n,
"invalid buffer: buf.len()={} < n={}",
buf.len(),
n
);
self.0 = Vec::from(&buf[..n]); self.0 = Vec::from(&buf[..n]);
} }
@@ -42,7 +48,7 @@ impl<O> Poly<O>where
pub fn copy_from(&mut self, other: &Poly<O>) { pub fn copy_from(&mut self, other: &Poly<O>) {
if std::ptr::eq(self, other) { if std::ptr::eq(self, other) {
return return;
} }
self.resize(other.n()); self.resize(other.n());
self.0.copy_from_slice(&other.0) self.0.copy_from_slice(&other.0)
@@ -64,10 +70,10 @@ impl<O> Default for Poly<O> {
#[derive(Clone, Debug, Eq)] #[derive(Clone, Debug, Eq)]
pub struct PolyRNS<O>(pub Vec<Poly<O>>); pub struct PolyRNS<O>(pub Vec<Poly<O>>);
impl<O> PolyRNS<O>where impl<O> PolyRNS<O>
where
O: Default + Clone + Copy, O: Default + Clone + Copy,
{ {
pub fn new(n: usize, level: usize) -> Self { pub fn new(n: usize, level: usize) -> Self {
let mut polyrns: PolyRNS<O> = PolyRNS::<O>::default(); let mut polyrns: PolyRNS<O> = PolyRNS::<O>::default();
let mut buf: Vec<O> = vec![O::default(); polyrns.buffer_size(n, level)]; let mut buf: Vec<O> = vec![O::default(); polyrns.buffer_size(n, level)];
@@ -92,7 +98,12 @@ impl<O> PolyRNS<O>where
} }
pub fn from_buffer(&mut self, n: usize, level: usize, buf: &mut [O]) { pub fn from_buffer(&mut self, n: usize, level: usize, buf: &mut [O]) {
assert!(buf.len() >= n * (level+1), "invalid buffer: buf.len()={} < n * (level+1)={}", buf.len(), level+1); assert!(
buf.len() >= n * (level + 1),
"invalid buffer: buf.len()={} < n * (level+1)={}",
buf.len(),
level + 1
);
self.0.clear(); self.0.clear();
for chunk in buf.chunks_mut(n).take(level + 1) { for chunk in buf.chunks_mut(n).take(level + 1) {
let mut poly: Poly<O> = Poly(Vec::new()); let mut poly: Poly<O> = Poly(Vec::new());
@@ -110,7 +121,12 @@ impl<O> PolyRNS<O>where
} }
pub fn at(&self, level: usize) -> &Poly<O> { pub fn at(&self, level: usize) -> &Poly<O> {
assert!(level <= self.level(), "invalid argument level: level={} > self.level()={}", level, self.level()); assert!(
level <= self.level(),
"invalid argument level: level={} > self.level()={}",
level,
self.level()
);
&self.0[level] &self.0[level]
} }
@@ -128,15 +144,25 @@ impl<O> PolyRNS<O>where
pub fn copy(&mut self, other: &PolyRNS<O>) { pub fn copy(&mut self, other: &PolyRNS<O>) {
if std::ptr::eq(self, other) { if std::ptr::eq(self, other) {
return return;
} }
self.resize(other.level()); self.resize(other.level());
self.copy_level(other.level(), other); self.copy_level(other.level(), other);
} }
pub fn copy_level(&mut self, level: usize, other: &PolyRNS<O>) { pub fn copy_level(&mut self, level: usize, other: &PolyRNS<O>) {
assert!(self.level() <= level, "invalid argument level: level={} > self.level()={}", level, self.level()); assert!(
assert!(other.level() <= level, "invalid argument level: level={} > other.level()={}", level, other.level()); self.level() <= level,
"invalid argument level: level={} > self.level()={}",
level,
self.level()
);
assert!(
other.level() <= level,
"invalid argument level: level={} > other.level()={}",
level,
other.level()
);
(0..level + 1).for_each(|i| self.at_mut(i).copy_from(other.at(i))) (0..level + 1).for_each(|i| self.at_mut(i).copy_from(other.at(i)))
} }
} }

View File

@@ -1,10 +1,9 @@
pub mod impl_u64; pub mod impl_u64;
use num::traits::Unsigned; use crate::dft::DFT;
use crate::modulus::prime::Prime; use crate::modulus::prime::Prime;
use crate::poly::{Poly, PolyRNS}; use crate::poly::{Poly, PolyRNS};
use crate::dft::DFT; use num::traits::Unsigned;
pub struct Ring<O: Unsigned> { pub struct Ring<O: Unsigned> {
pub n: usize, pub n: usize,
@@ -14,7 +13,7 @@ pub struct Ring<O: Unsigned>{
impl<O: Unsigned> Ring<O> { impl<O: Unsigned> Ring<O> {
pub fn n(&self) -> usize { pub fn n(&self) -> usize {
return self.n return self.n;
} }
pub fn new_poly(&self) -> Poly<u64> { pub fn new_poly(&self) -> Poly<u64> {
@@ -25,7 +24,6 @@ impl<O: Unsigned> Ring<O>{
pub struct RingRNS<'a, O: Unsigned>(pub &'a [Ring<O>]); pub struct RingRNS<'a, O: Unsigned>(pub &'a [Ring<O>]);
impl<O: Unsigned> RingRNS<'_, O> { impl<O: Unsigned> RingRNS<'_, O> {
pub fn n(&self) -> usize { pub fn n(&self) -> usize {
self.0[0].n() self.0[0].n()
} }

View File

@@ -1,14 +1,23 @@
use crate::modulus::WordOps; use crate::modulus::WordOps;
use crate::ring::Ring;
use crate::poly::Poly; use crate::poly::Poly;
use crate::ring::Ring;
/// Returns a lookup table for the automorphism X^{i} -> X^{i * k mod nth_root}. /// Returns a lookup table for the automorphism X^{i} -> X^{i * k mod nth_root}.
/// Method will panic if n or nth_root are not power-of-two. /// Method will panic if n or nth_root are not power-of-two.
/// Method will panic if gal_el is not coprime with nth_root. /// Method will panic if gal_el is not coprime with nth_root.
pub fn automorphism_index_ntt(n: usize, nth_root: u64, gal_el: u64) -> Vec<u64> { pub fn automorphism_index_ntt(n: usize, nth_root: u64, gal_el: u64) -> Vec<u64> {
assert!(n & (n - 1) != 0, "invalid n={}: not a power-of-two", n); assert!(n & (n - 1) != 0, "invalid n={}: not a power-of-two", n);
assert!(nth_root&(nth_root-1) != 0, "invalid nth_root={}: not a power-of-two", n); assert!(
assert!(gal_el & 1 == 1, "invalid gal_el={}: not coprime with nth_root={}", gal_el, nth_root); nth_root & (nth_root - 1) != 0,
"invalid nth_root={}: not a power-of-two",
n
);
assert!(
gal_el & 1 == 1,
"invalid gal_el={}: not coprime with nth_root={}",
gal_el,
nth_root
);
let mask = nth_root - 1; let mask = nth_root - 1;
let log_nth_root: u32 = nth_root.log2() as u32; let log_nth_root: u32 = nth_root.log2() as u32;
@@ -23,7 +32,12 @@ pub fn automorphism_index_ntt(n: usize, nth_root:u64, gal_el: u64) -> Vec<u64>{
impl Ring<u64> { impl Ring<u64> {
pub fn automorphism(&self, a: Poly<u64>, gal_el: u64, b: &mut Poly<u64>) { pub fn automorphism(&self, a: Poly<u64>, gal_el: u64, b: &mut Poly<u64>) {
debug_assert!(a.n() == b.n(), "invalid inputs: a.n() = {} != b.n() = {}", a.n(), b.n()); debug_assert!(
a.n() == b.n(),
"invalid inputs: a.n() = {} != b.n() = {}",
a.n(),
b.n()
);
debug_assert!(gal_el & 1 == 1, "invalid gal_el = {}: not odd", gal_el); debug_assert!(gal_el & 1 == 1, "invalid gal_el = {}: not odd", gal_el);
let n: usize = a.n(); let n: usize = a.n();

View File

@@ -1,5 +1,5 @@
pub mod automorphism; pub mod automorphism;
pub mod rescaling_rns;
pub mod ring; pub mod ring;
pub mod ring_rns; pub mod ring_rns;
pub mod rescaling_rns;
pub mod sampling; pub mod sampling;

View File

@@ -1,17 +1,31 @@
use crate::modulus::barrett::Barrett;
use crate::modulus::ONCE;
use crate::poly::PolyRNS;
use crate::ring::Ring; use crate::ring::Ring;
use crate::ring::RingRNS; use crate::ring::RingRNS;
use crate::poly::PolyRNS;
use crate::modulus::barrett::Barrett;
use crate::scalar::ScalarRNS; use crate::scalar::ScalarRNS;
use crate::modulus::ONCE;
extern crate test; extern crate test;
impl RingRNS<'_, u64> { impl RingRNS<'_, u64> {
/// Updates b to floor(a / q[b.level()]). /// Updates b to floor(a / q[b.level()]).
pub fn div_floor_by_last_modulus<const NTT:bool>(&self, a: &PolyRNS<u64>, buf: &mut PolyRNS<u64>, b: &mut PolyRNS<u64>){ pub fn div_floor_by_last_modulus<const NTT: bool>(
debug_assert!(self.level() <= a.level(), "invalid input a: self.level()={} > a.level()={}", self.level(), a.level()); &self,
debug_assert!(b.level() >= a.level()-1, "invalid input b: b.level()={} < a.level()-1={}", b.level(), a.level()-1); a: &PolyRNS<u64>,
buf: &mut PolyRNS<u64>,
b: &mut PolyRNS<u64>,
) {
debug_assert!(
self.level() <= a.level(),
"invalid input a: self.level()={} > a.level()={}",
self.level(),
a.level()
);
debug_assert!(
b.level() >= a.level() - 1,
"invalid input b: b.level()={} < a.level()-1={}",
b.level(),
a.level() - 1
);
let level = self.level(); let level = self.level();
let rescaling_constants: ScalarRNS<Barrett<u64>> = self.rescaling_constant(); let rescaling_constants: ScalarRNS<Barrett<u64>> = self.rescaling_constant();
@@ -21,19 +35,38 @@ impl RingRNS<'_, u64>{
self.0[level].intt::<false>(a.at(level), &mut buf_ntt_q_scaling[0]); self.0[level].intt::<false>(a.at(level), &mut buf_ntt_q_scaling[0]);
for (i, r) in self.0[0..level].iter().enumerate() { for (i, r) in self.0[0..level].iter().enumerate() {
r.ntt::<false>(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); r.ntt::<false>(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]);
r.sum_aqqmb_prod_c_scalar_barrett::<ONCE>(&buf_ntt_qi_scaling[0], a.at(i), &rescaling_constants.0[i], b.at_mut(i)); r.sum_aqqmb_prod_c_scalar_barrett::<ONCE>(
&buf_ntt_qi_scaling[0],
a.at(i),
&rescaling_constants.0[i],
b.at_mut(i),
);
} }
} else { } else {
for (i, r) in self.0[0..level].iter().enumerate() { for (i, r) in self.0[0..level].iter().enumerate() {
r.sum_aqqmb_prod_c_scalar_barrett::<ONCE>(a.at(level), a.at(i), &rescaling_constants.0[i], b.at_mut(i)); r.sum_aqqmb_prod_c_scalar_barrett::<ONCE>(
a.at(level),
a.at(i),
&rescaling_constants.0[i],
b.at_mut(i),
);
} }
} }
} }
/// Updates a to floor(a / q[b.level()]). /// Updates a to floor(a / q[b.level()]).
/// Expects a to be in the NTT domain. /// Expects a to be in the NTT domain.
pub fn div_floor_by_last_modulus_inplace<const NTT:bool>(&self, buf: &mut PolyRNS<u64>, a: &mut PolyRNS<u64>){ pub fn div_floor_by_last_modulus_inplace<const NTT: bool>(
debug_assert!(self.level() <= a.level(), "invalid input a: self.level()={} > a.level()={}", self.level(), a.level()); &self,
buf: &mut PolyRNS<u64>,
a: &mut PolyRNS<u64>,
) {
debug_assert!(
self.level() <= a.level(),
"invalid input a: self.level()={} > a.level()={}",
self.level(),
a.level()
);
let level = self.level(); let level = self.level();
let rescaling_constants: ScalarRNS<Barrett<u64>> = self.rescaling_constant(); let rescaling_constants: ScalarRNS<Barrett<u64>> = self.rescaling_constant();
@@ -43,22 +76,50 @@ impl RingRNS<'_, u64>{
self.0[level].intt::<true>(a.at(level), &mut buf_ntt_q_scaling[0]); self.0[level].intt::<true>(a.at(level), &mut buf_ntt_q_scaling[0]);
for (i, r) in self.0[0..level].iter().enumerate() { for (i, r) in self.0[0..level].iter().enumerate() {
r.ntt::<true>(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); r.ntt::<true>(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]);
r.sum_aqqmb_prod_c_scalar_barrett_inplace::<ONCE>(&buf_ntt_qi_scaling[0], &rescaling_constants.0[i], a.at_mut(i)); r.sum_aqqmb_prod_c_scalar_barrett_inplace::<ONCE>(
&buf_ntt_qi_scaling[0],
&rescaling_constants.0[i],
a.at_mut(i),
);
} }
} else { } else {
let (a_i, a_level) = buf.0.split_at_mut(level); let (a_i, a_level) = buf.0.split_at_mut(level);
for (i, r) in self.0[0..level].iter().enumerate() { for (i, r) in self.0[0..level].iter().enumerate() {
r.sum_aqqmb_prod_c_scalar_barrett_inplace::<ONCE>(&a_level[0], &rescaling_constants.0[i], &mut a_i[i]); r.sum_aqqmb_prod_c_scalar_barrett_inplace::<ONCE>(
&a_level[0],
&rescaling_constants.0[i],
&mut a_i[i],
);
} }
} }
} }
/// Updates b to floor(a / prod_{level - nb_moduli}^{level} q[i]) /// Updates b to floor(a / prod_{level - nb_moduli}^{level} q[i])
pub fn div_floor_by_last_moduli<const NTT:bool>(&self, nb_moduli:usize, a: &PolyRNS<u64>, buf: &mut PolyRNS<u64>, c: &mut PolyRNS<u64>){ pub fn div_floor_by_last_moduli<const NTT: bool>(
&self,
debug_assert!(self.level() <= a.level(), "invalid input a: self.level()={} > a.level()={}", self.level(), a.level()); nb_moduli: usize,
debug_assert!(c.level() >= a.level()-1, "invalid input b: b.level()={} < a.level()-1={}", c.level(), a.level()-1); a: &PolyRNS<u64>,
debug_assert!(nb_moduli <= a.level(), "invalid input nb_moduli: nb_moduli={} > a.level()={}", nb_moduli, a.level()); buf: &mut PolyRNS<u64>,
c: &mut PolyRNS<u64>,
) {
debug_assert!(
self.level() <= a.level(),
"invalid input a: self.level()={} > a.level()={}",
self.level(),
a.level()
);
debug_assert!(
c.level() >= a.level() - 1,
"invalid input b: b.level()={} < a.level()-1={}",
c.level(),
a.level() - 1
);
debug_assert!(
nb_moduli <= a.level(),
"invalid input nb_moduli: nb_moduli={} > a.level()={}",
nb_moduli,
a.level()
);
if nb_moduli == 0 { if nb_moduli == 0 {
if a != c { if a != c {
@@ -67,33 +128,78 @@ impl RingRNS<'_, u64>{
} else { } else {
if NTT { if NTT {
self.intt::<false>(a, buf); self.intt::<false>(a, buf);
(0..nb_moduli).for_each(|i|{self.at_level(self.level()-i).div_floor_by_last_modulus_inplace::<false>(&mut PolyRNS::<u64>::default(), buf)}); (0..nb_moduli).for_each(|i| {
self.at_level(self.level() - i)
.div_floor_by_last_modulus_inplace::<false>(
&mut PolyRNS::<u64>::default(),
buf,
)
});
self.at_level(self.level() - nb_moduli).ntt::<false>(buf, c); self.at_level(self.level() - nb_moduli).ntt::<false>(buf, c);
} else { } else {
self.div_floor_by_last_modulus::<false>(a, buf, c); self.div_floor_by_last_modulus::<false>(a, buf, c);
(1..nb_moduli).for_each(|i|{self.at_level(self.level()-i).div_floor_by_last_modulus_inplace::<false>(buf, c)}); (1..nb_moduli).for_each(|i| {
self.at_level(self.level() - i)
.div_floor_by_last_modulus_inplace::<false>(buf, c)
});
} }
} }
} }
/// Updates a to floor(a / prod_{level - nb_moduli}^{level} q[i]) /// Updates a to floor(a / prod_{level - nb_moduli}^{level} q[i])
pub fn div_floor_by_last_moduli_inplace<const NTT:bool>(&self, nb_moduli:usize, buf: &mut PolyRNS<u64>, a: &mut PolyRNS<u64>){ pub fn div_floor_by_last_moduli_inplace<const NTT: bool>(
debug_assert!(self.level() <= a.level(), "invalid input a: self.level()={} > a.level()={}", self.level(), a.level()); &self,
debug_assert!(nb_moduli <= a.level(), "invalid input nb_moduli: nb_moduli={} > a.level()={}", nb_moduli, a.level()); nb_moduli: usize,
buf: &mut PolyRNS<u64>,
a: &mut PolyRNS<u64>,
) {
debug_assert!(
self.level() <= a.level(),
"invalid input a: self.level()={} > a.level()={}",
self.level(),
a.level()
);
debug_assert!(
nb_moduli <= a.level(),
"invalid input nb_moduli: nb_moduli={} > a.level()={}",
nb_moduli,
a.level()
);
if NTT { if NTT {
self.intt::<false>(a, buf); self.intt::<false>(a, buf);
(0..nb_moduli).for_each(|i|{self.at_level(self.level()-i).div_floor_by_last_modulus_inplace::<false>(&mut PolyRNS::<u64>::default(), buf)}); (0..nb_moduli).for_each(|i| {
self.at_level(self.level() - i)
.div_floor_by_last_modulus_inplace::<false>(&mut PolyRNS::<u64>::default(), buf)
});
self.at_level(self.level() - nb_moduli).ntt::<false>(buf, a); self.at_level(self.level() - nb_moduli).ntt::<false>(buf, a);
} else { } else {
(0..nb_moduli).for_each(|i|{self.at_level(self.level()-i).div_floor_by_last_modulus_inplace::<false>(buf, a)}); (0..nb_moduli).for_each(|i| {
self.at_level(self.level() - i)
.div_floor_by_last_modulus_inplace::<false>(buf, a)
});
} }
} }
/// Updates b to round(a / q[b.level()]). /// Updates b to round(a / q[b.level()]).
/// Expects b to be in the NTT domain. /// Expects b to be in the NTT domain.
pub fn div_round_by_last_modulus<const NTT:bool>(&self, a: &PolyRNS<u64>, buf: &mut PolyRNS<u64>, b: &mut PolyRNS<u64>){ pub fn div_round_by_last_modulus<const NTT: bool>(
debug_assert!(self.level() <= a.level(), "invalid input a: self.level()={} > a.level()={}", self.level(), a.level()); &self,
debug_assert!(b.level() >= a.level()-1, "invalid input b: b.level()={} < a.level()-1={}", b.level(), a.level()-1); a: &PolyRNS<u64>,
buf: &mut PolyRNS<u64>,
b: &mut PolyRNS<u64>,
) {
debug_assert!(
self.level() <= a.level(),
"invalid input a: self.level()={} > a.level()={}",
self.level(),
a.level()
);
debug_assert!(
b.level() >= a.level() - 1,
"invalid input b: b.level()={} < a.level()-1={}",
b.level(),
a.level() - 1
);
let level: usize = self.level(); let level: usize = self.level();
let r_last: &Ring<u64> = &self.0[level]; let r_last: &Ring<u64> = &self.0[level];
@@ -105,20 +211,36 @@ impl RingRNS<'_, u64>{
r_last.intt::<false>(a.at(level), &mut buf_ntt_q_scaling[0]); r_last.intt::<false>(a.at(level), &mut buf_ntt_q_scaling[0]);
r_last.add_scalar_inplace::<ONCE>(&q_level_half, &mut buf_ntt_q_scaling[0]); r_last.add_scalar_inplace::<ONCE>(&q_level_half, &mut buf_ntt_q_scaling[0]);
for (i, r) in self.0[0..level].iter().enumerate() { for (i, r) in self.0[0..level].iter().enumerate() {
r_last.add_scalar::<ONCE>(&buf_ntt_q_scaling[0], &q_level_half, &mut buf_ntt_qi_scaling[0]); r_last.add_scalar::<ONCE>(
&buf_ntt_q_scaling[0],
&q_level_half,
&mut buf_ntt_qi_scaling[0],
);
r.ntt_inplace::<false>(&mut buf_ntt_qi_scaling[0]); r.ntt_inplace::<false>(&mut buf_ntt_qi_scaling[0]);
r.sum_aqqmb_prod_c_scalar_barrett::<ONCE>(&buf_ntt_qi_scaling[0], a.at(i), &rescaling_constants.0[i], b.at_mut(i)); r.sum_aqqmb_prod_c_scalar_barrett::<ONCE>(
&buf_ntt_qi_scaling[0],
a.at(i),
&rescaling_constants.0[i],
b.at_mut(i),
);
} }
} else { } else {
} }
} }
/// Updates a to round(a / q[b.level()]). /// Updates a to round(a / q[b.level()]).
/// Expects a to be in the NTT domain. /// Expects a to be in the NTT domain.
pub fn div_round_by_last_modulus_inplace<const NTT:bool>(&self, buf: &mut PolyRNS<u64>, a: &mut PolyRNS<u64>){ pub fn div_round_by_last_modulus_inplace<const NTT: bool>(
debug_assert!(self.level() <= a.level(), "invalid input a: self.level()={} > a.level()={}", self.level(), a.level()); &self,
buf: &mut PolyRNS<u64>,
a: &mut PolyRNS<u64>,
) {
debug_assert!(
self.level() <= a.level(),
"invalid input a: self.level()={} > a.level()={}",
self.level(),
a.level()
);
let level = self.level(); let level = self.level();
let r_last: &Ring<u64> = &self.0[level]; let r_last: &Ring<u64> = &self.0[level];
@@ -130,13 +252,18 @@ impl RingRNS<'_, u64>{
r_last.intt::<true>(a.at(level), &mut buf_ntt_q_scaling[0]); r_last.intt::<true>(a.at(level), &mut buf_ntt_q_scaling[0]);
r_last.add_scalar_inplace::<ONCE>(&q_level_half, &mut buf_ntt_q_scaling[0]); r_last.add_scalar_inplace::<ONCE>(&q_level_half, &mut buf_ntt_q_scaling[0]);
for (i, r) in self.0[0..level].iter().enumerate() { for (i, r) in self.0[0..level].iter().enumerate() {
r_last.add_scalar::<ONCE>(&buf_ntt_q_scaling[0], &q_level_half, &mut buf_ntt_qi_scaling[0]); r_last.add_scalar::<ONCE>(
&buf_ntt_q_scaling[0],
&q_level_half,
&mut buf_ntt_qi_scaling[0],
);
r.ntt::<true>(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); r.ntt::<true>(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]);
r.sum_aqqmb_prod_c_scalar_barrett_inplace::<ONCE>(&buf_ntt_qi_scaling[0], &rescaling_constants.0[i], a.at_mut(i)); r.sum_aqqmb_prod_c_scalar_barrett_inplace::<ONCE>(
&buf_ntt_qi_scaling[0],
&rescaling_constants.0[i],
a.at_mut(i),
);
} }
} }
} }
} }

View File

@@ -1,14 +1,14 @@
use crate::ring::Ring;
use crate::dft::ntt::Table; use crate::dft::ntt::Table;
use crate::modulus::prime::Prime;
use crate::modulus::montgomery::Montgomery;
use crate::modulus::barrett::Barrett; use crate::modulus::barrett::Barrett;
use crate::poly::Poly; use crate::modulus::montgomery::Montgomery;
use crate::modulus::{REDUCEMOD, BARRETT}; use crate::modulus::prime::Prime;
use crate::modulus::VectorOperations; use crate::modulus::VectorOperations;
use crate::modulus::{BARRETT, REDUCEMOD};
use crate::poly::Poly;
use crate::ring::Ring;
use crate::CHUNK;
use num_bigint::BigInt; use num_bigint::BigInt;
use num_traits::ToPrimitive; use num_traits::ToPrimitive;
use crate::CHUNK;
impl Ring<u64> { impl Ring<u64> {
pub fn new(n: usize, q_base: u64, q_power: usize) -> Self { pub fn new(n: usize, q_base: u64, q_power: usize) -> Self {
@@ -21,10 +21,23 @@ impl Ring<u64>{
} }
pub fn from_bigint(&self, coeffs: &[BigInt], step: usize, a: &mut Poly<u64>) { pub fn from_bigint(&self, coeffs: &[BigInt], step: usize, a: &mut Poly<u64>) {
assert!(step <= a.n(), "invalid step: step={} > a.n()={}", step, a.n()); assert!(
assert!(coeffs.len() <= a.n() / step, "invalid coeffs: coeffs.len()={} > a.n()/step={}", coeffs.len(), a.n()/step); step <= a.n(),
"invalid step: step={} > a.n()={}",
step,
a.n()
);
assert!(
coeffs.len() <= a.n() / step,
"invalid coeffs: coeffs.len()={} > a.n()/step={}",
coeffs.len(),
a.n() / step
);
let q_big: BigInt = BigInt::from(self.modulus.q); let q_big: BigInt = BigInt::from(self.modulus.q);
a.0.iter_mut().step_by(step).enumerate().for_each(|(i, v)| *v = (&coeffs[i] % &q_big).to_u64().unwrap()); a.0.iter_mut()
.step_by(step)
.enumerate()
.for_each(|(i, v)| *v = (&coeffs[i] % &q_big).to_u64().unwrap());
} }
} }
@@ -32,14 +45,14 @@ impl Ring<u64>{
pub fn ntt_inplace<const LAZY: bool>(&self, poly: &mut Poly<u64>) { pub fn ntt_inplace<const LAZY: bool>(&self, poly: &mut Poly<u64>) {
match LAZY { match LAZY {
true => self.dft.forward_inplace_lazy(&mut poly.0), true => self.dft.forward_inplace_lazy(&mut poly.0),
false => self.dft.forward_inplace(&mut poly.0) false => self.dft.forward_inplace(&mut poly.0),
} }
} }
pub fn intt_inplace<const LAZY: bool>(&self, poly: &mut Poly<u64>) { pub fn intt_inplace<const LAZY: bool>(&self, poly: &mut Poly<u64>) {
match LAZY { match LAZY {
true => self.dft.backward_inplace_lazy(&mut poly.0), true => self.dft.backward_inplace_lazy(&mut poly.0),
false => self.dft.backward_inplace(&mut poly.0) false => self.dft.backward_inplace(&mut poly.0),
} }
} }
@@ -47,7 +60,7 @@ impl Ring<u64>{
poly_out.0.copy_from_slice(&poly_in.0); poly_out.0.copy_from_slice(&poly_in.0);
match LAZY { match LAZY {
true => self.dft.forward_inplace_lazy(&mut poly_out.0), true => self.dft.forward_inplace_lazy(&mut poly_out.0),
false => self.dft.forward_inplace(&mut poly_out.0) false => self.dft.forward_inplace(&mut poly_out.0),
} }
} }
@@ -55,18 +68,18 @@ impl Ring<u64>{
poly_out.0.copy_from_slice(&poly_in.0); poly_out.0.copy_from_slice(&poly_in.0);
match LAZY { match LAZY {
true => self.dft.backward_inplace_lazy(&mut poly_out.0), true => self.dft.backward_inplace_lazy(&mut poly_out.0),
false => self.dft.backward_inplace(&mut poly_out.0) false => self.dft.backward_inplace(&mut poly_out.0),
} }
} }
} }
impl Ring<u64> { impl Ring<u64> {
#[inline(always)] #[inline(always)]
pub fn add_inplace<const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &mut Poly<u64>) { pub fn add_inplace<const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
self.modulus.va_add_vb_into_vb::<CHUNK, REDUCE>(&a.0, &mut b.0); self.modulus
.va_add_vb_into_vb::<CHUNK, REDUCE>(&a.0, &mut b.0);
} }
#[inline(always)] #[inline(always)]
@@ -74,7 +87,8 @@ impl Ring<u64>{
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n());
self.modulus.va_add_vb_into_vc::<CHUNK, REDUCE>(&a.0, &b.0, &mut c.0); self.modulus
.va_add_vb_into_vc::<CHUNK, REDUCE>(&a.0, &b.0, &mut c.0);
} }
#[inline(always)] #[inline(always)]
@@ -87,14 +101,16 @@ impl Ring<u64>{
pub fn add_scalar<const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &u64, c: &mut Poly<u64>) { pub fn add_scalar<const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &u64, c: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n());
self.modulus.va_add_sb_into_vc::<CHUNK, REDUCE>(&a.0, b, &mut c.0); self.modulus
.va_add_sb_into_vc::<CHUNK, REDUCE>(&a.0, b, &mut c.0);
} }
#[inline(always)] #[inline(always)]
pub fn sub_inplace<const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &mut Poly<u64>) { pub fn sub_inplace<const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
self.modulus.va_sub_vb_into_vb::<CHUNK, REDUCE>(&a.0, &mut b.0); self.modulus
.va_sub_vb_into_vb::<CHUNK, REDUCE>(&a.0, &mut b.0);
} }
#[inline(always)] #[inline(always)]
@@ -102,7 +118,8 @@ impl Ring<u64>{
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n());
self.modulus.va_sub_vb_into_vc::<CHUNK, REDUCE>(&a.0, &b.0, &mut c.0); self.modulus
.va_sub_vb_into_vc::<CHUNK, REDUCE>(&a.0, &b.0, &mut c.0);
} }
#[inline(always)] #[inline(always)]
@@ -119,57 +136,102 @@ impl Ring<u64>{
} }
#[inline(always)] #[inline(always)]
pub fn mul_montgomery_external<const REDUCE:REDUCEMOD>(&self, a:&Poly<Montgomery<u64>>, b:&Poly<u64>, c: &mut Poly<u64>){ pub fn mul_montgomery_external<const REDUCE: REDUCEMOD>(
&self,
a: &Poly<Montgomery<u64>>,
b: &Poly<u64>,
c: &mut Poly<u64>,
) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n());
self.modulus.va_mont_mul_vb_into_vc::<CHUNK, REDUCE>(&a.0, &b.0, &mut c.0); self.modulus
.va_mont_mul_vb_into_vc::<CHUNK, REDUCE>(&a.0, &b.0, &mut c.0);
} }
#[inline(always)] #[inline(always)]
pub fn mul_montgomery_external_inplace<const REDUCE:REDUCEMOD>(&self, a:&Poly<Montgomery<u64>>, b:&mut Poly<u64>){ pub fn mul_montgomery_external_inplace<const REDUCE: REDUCEMOD>(
&self,
a: &Poly<Montgomery<u64>>,
b: &mut Poly<u64>,
) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
self.modulus.va_mont_mul_vb_into_vb::<CHUNK, REDUCE>(&a.0, &mut b.0); self.modulus
.va_mont_mul_vb_into_vb::<CHUNK, REDUCE>(&a.0, &mut b.0);
} }
#[inline(always)] #[inline(always)]
pub fn mul_scalar<const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &u64, c: &mut Poly<u64>) { pub fn mul_scalar<const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &u64, c: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "b.n()={} != n={}", a.n(), self.n()); debug_assert!(a.n() == self.n(), "b.n()={} != n={}", a.n(), self.n());
debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n());
self.modulus.sa_barrett_mul_vb_into_vc::<CHUNK, REDUCE>(&self.modulus.barrett.prepare(*b), &a.0, &mut c.0); self.modulus.sa_barrett_mul_vb_into_vc::<CHUNK, REDUCE>(
&self.modulus.barrett.prepare(*b),
&a.0,
&mut c.0,
);
} }
#[inline(always)] #[inline(always)]
pub fn mul_scalar_inplace<const REDUCE: REDUCEMOD>(&self, a: &u64, b: &mut Poly<u64>) { pub fn mul_scalar_inplace<const REDUCE: REDUCEMOD>(&self, a: &u64, b: &mut Poly<u64>) {
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
self.modulus.sa_barrett_mul_vb_into_vb::<CHUNK, REDUCE>(&self.modulus.barrett.prepare(self.modulus.barrett.reduce::<BARRETT>(a)), &mut b.0); self.modulus.sa_barrett_mul_vb_into_vb::<CHUNK, REDUCE>(
&self
.modulus
.barrett
.prepare(self.modulus.barrett.reduce::<BARRETT>(a)),
&mut b.0,
);
} }
#[inline(always)] #[inline(always)]
pub fn mul_scalar_barrett_inplace<const REDUCE:REDUCEMOD>(&self, a:&Barrett<u64>, b:&mut Poly<u64>){ pub fn mul_scalar_barrett_inplace<const REDUCE: REDUCEMOD>(
&self,
a: &Barrett<u64>,
b: &mut Poly<u64>,
) {
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
self.modulus.sa_barrett_mul_vb_into_vb::<CHUNK, REDUCE>(a, &mut b.0); self.modulus
.sa_barrett_mul_vb_into_vb::<CHUNK, REDUCE>(a, &mut b.0);
} }
#[inline(always)] #[inline(always)]
pub fn mul_scalar_barrett<const REDUCE:REDUCEMOD>(&self, a:&Barrett<u64>, b: &Poly<u64>, c:&mut Poly<u64>){ pub fn mul_scalar_barrett<const REDUCE: REDUCEMOD>(
&self,
a: &Barrett<u64>,
b: &Poly<u64>,
c: &mut Poly<u64>,
) {
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
self.modulus.sa_barrett_mul_vb_into_vc::<CHUNK, REDUCE>(a, &b.0, &mut c.0); self.modulus
.sa_barrett_mul_vb_into_vc::<CHUNK, REDUCE>(a, &b.0, &mut c.0);
} }
#[inline(always)] #[inline(always)]
pub fn sum_aqqmb_prod_c_scalar_barrett<const REDUCE:REDUCEMOD>(&self, a: &Poly<u64>, b: &Poly<u64>, c: &Barrett<u64>, d: &mut Poly<u64>){ pub fn sum_aqqmb_prod_c_scalar_barrett<const REDUCE: REDUCEMOD>(
&self,
a: &Poly<u64>,
b: &Poly<u64>,
c: &Barrett<u64>,
d: &mut Poly<u64>,
) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
debug_assert!(d.n() == self.n(), "d.n()={} != n={}", d.n(), self.n()); debug_assert!(d.n() == self.n(), "d.n()={} != n={}", d.n(), self.n());
self.modulus.va_sub_vb_mul_sc_into_vd::<CHUNK, REDUCE>(&a.0, &b.0, c, &mut d.0); self.modulus
.va_sub_vb_mul_sc_into_vd::<CHUNK, REDUCE>(&a.0, &b.0, c, &mut d.0);
} }
#[inline(always)] #[inline(always)]
pub fn sum_aqqmb_prod_c_scalar_barrett_inplace<const REDUCE:REDUCEMOD>(&self, a: &Poly<u64>, c: &Barrett<u64>, b: &mut Poly<u64>){ pub fn sum_aqqmb_prod_c_scalar_barrett_inplace<const REDUCE: REDUCEMOD>(
&self,
a: &Poly<u64>,
c: &Barrett<u64>,
b: &mut Poly<u64>,
) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
self.modulus.va_sub_vb_mul_sc_into_vb::<CHUNK, REDUCE>(&a.0, c, &mut b.0); self.modulus
.va_sub_vb_mul_sc_into_vb::<CHUNK, REDUCE>(&a.0, c, &mut b.0);
} }
} }

View File

@@ -1,9 +1,9 @@
use crate::ring::{Ring, RingRNS};
use crate::poly::PolyRNS;
use crate::modulus::montgomery::Montgomery;
use crate::modulus::barrett::Barrett; use crate::modulus::barrett::Barrett;
use crate::scalar::ScalarRNS; use crate::modulus::montgomery::Montgomery;
use crate::modulus::REDUCEMOD; use crate::modulus::REDUCEMOD;
use crate::poly::PolyRNS;
use crate::ring::{Ring, RingRNS};
use crate::scalar::ScalarRNS;
use num_bigint::BigInt; use num_bigint::BigInt;
pub fn new_rings(n: usize, moduli: Vec<u64>) -> Vec<Ring<u64>> { pub fn new_rings(n: usize, moduli: Vec<u64>) -> Vec<Ring<u64>> {
@@ -12,7 +12,7 @@ pub fn new_rings(n: usize, moduli: Vec<u64>) -> Vec<Ring<u64>>{
.into_iter() .into_iter()
.map(|prime| Ring::new(n, prime, 1)) .map(|prime| Ring::new(n, prime, 1))
.collect(); .collect();
return rings return rings;
} }
impl<'a> RingRNS<'a, u64> { impl<'a> RingRNS<'a, u64> {
@@ -22,25 +22,52 @@ impl<'a> RingRNS<'a, u64>{
pub fn modulus(&self) -> BigInt { pub fn modulus(&self) -> BigInt {
let mut modulus = BigInt::from(1); let mut modulus = BigInt::from(1);
self.0.iter().enumerate().for_each(|(_, r)|modulus *= BigInt::from(r.modulus.q)); self.0
.iter()
.enumerate()
.for_each(|(_, r)| modulus *= BigInt::from(r.modulus.q));
modulus modulus
} }
pub fn rescaling_constant(&self) -> ScalarRNS<Barrett<u64>> { pub fn rescaling_constant(&self) -> ScalarRNS<Barrett<u64>> {
let level = self.level(); let level = self.level();
let q_scale: u64 = self.0[level].modulus.q; let q_scale: u64 = self.0[level].modulus.q;
ScalarRNS((0..level).map(|i| {self.0[i].modulus.barrett.prepare(self.0[i].modulus.q - self.0[i].modulus.inv(q_scale))}).collect()) ScalarRNS(
(0..level)
.map(|i| {
self.0[i]
.modulus
.barrett
.prepare(self.0[i].modulus.q - self.0[i].modulus.inv(q_scale))
})
.collect(),
)
} }
pub fn from_bigint_inplace(&self, coeffs: &[BigInt], step: usize, a: &mut PolyRNS<u64>) { pub fn from_bigint_inplace(&self, coeffs: &[BigInt], step: usize, a: &mut PolyRNS<u64>) {
let level = self.level(); let level = self.level();
assert!(level <= a.level(), "invalid level: level={} > a.level()={}", level, a.level()); assert!(
(0..level).for_each(|i|{self.0[i].from_bigint(coeffs, step, a.at_mut(i))}); level <= a.level(),
"invalid level: level={} > a.level()={}",
level,
a.level()
);
(0..level).for_each(|i| self.0[i].from_bigint(coeffs, step, a.at_mut(i)));
} }
pub fn to_bigint_inplace(&self, a: &PolyRNS<u64>, step: usize, coeffs: &mut [BigInt]) { pub fn to_bigint_inplace(&self, a: &PolyRNS<u64>, step: usize, coeffs: &mut [BigInt]) {
assert!(step <= a.n(), "invalid step: step={} > a.n()={}", step, a.n()); assert!(
assert!(coeffs.len() <= a.n() / step, "invalid coeffs: coeffs.len()={} > a.n()/step={}", coeffs.len(), a.n()/step); step <= a.n(),
"invalid step: step={} > a.n()={}",
step,
a.n()
);
assert!(
coeffs.len() <= a.n() / step,
"invalid coeffs: coeffs.len()={} > a.n()/step={}",
coeffs.len(),
a.n() / step
);
let mut inv_crt: Vec<BigInt> = vec![BigInt::default(); self.level() + 1]; let mut inv_crt: Vec<BigInt> = vec![BigInt::default(); self.level() + 1];
let q_big: BigInt = self.modulus(); let q_big: BigInt = self.modulus();
@@ -67,92 +94,260 @@ impl<'a> RingRNS<'a, u64>{
impl RingRNS<'_, u64> { impl RingRNS<'_, u64> {
pub fn ntt_inplace<const LAZY: bool>(&self, a: &mut PolyRNS<u64>) { pub fn ntt_inplace<const LAZY: bool>(&self, a: &mut PolyRNS<u64>) {
self.0.iter().enumerate().for_each(|(i, ring)| ring.ntt_inplace::<LAZY>(&mut a.0[i])); self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.ntt_inplace::<LAZY>(&mut a.0[i]));
} }
pub fn intt_inplace<const LAZY: bool>(&self, a: &mut PolyRNS<u64>) { pub fn intt_inplace<const LAZY: bool>(&self, a: &mut PolyRNS<u64>) {
self.0.iter().enumerate().for_each(|(i, ring)| ring.intt_inplace::<LAZY>(&mut a.0[i])); self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.intt_inplace::<LAZY>(&mut a.0[i]));
} }
pub fn ntt<const LAZY: bool>(&self, a: &PolyRNS<u64>, b: &mut PolyRNS<u64>) { pub fn ntt<const LAZY: bool>(&self, a: &PolyRNS<u64>, b: &mut PolyRNS<u64>) {
self.0.iter().enumerate().for_each(|(i, ring)| ring.ntt::<LAZY>(&a.0[i], &mut b.0[i])); self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.ntt::<LAZY>(&a.0[i], &mut b.0[i]));
} }
pub fn intt<const LAZY: bool>(&self, a: &PolyRNS<u64>, b: &mut PolyRNS<u64>) { pub fn intt<const LAZY: bool>(&self, a: &PolyRNS<u64>, b: &mut PolyRNS<u64>) {
self.0.iter().enumerate().for_each(|(i, ring)| ring.intt::<LAZY>(&a.0[i], &mut b.0[i])); self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.intt::<LAZY>(&a.0[i], &mut b.0[i]));
} }
} }
impl RingRNS<'_, u64> { impl RingRNS<'_, u64> {
#[inline(always)] #[inline(always)]
pub fn add<const REDUCE: REDUCEMOD>(&self, a: &PolyRNS<u64>, b: &PolyRNS<u64>, c: &mut PolyRNS<u64>){ pub fn add<const REDUCE: REDUCEMOD>(
debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); &self,
debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); a: &PolyRNS<u64>,
debug_assert!(c.level() >= self.level(), "c.level()={} < self.level()={}", c.level(), self.level()); b: &PolyRNS<u64>,
self.0.iter().enumerate().for_each(|(i, ring)| ring.add::<REDUCE>(&a.0[i], &b.0[i], &mut c.0[i])); c: &mut PolyRNS<u64>,
) {
debug_assert!(
a.level() >= self.level(),
"a.level()={} < self.level()={}",
a.level(),
self.level()
);
debug_assert!(
b.level() >= self.level(),
"b.level()={} < self.level()={}",
b.level(),
self.level()
);
debug_assert!(
c.level() >= self.level(),
"c.level()={} < self.level()={}",
c.level(),
self.level()
);
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.add::<REDUCE>(&a.0[i], &b.0[i], &mut c.0[i]));
} }
#[inline(always)] #[inline(always)]
pub fn add_inplace<const REDUCE: REDUCEMOD>(&self, a: &PolyRNS<u64>, b: &mut PolyRNS<u64>) { pub fn add_inplace<const REDUCE: REDUCEMOD>(&self, a: &PolyRNS<u64>, b: &mut PolyRNS<u64>) {
debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); debug_assert!(
debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); a.level() >= self.level(),
self.0.iter().enumerate().for_each(|(i, ring)| ring.add_inplace::<REDUCE>(&a.0[i], &mut b.0[i])); "a.level()={} < self.level()={}",
a.level(),
self.level()
);
debug_assert!(
b.level() >= self.level(),
"b.level()={} < self.level()={}",
b.level(),
self.level()
);
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.add_inplace::<REDUCE>(&a.0[i], &mut b.0[i]));
} }
#[inline(always)] #[inline(always)]
pub fn sub<const REDUCE: REDUCEMOD>(&self, a: &PolyRNS<u64>, b: &PolyRNS<u64>, c: &mut PolyRNS<u64>){ pub fn sub<const REDUCE: REDUCEMOD>(
debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); &self,
debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); a: &PolyRNS<u64>,
debug_assert!(c.level() >= self.level(), "c.level()={} < self.level()={}", c.level(), self.level()); b: &PolyRNS<u64>,
self.0.iter().enumerate().for_each(|(i, ring)| ring.sub::<REDUCE>(&a.0[i], &b.0[i], &mut c.0[i])); c: &mut PolyRNS<u64>,
) {
debug_assert!(
a.level() >= self.level(),
"a.level()={} < self.level()={}",
a.level(),
self.level()
);
debug_assert!(
b.level() >= self.level(),
"b.level()={} < self.level()={}",
b.level(),
self.level()
);
debug_assert!(
c.level() >= self.level(),
"c.level()={} < self.level()={}",
c.level(),
self.level()
);
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.sub::<REDUCE>(&a.0[i], &b.0[i], &mut c.0[i]));
} }
#[inline(always)] #[inline(always)]
pub fn sub_inplace<const REDUCE: REDUCEMOD>(&self, a: &PolyRNS<u64>, b: &mut PolyRNS<u64>) { pub fn sub_inplace<const REDUCE: REDUCEMOD>(&self, a: &PolyRNS<u64>, b: &mut PolyRNS<u64>) {
debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); debug_assert!(
debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); a.level() >= self.level(),
self.0.iter().enumerate().for_each(|(i, ring)| ring.sub_inplace::<REDUCE>(&a.0[i], &mut b.0[i])); "a.level()={} < self.level()={}",
a.level(),
self.level()
);
debug_assert!(
b.level() >= self.level(),
"b.level()={} < self.level()={}",
b.level(),
self.level()
);
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.sub_inplace::<REDUCE>(&a.0[i], &mut b.0[i]));
} }
#[inline(always)] #[inline(always)]
pub fn neg<const REDUCE: REDUCEMOD>(&self, a: &PolyRNS<u64>, b: &mut PolyRNS<u64>) { pub fn neg<const REDUCE: REDUCEMOD>(&self, a: &PolyRNS<u64>, b: &mut PolyRNS<u64>) {
debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); debug_assert!(
debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); a.level() >= self.level(),
self.0.iter().enumerate().for_each(|(i, ring)| ring.neg::<REDUCE>(&a.0[i], &mut b.0[i])); "a.level()={} < self.level()={}",
a.level(),
self.level()
);
debug_assert!(
b.level() >= self.level(),
"b.level()={} < self.level()={}",
b.level(),
self.level()
);
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.neg::<REDUCE>(&a.0[i], &mut b.0[i]));
} }
#[inline(always)] #[inline(always)]
pub fn neg_inplace<const REDUCE: REDUCEMOD>(&self, a: &mut PolyRNS<u64>) { pub fn neg_inplace<const REDUCE: REDUCEMOD>(&self, a: &mut PolyRNS<u64>) {
debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); debug_assert!(
self.0.iter().enumerate().for_each(|(i, ring)| ring.neg_inplace::<REDUCE>(&mut a.0[i])); a.level() >= self.level(),
"a.level()={} < self.level()={}",
a.level(),
self.level()
);
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.neg_inplace::<REDUCE>(&mut a.0[i]));
} }
#[inline(always)] #[inline(always)]
pub fn mul_montgomery_external<const REDUCE:REDUCEMOD>(&self, a:&PolyRNS<Montgomery<u64>>, b:&PolyRNS<u64>, c: &mut PolyRNS<u64>){ pub fn mul_montgomery_external<const REDUCE: REDUCEMOD>(
debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); &self,
debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); a: &PolyRNS<Montgomery<u64>>,
debug_assert!(c.level() >= self.level(), "c.level()={} < self.level()={}", c.level(), self.level()); b: &PolyRNS<u64>,
self.0.iter().enumerate().for_each(|(i, ring)| ring.mul_montgomery_external::<REDUCE>(&a.0[i], &b.0[i], &mut c.0[i])); c: &mut PolyRNS<u64>,
) {
debug_assert!(
a.level() >= self.level(),
"a.level()={} < self.level()={}",
a.level(),
self.level()
);
debug_assert!(
b.level() >= self.level(),
"b.level()={} < self.level()={}",
b.level(),
self.level()
);
debug_assert!(
c.level() >= self.level(),
"c.level()={} < self.level()={}",
c.level(),
self.level()
);
self.0.iter().enumerate().for_each(|(i, ring)| {
ring.mul_montgomery_external::<REDUCE>(&a.0[i], &b.0[i], &mut c.0[i])
});
} }
#[inline(always)] #[inline(always)]
pub fn mul_montgomery_external_inplace<const REDUCE:REDUCEMOD>(&self, a:&PolyRNS<Montgomery<u64>>, b:&mut PolyRNS<u64>){ pub fn mul_montgomery_external_inplace<const REDUCE: REDUCEMOD>(
debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); &self,
debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); a: &PolyRNS<Montgomery<u64>>,
self.0.iter().enumerate().for_each(|(i, ring)| ring.mul_montgomery_external_inplace::<REDUCE>(&a.0[i], &mut b.0[i])); b: &mut PolyRNS<u64>,
) {
debug_assert!(
a.level() >= self.level(),
"a.level()={} < self.level()={}",
a.level(),
self.level()
);
debug_assert!(
b.level() >= self.level(),
"b.level()={} < self.level()={}",
b.level(),
self.level()
);
self.0.iter().enumerate().for_each(|(i, ring)| {
ring.mul_montgomery_external_inplace::<REDUCE>(&a.0[i], &mut b.0[i])
});
} }
#[inline(always)] #[inline(always)]
pub fn mul_scalar<const REDUCE:REDUCEMOD>(&self, a: &PolyRNS<u64>, b: &u64, c: &mut PolyRNS<u64>){ pub fn mul_scalar<const REDUCE: REDUCEMOD>(
debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); &self,
debug_assert!(c.level() >= self.level(), "b.level()={} < self.level()={}", c.level(), self.level()); a: &PolyRNS<u64>,
self.0.iter().enumerate().for_each(|(i, ring)| ring.mul_scalar::<REDUCE>(&a.0[i], b, &mut c.0[i])); b: &u64,
c: &mut PolyRNS<u64>,
) {
debug_assert!(
a.level() >= self.level(),
"a.level()={} < self.level()={}",
a.level(),
self.level()
);
debug_assert!(
c.level() >= self.level(),
"b.level()={} < self.level()={}",
c.level(),
self.level()
);
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.mul_scalar::<REDUCE>(&a.0[i], b, &mut c.0[i]));
} }
#[inline(always)] #[inline(always)]
pub fn mul_scalar_inplace<const REDUCE: REDUCEMOD>(&self, a: &u64, b: &mut PolyRNS<u64>) { pub fn mul_scalar_inplace<const REDUCE: REDUCEMOD>(&self, a: &u64, b: &mut PolyRNS<u64>) {
debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); debug_assert!(
self.0.iter().enumerate().for_each(|(i, ring)| ring.mul_scalar_inplace::<REDUCE>(a, &mut b.0[i])); b.level() >= self.level(),
"b.level()={} < self.level()={}",
b.level(),
self.level()
);
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.mul_scalar_inplace::<REDUCE>(a, &mut b.0[i]));
} }
} }

View File

@@ -1,18 +1,22 @@
use sampling::source::Source;
use crate::modulus::WordOps; use crate::modulus::WordOps;
use crate::ring::{Ring, RingRNS};
use crate::poly::{Poly, PolyRNS}; use crate::poly::{Poly, PolyRNS};
use crate::ring::{Ring, RingRNS};
use sampling::source::Source;
impl Ring<u64> { impl Ring<u64> {
pub fn fill_uniform(&self, source: &mut Source, a: &mut Poly<u64>) { pub fn fill_uniform(&self, source: &mut Source, a: &mut Poly<u64>) {
let max: u64 = self.modulus.q; let max: u64 = self.modulus.q;
let mask: u64 = max.mask(); let mask: u64 = max.mask();
a.0.iter_mut().for_each(|a|{*a = source.next_u64n(max, mask)}); a.0.iter_mut()
.for_each(|a| *a = source.next_u64n(max, mask));
} }
} }
impl RingRNS<'_, u64> { impl RingRNS<'_, u64> {
pub fn fill_uniform(&self, source: &mut Source, a: &mut PolyRNS<u64>) { pub fn fill_uniform(&self, source: &mut Source, a: &mut PolyRNS<u64>) {
self.0.iter().enumerate().for_each(|(i, r)|{r.fill_uniform(source, a.at_mut(i))}); self.0
.iter()
.enumerate()
.for_each(|(i, r)| r.fill_uniform(source, a.at_mut(i)));
} }
} }

View File

@@ -1,8 +1,8 @@
use num_bigint::BigInt;
use num_bigint::Sign;
use math::ring::{Ring, RingRNS};
use math::poly::PolyRNS; use math::poly::PolyRNS;
use math::ring::impl_u64::ring_rns::new_rings; use math::ring::impl_u64::ring_rns::new_rings;
use math::ring::{Ring, RingRNS};
use num_bigint::BigInt;
use num_bigint::Sign;
use sampling::source::Source; use sampling::source::Source;
#[test] #[test]
@@ -17,7 +17,6 @@ fn rescaling_rns_u64(){
} }
fn test_div_floor_by_last_modulus<const NTT: bool>(ring_rns: &RingRNS<u64>) { fn test_div_floor_by_last_modulus<const NTT: bool>(ring_rns: &RingRNS<u64>) {
let seed: [u8; 32] = [0; 32]; let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed); let mut source: Source = Source::new(seed);
@@ -29,8 +28,10 @@ fn test_div_floor_by_last_modulus<const NTT:bool>(ring_rns: &RingRNS<u64>) {
ring_rns.fill_uniform(&mut source, &mut a); ring_rns.fill_uniform(&mut source, &mut a);
// Maps PolyRNS to [BigInt] // Maps PolyRNS to [BigInt]
let mut coeffs_a: Vec<BigInt> = (0..a.n()).map(|i|{BigInt::from(i)}).collect(); let mut coeffs_a: Vec<BigInt> = (0..a.n()).map(|i| BigInt::from(i)).collect();
ring_rns.at_level(a.level()).to_bigint_inplace(&a, 1, &mut coeffs_a); ring_rns
.at_level(a.level())
.to_bigint_inplace(&a, 1, &mut coeffs_a);
// Performs c = intt(ntt(a) / q_level) // Performs c = intt(ntt(a) / q_level)
if NTT { if NTT {
@@ -45,7 +46,9 @@ fn test_div_floor_by_last_modulus<const NTT:bool>(ring_rns: &RingRNS<u64>) {
// Exports c to coeffs_c // Exports c to coeffs_c
let mut coeffs_c = vec![BigInt::from(0); c.n()]; let mut coeffs_c = vec![BigInt::from(0); c.n()];
ring_rns.at_level(c.level()).to_bigint_inplace(&c, 1, &mut coeffs_c); ring_rns
.at_level(c.level())
.to_bigint_inplace(&c, 1, &mut coeffs_c);
// Performs floor division on a // Performs floor division on a
let scalar_big = BigInt::from(ring_rns.0[ring_rns.level()].modulus.q); let scalar_big = BigInt::from(ring_rns.0[ring_rns.level()].modulus.q);

View File

@@ -1,6 +1,6 @@
use rand_chacha::rand_core::SeedableRng; use rand_chacha::rand_core::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_core::RngCore; use rand_core::RngCore;
use rand_chacha::{ChaCha8Rng};
const MAXF64: f64 = 9007199254740992.0; const MAXF64: f64 = 9007199254740992.0;
@@ -10,7 +10,9 @@ pub struct Source{
impl Source { impl Source {
pub fn new(seed: [u8; 32]) -> Source { pub fn new(seed: [u8; 32]) -> Source {
Source{source:ChaCha8Rng::from_seed(seed)} Source {
source: ChaCha8Rng::from_seed(seed),
}
} }
pub fn new_seed(&mut self) -> [u8; 32] { pub fn new_seed(&mut self) -> [u8; 32] {