diff --git a/math/benches/ntt.rs b/math/benches/ntt.rs index c21328c..bb530d5 100644 --- a/math/benches/ntt.rs +++ b/math/benches/ntt.rs @@ -1,31 +1,25 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use math::{modulus::prime::Prime,dft::ntt::Table}; use math::dft::DFT; +use math::{dft::ntt::Table, modulus::prime::Prime}; fn forward_inplace(c: &mut Criterion) { fn runner(prime_instance: Prime, nth_root: u64) -> Box { let ntt_table: Table = Table::::new(prime_instance, nth_root); let mut a: Vec = vec![0; (nth_root >> 1) as usize]; - for i in 0..a.len(){ + for i in 0..a.len() { a[i] = i as u64; } - Box::new(move || { - ntt_table.forward_inplace::(&mut a) - }) + Box::new(move || ntt_table.forward_inplace::(&mut a)) } - let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("forward_inplace"); + let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = + c.benchmark_group("forward_inplace"); for log_nth_root in 11..18 { - let prime_instance: Prime = Prime::::new(0x1fffffffffe00001, 1); - let runners = [ - ("prime", { - runner(prime_instance, 1<, nth_root: u64) -> Box { let ntt_table: Table = Table::::new(prime_instance, nth_root); let mut a: Vec = vec![0; (nth_root >> 1) as usize]; - for i in 0..a.len(){ + for i in 0..a.len() { a[i] = i as u64; } - Box::new(move || { - ntt_table.forward_inplace_lazy(&mut a) - }) + Box::new(move || ntt_table.forward_inplace_lazy(&mut a)) } - let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("forward_inplace_lazy"); + let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = + c.benchmark_group("forward_inplace_lazy"); for log_nth_root in 11..17 { - let prime_instance: Prime = Prime::::new(0x1fffffffffe00001, 1); - let runners = [ - ("prime", { - runner(prime_instance, 1<, nth_root: u64) -> Box { let ntt_table: Table = Table::::new(prime_instance, nth_root); let mut a: Vec = vec![0; (nth_root >> 1) as usize]; - for i in 0..a.len(){ + for i in 0..a.len() { a[i] = i as u64; } - Box::new(move || { - ntt_table.backward_inplace::(&mut a) - }) + Box::new(move || ntt_table.backward_inplace::(&mut a)) } - let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("backward_inplace"); + let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = + c.benchmark_group("backward_inplace"); for log_nth_root in 11..18 { - let prime_instance: Prime = Prime::::new(0x1fffffffffe00001, 1); - let runners = [ - ("prime", { - runner(prime_instance, 1<, nth_root: u64) -> Box { let ntt_table: Table = Table::::new(prime_instance, nth_root); let mut a: Vec = vec![0; (nth_root >> 1) as usize]; - for i in 0..a.len(){ + for i in 0..a.len() { a[i] = i as u64; } - Box::new(move || { - ntt_table.backward_inplace::(&mut a) - }) + Box::new(move || ntt_table.backward_inplace::(&mut a)) } - let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("backward_inplace_lazy"); + let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = + c.benchmark_group("backward_inplace_lazy"); for log_nth_root in 11..17 { - let prime_instance: Prime = Prime::::new(0x1fffffffffe00001, 1); - let runners = [ - ("prime", { - runner(prime_instance, 1<) -> Box { - let mut p0: math::poly::Poly = r.new_poly(); let mut p1: math::poly::Poly = r.new_poly(); - for i in 0..p0.n(){ + for i in 0..p0.n() { p0.0[i] = i as u64; p1.0[i] = i as u64; } @@ -19,18 +18,14 @@ fn va_add_vb_into_vb(c: &mut Criterion) { }) } - let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("va_add_vb_into_vb"); + let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = + c.benchmark_group("va_add_vb_into_vb"); for log_n in 11..17 { - - let n: usize = 1< = Ring::::new(n, q_base, q_power); - let runners = [ - ("prime", { - runner(r) - }), - ]; + let runners = [("prime", { runner(r) })]; for (name, mut runner) in runners { let id = BenchmarkId::new(name, n); b.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); @@ -40,30 +35,26 @@ fn va_add_vb_into_vb(c: &mut Criterion) { fn va_mont_mul_vb_into_vb(c: &mut Criterion) { fn runner(r: Ring) -> Box { - let mut p0: math::poly::Poly> = r.new_poly(); let mut p1: math::poly::Poly = r.new_poly(); - for i in 0..p0.n(){ + for i in 0..p0.n() { p0.0[i] = r.modulus.montgomery.prepare::(i as u64); p1.0[i] = i as u64; } Box::new(move || { - r.modulus.va_mont_mul_vb_into_vb::(&p0.0, &mut p1.0); + r.modulus + .va_mont_mul_vb_into_vb::(&p0.0, &mut p1.0); }) } - let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("va_mont_mul_vb_into_vb"); + let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = + c.benchmark_group("va_mont_mul_vb_into_vb"); for log_n in 11..17 { - - let n: usize = 1< = Ring::::new(n, q_base, q_power); - let runners = [ - ("prime", { - runner(r) - }), - ]; + let runners = [("prime", { runner(r) })]; for (name, mut runner) in runners { let id = BenchmarkId::new(name, n); b.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); @@ -73,31 +64,27 @@ fn va_mont_mul_vb_into_vb(c: &mut Criterion) { fn va_mont_mul_vb_into_vc(c: &mut Criterion) { fn runner(r: Ring) -> Box { - let mut p0: math::poly::Poly> = r.new_poly(); let mut p1: math::poly::Poly = r.new_poly(); let mut p2: math::poly::Poly = r.new_poly(); - for i in 0..p0.n(){ + for i in 0..p0.n() { p0.0[i] = r.modulus.montgomery.prepare::(i as u64); p1.0[i] = i as u64; } Box::new(move || { - r.modulus.va_mont_mul_vb_into_vc::(&p0.0, & p1.0, &mut p2.0); + r.modulus + .va_mont_mul_vb_into_vc::(&p0.0, &p1.0, &mut p2.0); }) } - let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("va_mont_mul_vb_into_vc"); + let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = + c.benchmark_group("va_mont_mul_vb_into_vc"); for log_n in 11..17 { - - let n: usize = 1< = Ring::::new(n, q_base, q_power); - let runners = [ - ("prime", { - runner(r) - }), - ]; + let runners = [("prime", { runner(r) })]; for (name, mut runner) in runners { let id = BenchmarkId::new(name, n); b.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); @@ -105,5 +92,10 @@ fn va_mont_mul_vb_into_vc(c: &mut Criterion) { } } -criterion_group!(benches, va_add_vb_into_vb, va_mont_mul_vb_into_vb, va_mont_mul_vb_into_vc); +criterion_group!( + benches, + va_add_vb_into_vb, + va_mont_mul_vb_into_vb, + va_mont_mul_vb_into_vc +); criterion_main!(benches); diff --git a/math/benches/ring_rns.rs b/math/benches/ring_rns.rs index c506ca3..52279ef 100644 --- a/math/benches/ring_rns.rs +++ b/math/benches/ring_rns.rs @@ -1,33 +1,33 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use math::ring::{Ring, RingRNS}; -use math::ring::impl_u64::ring_rns::new_rings; use math::poly::PolyRNS; +use math::ring::impl_u64::ring_rns::new_rings; +use math::ring::{Ring, RingRNS}; fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) { fn runner(r: RingRNS) -> Box { - let a: PolyRNS = r.new_polyrns(); let mut b: PolyRNS = r.new_polyrns(); let mut c: PolyRNS = r.new_polyrns(); - Box::new(move || { - r.div_floor_by_last_modulus::(&a, &mut b, &mut c) - }) + Box::new(move || r.div_floor_by_last_modulus::(&a, &mut b, &mut c)) } - let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("div_floor_by_last_modulus_ntt_true"); + let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = + c.benchmark_group("div_floor_by_last_modulus_ntt_true"); for log_n in 11..18 { - - let n = 1< = vec![0x1fffffffffe00001u64, 0x1fffffffffc80001u64, 0x1fffffffffb40001, 0x1fffffffff500001]; + let n = 1 << log_n; + let moduli: Vec = vec![ + 0x1fffffffffe00001u64, + 0x1fffffffffc80001u64, + 0x1fffffffffb40001, + 0x1fffffffff500001, + ]; let rings: Vec> = new_rings(n, moduli); let ring_rns: RingRNS<'_, u64> = RingRNS::new(&rings); - let runners = [ - (format!("prime/n={}/level={}", n, ring_rns.level()), { - runner(ring_rns) - }), - ]; + let runners = [(format!("prime/n={}/level={}", n, ring_rns.level()), { + runner(ring_rns) + })]; for (name, mut runner) in runners { b.bench_with_input(name, &(), |b, _| b.iter(&mut runner)); diff --git a/math/benches/sampling.rs b/math/benches/sampling.rs index dae8997..a695c0b 100644 --- a/math/benches/sampling.rs +++ b/math/benches/sampling.rs @@ -1,14 +1,13 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use math::ring::{Ring, RingRNS}; -use math::ring::impl_u64::ring_rns::new_rings; use math::poly::PolyRNS; +use math::ring::impl_u64::ring_rns::new_rings; +use math::ring::{Ring, RingRNS}; use sampling::source::Source; fn fill_uniform(c: &mut Criterion) { fn runner(r: RingRNS) -> Box { - let mut a: PolyRNS = r.new_polyrns(); - let seed: [u8; 32] = [0;32]; + let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); Box::new(move || { @@ -16,19 +15,22 @@ fn fill_uniform(c: &mut Criterion) { }) } - let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("fill_uniform"); + let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = + c.benchmark_group("fill_uniform"); for log_n in 11..18 { - - let n = 1< = vec![0x1fffffffffe00001u64, 0x1fffffffffc80001u64, 0x1fffffffffb40001, 0x1fffffffff500001]; + let n = 1 << log_n; + let moduli: Vec = vec![ + 0x1fffffffffe00001u64, + 0x1fffffffffc80001u64, + 0x1fffffffffb40001, + 0x1fffffffff500001, + ]; let rings: Vec> = new_rings(n, moduli); let ring_rns: RingRNS<'_, u64> = RingRNS::new(&rings); - let runners = [ - (format!("prime/n={}/level={}", n, ring_rns.level()), { - runner(ring_rns) - }), - ]; + let runners = [(format!("prime/n={}/level={}", n, ring_rns.level()), { + runner(ring_rns) + })]; for (name, mut runner) in runners { b.bench_with_input(name, &(), |b, _| b.iter(&mut runner)); diff --git a/math/examples/main.rs b/math/examples/main.rs index 8b932af..bdd9b74 100644 --- a/math/examples/main.rs +++ b/math/examples/main.rs @@ -1,13 +1,13 @@ -use math::ring::Ring; -use math::modulus::prime::Prime; use math::dft::ntt::Table; +use math::modulus::prime::Prime; +use math::ring::Ring; fn main() { // Example usage of `Prime` - let q_base: u64 = 65537; // Example prime base - let q_power: usize = 1; // Example power + let q_base: u64 = 65537; // Example prime base + let q_power: usize = 1; // Example power let prime_instance: Prime = Prime::::new(q_base, q_power); - + // Display the fields of `Prime` to verify println!("Prime instance created:"); println!("q: {}", prime_instance.q()); @@ -15,13 +15,13 @@ fn main() { println!("q_power: {}", prime_instance.q_power()); let n: u64 = 32; - let nth_root: u64 = n<<1; + let nth_root: u64 = n << 1; let ntt_table: Table = Table::::new(prime_instance, nth_root); let mut a: Vec = vec![0; (nth_root >> 1) as usize]; - for i in 0..a.len(){ + for i in 0..a.len() { a[i] = i as u64; } @@ -35,17 +35,16 @@ fn main() { println!("{:?}", a); - let r : Ring = Ring::::new(n as usize, q_base, q_power); + let r: Ring = Ring::::new(n as usize, q_base, q_power); let mut p0: math::poly::Poly = r.new_poly(); let mut p1: math::poly::Poly = r.new_poly(); - for i in 0..p0.n(){ + for i in 0..p0.n() { p0.0[i] = i as u64 } - r.automorphism(p0, (2*r.n-1) as u64, &mut p1); + r.automorphism(p0, (2 * r.n - 1) as u64, &mut p1); println!("{:?}", p1); - -} \ No newline at end of file +} diff --git a/math/src/dft.rs b/math/src/dft.rs index 4c414c0..bc81149 100644 --- a/math/src/dft.rs +++ b/math/src/dft.rs @@ -5,4 +5,4 @@ pub trait DFT { fn forward_inplace(&self, x: &mut [O]); fn backward_inplace_lazy(&self, x: &mut [O]); fn backward_inplace(&self, x: &mut [O]); -} \ No newline at end of file +} diff --git a/math/src/dft/ntt.rs b/math/src/dft/ntt.rs index ff9ec6c..db0dc67 100644 --- a/math/src/dft/ntt.rs +++ b/math/src/dft/ntt.rs @@ -1,127 +1,159 @@ -use crate::modulus::montgomery::Montgomery; +use crate::dft::DFT; use crate::modulus::barrett::Barrett; +use crate::modulus::montgomery::Montgomery; use crate::modulus::prime::Prime; use crate::modulus::ReduceOnce; use crate::modulus::WordOps; -use crate::modulus::{NONE, ONCE, BARRETT}; -use crate::dft::DFT; +use crate::modulus::{BARRETT, NONE, ONCE}; use itertools::izip; #[allow(dead_code)] -pub struct Table{ - prime:Prime, +pub struct Table { + prime: Prime, psi: O, - psi_forward_rev:Vec>, + psi_forward_rev: Vec>, psi_backward_rev: Vec>, - q:O, - two_q:O, - four_q:O, + q: O, + two_q: O, + four_q: O, } -impl Table< u64> { - pub fn new(prime: Prime, nth_root: u64)->Self{ - - assert!(nth_root&(nth_root-1) == 0, "invalid argument: nth_root = {} is not a power of two", nth_root); +impl Table { + pub fn new(prime: Prime, nth_root: u64) -> Self { + assert!( + nth_root & (nth_root - 1) == 0, + "invalid argument: nth_root = {} is not a power of two", + nth_root + ); let psi: u64 = prime.primitive_nth_root(nth_root); let psi_mont: Montgomery = prime.montgomery.prepare::(psi); - let psi_inv_mont: Montgomery = prime.montgomery.pow(psi_mont, prime.phi-1); - + let psi_inv_mont: Montgomery = prime.montgomery.pow(psi_mont, prime.phi - 1); + let mut psi_forward_rev: Vec> = vec![Barrett(0, 0); (nth_root >> 1) as usize]; let mut psi_backward_rev: Vec> = vec![Barrett(0, 0); (nth_root >> 1) as usize]; psi_forward_rev[0] = prime.barrett.prepare(1); psi_backward_rev[0] = prime.barrett.prepare(1); - let log_nth_root_half: u32 = (nth_root>>1).log2() as _; + let log_nth_root_half: u32 = (nth_root >> 1).log2() as _; let mut powers_forward: u64 = 1u64; let mut powers_backward: u64 = 1u64; - for i in 1..(nth_root>>1) as usize{ - + for i in 1..(nth_root >> 1) as usize { let i_rev: usize = i.reverse_bits_msb(log_nth_root_half); - prime.montgomery.mul_external_assign::(psi_mont, &mut powers_forward); - prime.montgomery.mul_external_assign::(psi_inv_mont, &mut powers_backward); + prime + .montgomery + .mul_external_assign::(psi_mont, &mut powers_forward); + prime + .montgomery + .mul_external_assign::(psi_inv_mont, &mut powers_backward); psi_forward_rev[i_rev] = prime.barrett.prepare(powers_forward); - psi_backward_rev[i_rev] = prime.barrett.prepare(powers_backward); + psi_backward_rev[i_rev] = prime.barrett.prepare(powers_backward); } let q: u64 = prime.q(); - Self{ - prime: prime, - psi:psi, - psi_forward_rev: psi_forward_rev, + Self { + prime: prime, + psi: psi, + psi_forward_rev: psi_forward_rev, psi_backward_rev: psi_backward_rev, - q:q, - two_q:q<<1, - four_q:q<<2, + q: q, + two_q: q << 1, + four_q: q << 2, } } } - -impl DFT for Table{ - - fn forward_inplace(&self, a: &mut [u64]){ +impl DFT for Table { + fn forward_inplace(&self, a: &mut [u64]) { self.forward_inplace::(a) } - fn forward_inplace_lazy(&self, a: &mut [u64]){ + fn forward_inplace_lazy(&self, a: &mut [u64]) { self.forward_inplace::(a) } - fn backward_inplace(&self, a: &mut [u64]){ + fn backward_inplace(&self, a: &mut [u64]) { self.backward_inplace::(a) } - fn backward_inplace_lazy(&self, a: &mut [u64]){ + fn backward_inplace_lazy(&self, a: &mut [u64]) { self.backward_inplace::(a) } } -impl Table{ - - pub fn forward_inplace(&self, a: &mut [u64]){ +impl Table { + pub fn forward_inplace(&self, a: &mut [u64]) { self.forward_inplace_core::(a); } - pub fn forward_inplace_core(&self, a: &mut [u64]) { - + pub fn forward_inplace_core( + &self, + a: &mut [u64], + ) { let n: usize = a.len(); - assert!(n & n-1 == 0, "invalid x.len()= {} must be a power of two", n); - let log_n: u32 = usize::BITS - ((n as usize)-1).leading_zeros(); + assert!( + n & n - 1 == 0, + "invalid x.len()= {} must be a power of two", + n + ); + let log_n: u32 = usize::BITS - ((n as usize) - 1).leading_zeros(); let start: u32 = SKIPSTART as u32; let end: u32 = log_n - (SKIPEND as u32); for layer in start..end { let (m, size) = (1 << layer, 1 << (log_n - layer - 1)); - let t: usize = 2*size; + let t: usize = 2 * size; if layer == log_n - 1 { - if LAZY{ - izip!(a.chunks_exact_mut(t), &self.psi_forward_rev[m..]).for_each(|(a, psi)| { - let (a, b) = a.split_at_mut(size); - self.dit_inplace::(&mut a[0], &mut b[0], *psi); - debug_assert!(a[0] < self.two_q, "forward_inplace_core:: output {} > {} (2q-1)", a[0], self.two_q-1); - debug_assert!(b[0] < self.two_q, "forward_inplace_core:: output {} > {} (2q-1)", b[0], self.two_q-1); - }); - }else{ - izip!(a.chunks_exact_mut(t), &self.psi_forward_rev[m..]).for_each(|(a, psi)| { - let (a, b) = a.split_at_mut(size); - self.dit_inplace::(&mut a[0], &mut b[0], *psi); - self.prime.barrett.reduce_assign::(&mut a[0]); - self.prime.barrett.reduce_assign::(&mut b[0]); - debug_assert!(a[0] < self.q, "forward_inplace_core:: output {} > {} (q-1)", a[0], self.q-1); - debug_assert!(b[0] < self.q, "forward_inplace_core:: output {} > {} (q-1)", b[0], self.q-1); - }); + if LAZY { + izip!(a.chunks_exact_mut(t), &self.psi_forward_rev[m..]).for_each( + |(a, psi)| { + let (a, b) = a.split_at_mut(size); + self.dit_inplace::(&mut a[0], &mut b[0], *psi); + debug_assert!( + a[0] < self.two_q, + "forward_inplace_core:: output {} > {} (2q-1)", + a[0], + self.two_q - 1 + ); + debug_assert!( + b[0] < self.two_q, + "forward_inplace_core:: output {} > {} (2q-1)", + b[0], + self.two_q - 1 + ); + }, + ); + } else { + izip!(a.chunks_exact_mut(t), &self.psi_forward_rev[m..]).for_each( + |(a, psi)| { + let (a, b) = a.split_at_mut(size); + self.dit_inplace::(&mut a[0], &mut b[0], *psi); + self.prime.barrett.reduce_assign::(&mut a[0]); + self.prime.barrett.reduce_assign::(&mut b[0]); + debug_assert!( + a[0] < self.q, + "forward_inplace_core:: output {} > {} (q-1)", + a[0], + self.q - 1 + ); + debug_assert!( + b[0] < self.q, + "forward_inplace_core:: output {} > {} (q-1)", + b[0], + self.q - 1 + ); + }, + ); } - - } else if t >= 16{ + } else if t >= 16 { izip!(a.chunks_exact_mut(t), &self.psi_forward_rev[m..]).for_each(|(a, psi)| { let (a, b) = a.split_at_mut(size); izip!(a.chunks_exact_mut(8), b.chunks_exact_mut(8)).for_each(|(a, b)| { @@ -135,7 +167,7 @@ impl Table{ self.dit_inplace::(&mut a[7], &mut b[7], *psi); }); }); - }else{ + } else { izip!(a.chunks_exact_mut(t), &self.psi_forward_rev[m..]).for_each(|(a, psi)| { let (a, b) = a.split_at_mut(size); izip!(a, b).for_each(|(a, b)| self.dit_inplace::(a, b, *psi)); @@ -150,7 +182,7 @@ impl Table{ debug_assert!(*b < self.four_q, "b:{} q:{}", b, self.four_q); a.reduce_once_assign(self.two_q); let bt: u64 = self.prime.barrett.mul_external::(t, *b); - *b = *a + self.two_q-bt; + *b = *a + self.two_q - bt; *a += bt; if !LAZY { a.reduce_once_assign(self.two_q); @@ -158,58 +190,63 @@ impl Table{ } } - pub fn backward_inplace(&self, a: &mut [u64]){ + pub fn backward_inplace(&self, a: &mut [u64]) { self.backward_inplace_core::(a); } - pub fn backward_inplace_core(&self, a: &mut [u64]) { + pub fn backward_inplace_core( + &self, + a: &mut [u64], + ) { let n: usize = a.len(); - assert!(n & n-1 == 0, "invalid x.len()= {} must be a power of two", n); - let log_n = usize::BITS - ((n as usize)-1).leading_zeros(); + assert!( + n & n - 1 == 0, + "invalid x.len()= {} must be a power of two", + n + ); + let log_n = usize::BITS - ((n as usize) - 1).leading_zeros(); let start: u32 = SKIPEND as u32; let end: u32 = log_n - (SKIPSTART as u32); for layer in (start..end).rev() { let (m, size) = (1 << layer, 1 << (log_n - layer - 1)); - let t: usize = 2*size; + let t: usize = 2 * size; if layer == 0 { - let n_inv: Barrett = self.prime.barrett.prepare(self.prime.inv(n as u64)); - let psi: Barrett = self.prime.barrett.prepare(self.prime.barrett.mul_external::(n_inv, self.psi_backward_rev[1].0)); - - izip!(a.chunks_exact_mut(2 * size)).for_each( - |a| { - let (a, b) = a.split_at_mut(size); - izip!(a.chunks_exact_mut(8), b.chunks_exact_mut(8)).for_each(|(a, b)| { - self.dif_last_inplace::(&mut a[0], &mut b[0], psi, n_inv); - self.dif_last_inplace::(&mut a[1], &mut b[1], psi, n_inv); - self.dif_last_inplace::(&mut a[2], &mut b[2], psi, n_inv); - self.dif_last_inplace::(&mut a[3], &mut b[3], psi, n_inv); - self.dif_last_inplace::(&mut a[4], &mut b[4], psi, n_inv); - self.dif_last_inplace::(&mut a[5], &mut b[5], psi, n_inv); - self.dif_last_inplace::(&mut a[6], &mut b[6], psi, n_inv); - self.dif_last_inplace::(&mut a[7], &mut b[7], psi, n_inv); - }); - }, + let psi: Barrett = self.prime.barrett.prepare( + self.prime + .barrett + .mul_external::(n_inv, self.psi_backward_rev[1].0), ); - } else if t >= 16{ - izip!(a.chunks_exact_mut(t), &self.psi_backward_rev[m..]).for_each( - |(a, psi)| { - let (a, b) = a.split_at_mut(size); - izip!(a.chunks_exact_mut(8), b.chunks_exact_mut(8)).for_each(|(a, b)| { - self.dif_inplace::(&mut a[0], &mut b[0], *psi); - self.dif_inplace::(&mut a[1], &mut b[1], *psi); - self.dif_inplace::(&mut a[2], &mut b[2], *psi); - self.dif_inplace::(&mut a[3], &mut b[3], *psi); - self.dif_inplace::(&mut a[4], &mut b[4], *psi); - self.dif_inplace::(&mut a[5], &mut b[5], *psi); - self.dif_inplace::(&mut a[6], &mut b[6], *psi); - self.dif_inplace::(&mut a[7], &mut b[7], *psi); - }); - }, - ); + izip!(a.chunks_exact_mut(2 * size)).for_each(|a| { + let (a, b) = a.split_at_mut(size); + izip!(a.chunks_exact_mut(8), b.chunks_exact_mut(8)).for_each(|(a, b)| { + self.dif_last_inplace::(&mut a[0], &mut b[0], psi, n_inv); + self.dif_last_inplace::(&mut a[1], &mut b[1], psi, n_inv); + self.dif_last_inplace::(&mut a[2], &mut b[2], psi, n_inv); + self.dif_last_inplace::(&mut a[3], &mut b[3], psi, n_inv); + self.dif_last_inplace::(&mut a[4], &mut b[4], psi, n_inv); + self.dif_last_inplace::(&mut a[5], &mut b[5], psi, n_inv); + self.dif_last_inplace::(&mut a[6], &mut b[6], psi, n_inv); + self.dif_last_inplace::(&mut a[7], &mut b[7], psi, n_inv); + }); + }); + } else if t >= 16 { + izip!(a.chunks_exact_mut(t), &self.psi_backward_rev[m..]).for_each(|(a, psi)| { + let (a, b) = a.split_at_mut(size); + izip!(a.chunks_exact_mut(8), b.chunks_exact_mut(8)).for_each(|(a, b)| { + self.dif_inplace::(&mut a[0], &mut b[0], *psi); + self.dif_inplace::(&mut a[1], &mut b[1], *psi); + self.dif_inplace::(&mut a[2], &mut b[2], *psi); + self.dif_inplace::(&mut a[3], &mut b[3], *psi); + self.dif_inplace::(&mut a[4], &mut b[4], *psi); + self.dif_inplace::(&mut a[5], &mut b[5], *psi); + self.dif_inplace::(&mut a[6], &mut b[6], *psi); + self.dif_inplace::(&mut a[7], &mut b[7], *psi); + }); + }); } else { izip!(a.chunks_exact_mut(2 * size), &self.psi_backward_rev[m..]).for_each( |(a, psi)| { @@ -225,7 +262,10 @@ impl Table{ fn dif_inplace(&self, a: &mut u64, b: &mut u64, t: Barrett) { debug_assert!(*a < self.two_q, "a:{} q:{}", a, self.two_q); debug_assert!(*b < self.two_q, "b:{} q:{}", b, self.two_q); - let d: u64 = self.prime.barrett.mul_external::(t, *a + self.two_q - *b); + let d: u64 = self + .prime + .barrett + .mul_external::(t, *a + self.two_q - *b); *a = *a + *b; a.reduce_once_assign(self.two_q); *b = d; @@ -235,15 +275,27 @@ impl Table{ } } - fn dif_last_inplace(&self, a: &mut u64, b: &mut u64, psi: Barrett, n_inv: Barrett){ + fn dif_last_inplace( + &self, + a: &mut u64, + b: &mut u64, + psi: Barrett, + n_inv: Barrett, + ) { debug_assert!(*a < self.two_q); debug_assert!(*b < self.two_q); - if LAZY{ - let d: u64 = self.prime.barrett.mul_external::(psi, *a + self.two_q - *b); + if LAZY { + let d: u64 = self + .prime + .barrett + .mul_external::(psi, *a + self.two_q - *b); *a = self.prime.barrett.mul_external::(n_inv, *a + *b); *b = d; - }else{ - let d: u64 = self.prime.barrett.mul_external::(psi, *a + self.two_q - *b); + } else { + let d: u64 = self + .prime + .barrett + .mul_external::(psi, *a + self.two_q - *b); *a = self.prime.barrett.mul_external::(n_inv, *a + *b); *b = d; } @@ -260,10 +312,10 @@ mod tests { let q_power: usize = 1; let prime_instance: Prime = Prime::::new(q_base, q_power); let n: u64 = 32; - let two_nth_root: u64 = n<<1; + let two_nth_root: u64 = n << 1; let ntt_table: Table = Table::::new(prime_instance, two_nth_root); let mut a: Vec = vec![0; n as usize]; - for i in 0..a.len(){ + for i in 0..a.len() { a[i] = i as u64; } @@ -272,4 +324,4 @@ mod tests { ntt_table.backward_inplace::(&mut a); assert!(a == b); } -} \ No newline at end of file +} diff --git a/math/src/lib.rs b/math/src/lib.rs index 7fdcca5..f8e16be 100644 --- a/math/src/lib.rs +++ b/math/src/lib.rs @@ -1,22 +1,20 @@ #![feature(bigint_helper_methods)] #![feature(test)] -pub mod modulus; pub mod dft; -pub mod ring; +pub mod modulus; pub mod poly; +pub mod ring; pub mod scalar; -pub const CHUNK: usize= 8; +pub const CHUNK: usize = 8; + +pub mod macros { -pub mod macros{ - #[macro_export] macro_rules! apply_v { - ($self:expr, $f:expr, $a:expr, $CHUNK:expr) => { - - match CHUNK{ + match CHUNK { 8 => { $a.chunks_exact_mut(8).for_each(|a| { $f(&$self, &mut a[0]); @@ -30,12 +28,12 @@ pub mod macros{ }); let n: usize = $a.len(); - let m = n - (n&(CHUNK-1)); + let m = n - (n & (CHUNK - 1)); $a[m..].iter_mut().for_each(|a| { $f(&$self, a); }); - }, - _=>{ + } + _ => { $a.iter_mut().for_each(|a| { $f(&$self, a); }); @@ -46,16 +44,21 @@ pub mod macros{ #[macro_export] macro_rules! apply_vv { - ($self:expr, $f:expr, $a:expr, $b:expr, $CHUNK:expr) => { - let n: usize = $a.len(); - debug_assert!($b.len() == n, "invalid argument b: b.len() = {} != a.len() = {}", $b.len(), n); - debug_assert!(CHUNK&(CHUNK-1) == 0, "invalid CHUNK const: not a power of two"); + debug_assert!( + $b.len() == n, + "invalid argument b: b.len() = {} != a.len() = {}", + $b.len(), + n + ); + debug_assert!( + CHUNK & (CHUNK - 1) == 0, + "invalid CHUNK const: not a power of two" + ); - match CHUNK{ + match CHUNK { 8 => { - izip!($a.chunks_exact(8), $b.chunks_exact_mut(8)).for_each(|(a, b)| { $f(&$self, &a[0], &mut b[0]); $f(&$self, &a[1], &mut b[1]); @@ -67,12 +70,12 @@ pub mod macros{ $f(&$self, &a[7], &mut b[7]); }); - let m = n - (n&(CHUNK-1)); + let m = n - (n & (CHUNK - 1)); izip!($a[m..].iter(), $b[m..].iter_mut()).for_each(|(a, b)| { $f(&$self, a, b); }); - }, - _=>{ + } + _ => { izip!($a.iter(), $b.iter_mut()).for_each(|(a, b)| { $f(&$self, a, b); }); @@ -83,18 +86,33 @@ pub mod macros{ #[macro_export] macro_rules! apply_vvv { - ($self:expr, $f:expr, $a:expr, $b:expr, $c:expr, $CHUNK:expr) => { - let n: usize = $a.len(); - debug_assert!($b.len() == n, "invalid argument b: b.len() = {} != a.len() = {}", $b.len(), n); - debug_assert!($c.len() == n, "invalid argument c: b.len() = {} != a.len() = {}", $c.len(), n); - debug_assert!(CHUNK&(CHUNK-1) == 0, "invalid CHUNK const: not a power of two"); + debug_assert!( + $b.len() == n, + "invalid argument b: b.len() = {} != a.len() = {}", + $b.len(), + n + ); + debug_assert!( + $c.len() == n, + "invalid argument c: b.len() = {} != a.len() = {}", + $c.len(), + n + ); + debug_assert!( + CHUNK & (CHUNK - 1) == 0, + "invalid CHUNK const: not a power of two" + ); - match CHUNK{ + match CHUNK { 8 => { - - izip!($a.chunks_exact(8), $b.chunks_exact(8), $c.chunks_exact_mut(8)).for_each(|(a, b, c)| { + izip!( + $a.chunks_exact(8), + $b.chunks_exact(8), + $c.chunks_exact_mut(8) + ) + .for_each(|(a, b, c)| { $f(&$self, &a[0], &b[0], &mut c[0]); $f(&$self, &a[1], &b[1], &mut c[1]); $f(&$self, &a[2], &b[2], &mut c[2]); @@ -105,12 +123,14 @@ pub mod macros{ $f(&$self, &a[7], &b[7], &mut c[7]); }); - let m = n - (n&7); - izip!($a[m..].iter(), $b[m..].iter(), $c[m..].iter_mut()).for_each(|(a, b, c)| { - $f(&$self, a, b, c); - }); - }, - _=>{ + let m = n - (n & 7); + izip!($a[m..].iter(), $b[m..].iter(), $c[m..].iter_mut()).for_each( + |(a, b, c)| { + $f(&$self, a, b, c); + }, + ); + } + _ => { izip!($a.iter(), $b.iter(), $c.iter_mut()).for_each(|(a, b, c)| { $f(&$self, a, b, c); }); @@ -121,16 +141,16 @@ pub mod macros{ #[macro_export] macro_rules! apply_sv { - ($self:expr, $f:expr, $a:expr, $b:expr, $CHUNK:expr) => { - let n: usize = $b.len(); - debug_assert!(CHUNK&(CHUNK-1) == 0, "invalid CHUNK const: not a power of two"); + debug_assert!( + CHUNK & (CHUNK - 1) == 0, + "invalid CHUNK const: not a power of two" + ); - match CHUNK{ + match CHUNK { 8 => { - izip!($b.chunks_exact_mut(8)).for_each(|b| { $f(&$self, $a, &mut b[0]); $f(&$self, $a, &mut b[1]); @@ -142,12 +162,12 @@ pub mod macros{ $f(&$self, $a, &mut b[7]); }); - let m = n - (n&7); + let m = n - (n & 7); izip!($b[m..].iter_mut()).for_each(|b| { $f(&$self, $a, b); }); - }, - _=>{ + } + _ => { izip!($b.iter_mut()).for_each(|b| { $f(&$self, $a, b); }); @@ -158,16 +178,21 @@ pub mod macros{ #[macro_export] macro_rules! apply_svv { - ($self:expr, $f:expr, $a:expr, $b:expr, $c:expr, $CHUNK:expr) => { - let n: usize = $b.len(); - debug_assert!($c.len() == n, "invalid argument c: c.len() = {} != b.len() = {}", $c.len(), n); - debug_assert!(CHUNK&(CHUNK-1) == 0, "invalid CHUNK const: not a power of two"); + debug_assert!( + $c.len() == n, + "invalid argument c: c.len() = {} != b.len() = {}", + $c.len(), + n + ); + debug_assert!( + CHUNK & (CHUNK - 1) == 0, + "invalid CHUNK const: not a power of two" + ); - match CHUNK{ + match CHUNK { 8 => { - izip!($b.chunks_exact(8), $c.chunks_exact_mut(8)).for_each(|(b, c)| { $f(&$self, $a, &b[0], &mut c[0]); $f(&$self, $a, &b[1], &mut c[1]); @@ -179,12 +204,12 @@ pub mod macros{ $f(&$self, $a, &b[7], &mut c[7]); }); - let m = n - (n&7); + let m = n - (n & 7); izip!($b[m..].iter(), $c[m..].iter_mut()).for_each(|(b, c)| { $f(&$self, $a, b, c); }); - }, - _=>{ + } + _ => { izip!($b.iter(), $c.iter_mut()).for_each(|(b, c)| { $f(&$self, $a, b, c); }); @@ -195,18 +220,33 @@ pub mod macros{ #[macro_export] macro_rules! apply_vvsv { - ($self:expr, $f:expr, $a:expr, $b:expr, $c:expr, $d:expr, $CHUNK:expr) => { - let n: usize = $a.len(); - debug_assert!($b.len() == n, "invalid argument b: b.len() = {} != a.len() = {}", $b.len(), n); - debug_assert!($d.len() == n, "invalid argument d: d.len() = {} != a.len() = {}", $d.len(), n); - debug_assert!(CHUNK&(CHUNK-1) == 0, "invalid CHUNK const: not a power of two"); + debug_assert!( + $b.len() == n, + "invalid argument b: b.len() = {} != a.len() = {}", + $b.len(), + n + ); + debug_assert!( + $d.len() == n, + "invalid argument d: d.len() = {} != a.len() = {}", + $d.len(), + n + ); + debug_assert!( + CHUNK & (CHUNK - 1) == 0, + "invalid CHUNK const: not a power of two" + ); - match CHUNK{ + match CHUNK { 8 => { - - izip!($a.chunks_exact(8), $b.chunks_exact(8), $d.chunks_exact_mut(8)).for_each(|(a, b, d)| { + izip!( + $a.chunks_exact(8), + $b.chunks_exact(8), + $d.chunks_exact_mut(8) + ) + .for_each(|(a, b, d)| { $f(&$self, &a[0], &b[0], $c, &mut d[0]); $f(&$self, &a[1], &b[1], $c, &mut d[1]); $f(&$self, &a[2], &b[2], $c, &mut d[2]); @@ -217,12 +257,14 @@ pub mod macros{ $f(&$self, &a[7], &b[7], $c, &mut d[7]); }); - let m = n - (n&7); - izip!($a[m..].iter(), $b[m..].iter(), $d[m..].iter_mut()).for_each(|(a, b, d)| { - $f(&$self, a, b, $c, d); - }); - }, - _=>{ + let m = n - (n & 7); + izip!($a[m..].iter(), $b[m..].iter(), $d[m..].iter_mut()).for_each( + |(a, b, d)| { + $f(&$self, a, b, $c, d); + }, + ); + } + _ => { izip!($a.iter(), $b.iter(), $d.iter_mut()).for_each(|(a, b, d)| { $f(&$self, a, b, $c, d); }); @@ -233,16 +275,21 @@ pub mod macros{ #[macro_export] macro_rules! apply_vsv { - ($self:expr, $f:expr, $a:expr, $c:expr, $b:expr, $CHUNK:expr) => { - let n: usize = $a.len(); - debug_assert!($b.len() == n, "invalid argument b: b.len() = {} != a.len() = {}", $b.len(), n); - debug_assert!(CHUNK&(CHUNK-1) == 0, "invalid CHUNK const: not a power of two"); + debug_assert!( + $b.len() == n, + "invalid argument b: b.len() = {} != a.len() = {}", + $b.len(), + n + ); + debug_assert!( + CHUNK & (CHUNK - 1) == 0, + "invalid CHUNK const: not a power of two" + ); - match CHUNK{ + match CHUNK { 8 => { - izip!($a.chunks_exact(8), $b.chunks_exact_mut(8)).for_each(|(a, b)| { $f(&$self, &a[0], $c, &mut b[0]); $f(&$self, &a[1], $c, &mut b[1]); @@ -254,12 +301,12 @@ pub mod macros{ $f(&$self, &a[7], $c, &mut b[7]); }); - let m = n - (n&7); + let m = n - (n & 7); izip!($a[m..].iter(), $b[m..].iter_mut()).for_each(|(a, b)| { $f(&$self, a, $c, b); }); - }, - _=>{ + } + _ => { izip!($a.iter(), $b.iter_mut()).for_each(|(a, b)| { $f(&$self, a, $c, b); }); @@ -267,4 +314,4 @@ pub mod macros{ } }; } -} \ No newline at end of file +} diff --git a/math/src/modulus.rs b/math/src/modulus.rs index ae46d51..305eee4 100644 --- a/math/src/modulus.rs +++ b/math/src/modulus.rs @@ -1,7 +1,7 @@ -pub mod prime; pub mod barrett; -pub mod montgomery; pub mod impl_u64; +pub mod montgomery; +pub mod prime; pub type REDUCEMOD = u8; @@ -12,159 +12,234 @@ pub const FOURTIMES: REDUCEMOD = 3; pub const BARRETT: REDUCEMOD = 4; pub const BARRETTLAZY: REDUCEMOD = 5; -pub trait WordOps{ +pub trait WordOps { fn log2(self) -> O; - fn reverse_bits_msb(self, n:u32) -> O; + fn reverse_bits_msb(self, n: u32) -> O; fn mask(self) -> O; } -impl WordOps for u64{ +impl WordOps for u64 { #[inline(always)] - fn log2(self) -> u64{ - (u64::BITS - (self-1).leading_zeros()) as _ + fn log2(self) -> u64 { + (u64::BITS - (self - 1).leading_zeros()) as _ } #[inline(always)] - fn reverse_bits_msb(self, n: u32) -> u64{ + fn reverse_bits_msb(self, n: u32) -> u64 { self.reverse_bits() >> (usize::BITS - n) } #[inline(always)] - fn mask(self) -> u64{ - (1< u64 { + (1 << self.log2()) - 1 } } -impl WordOps for usize{ +impl WordOps for usize { #[inline(always)] - fn log2(self) -> usize{ - (usize::BITS - (self-1).leading_zeros()) as _ + fn log2(self) -> usize { + (usize::BITS - (self - 1).leading_zeros()) as _ } #[inline(always)] - fn reverse_bits_msb(self, n: u32) -> usize{ + fn reverse_bits_msb(self, n: u32) -> usize { self.reverse_bits() >> (usize::BITS - n) } #[inline(always)] - fn mask(self) -> usize{ - (1< usize { + (1 << self.log2()) - 1 } } -pub trait ReduceOnce{ +pub trait ReduceOnce { /// Assigns self-q to self if self >= q in constant time. /// User must ensure that 2q fits in O. fn reduce_once_constant_time_assign(&mut self, q: O); /// Returns self-q if self >= q else self in constant time. /// /// User must ensure that 2q fits in O. - fn reduce_once_constant_time(&self, q:O) -> O; + fn reduce_once_constant_time(&self, q: O) -> O; /// Assigns self-q to self if self >= q. /// /// User must ensure that 2q fits in O. fn reduce_once_assign(&mut self, q: O); /// Returns self-q if self >= q else self. /// /// User must ensure that 2q fits in O. - fn reduce_once(&self, q:O) -> O; + fn reduce_once(&self, q: O) -> O; } -pub trait ScalarOperations{ - +pub trait ScalarOperations { // Applies a parameterized modular reduction. - fn sa_reduce_into_sa(&self, x: &mut O); + fn sa_reduce_into_sa(&self, x: &mut O); // Assigns a + b to c. - fn sa_add_sb_into_sc(&self, a: &O, b:&O, c: &mut O); + fn sa_add_sb_into_sc(&self, a: &O, b: &O, c: &mut O); // Assigns a + b to b. - fn sa_add_sb_into_sb(&self, a: &O, b: &mut O); + fn sa_add_sb_into_sb(&self, a: &O, b: &mut O); // Assigns a - b to c. - fn sa_sub_sb_into_sc(&self, a: &O, b:&O, c: &mut O); + fn sa_sub_sb_into_sc(&self, a: &O, b: &O, c: &mut O); // Assigns b - a to b. - fn sa_sub_sb_into_sb(&self, a: &O, b: &mut O); + fn sa_sub_sb_into_sb(&self, a: &O, b: &mut O); // Assigns -a to a. - fn sa_neg_into_sa(&self, a:&mut O); + fn sa_neg_into_sa(&self, a: &mut O); // Assigns -a to b. - fn sa_neg_into_sb(&self, a: &O, b:&mut O); + fn sa_neg_into_sb(&self, a: &O, b: &mut O); // Assigns a * 2^64 to b. - fn sa_prep_mont_into_sb(&self, a: &O, b: &mut montgomery::Montgomery); + fn sa_prep_mont_into_sb( + &self, + a: &O, + b: &mut montgomery::Montgomery, + ); // Assigns a * b to c. - fn sa_mont_mul_sb_into_sc(&self, a:&montgomery::Montgomery, b:&O, c: &mut O); + fn sa_mont_mul_sb_into_sc( + &self, + a: &montgomery::Montgomery, + b: &O, + c: &mut O, + ); // Assigns a * b to b. - fn sa_mont_mul_sb_into_sb(&self, a:&montgomery::Montgomery, b:&mut O); + fn sa_mont_mul_sb_into_sb( + &self, + a: &montgomery::Montgomery, + b: &mut O, + ); // Assigns a * b to c. - fn sa_barrett_mul_sb_into_sc(&self, a: &barrett::Barrett, b:&O, c: &mut O); + fn sa_barrett_mul_sb_into_sc( + &self, + a: &barrett::Barrett, + b: &O, + c: &mut O, + ); // Assigns a * b to b. - fn sa_barrett_mul_sb_into_sb(&self, a:&barrett::Barrett, b:&mut O); + fn sa_barrett_mul_sb_into_sb( + &self, + a: &barrett::Barrett, + b: &mut O, + ); // Assigns (a + 2q - b) * c to d. - fn sa_sub_sb_mul_sc_into_sd(&self, a: &O, b: &O, c: &barrett::Barrett, d: &mut O); + fn sa_sub_sb_mul_sc_into_sd( + &self, + a: &O, + b: &O, + c: &barrett::Barrett, + d: &mut O, + ); // Assigns (a + 2q - b) * c to b. - fn sa_sub_sb_mul_sc_into_sb(&self, a: &u64, c: &barrett::Barrett, b: &mut u64); + fn sa_sub_sb_mul_sc_into_sb( + &self, + a: &u64, + c: &barrett::Barrett, + b: &mut u64, + ); } -pub trait VectorOperations{ - +pub trait VectorOperations { // Applies a parameterized modular reduction. - fn va_reduce_into_va(&self, x: &mut [O]); + fn va_reduce_into_va(&self, x: &mut [O]); // ADD // Assigns a[i] + b[i] to c[i] - fn va_add_vb_into_vc(&self, a: &[O], b:&[O], c: &mut [O]); + fn va_add_vb_into_vc( + &self, + a: &[O], + b: &[O], + c: &mut [O], + ); // Assigns a[i] + b[i] to b[i] - fn va_add_vb_into_vb(&self, a: &[O], b: &mut [O]); + fn va_add_vb_into_vb(&self, a: &[O], b: &mut [O]); // Assigns a[i] + b to c[i] - fn va_add_sb_into_vc(&self, a: &[O], b:&O, c:&mut [O]); + fn va_add_sb_into_vc( + &self, + a: &[O], + b: &O, + c: &mut [O], + ); // Assigns b[i] + a to b[i] - fn sa_add_vb_into_vb(&self, a:&O, b:&mut [O]); + fn sa_add_vb_into_vb(&self, a: &O, b: &mut [O]); // SUB // Assigns a[i] - b[i] to b[i] - fn va_sub_vb_into_vb(&self, a: &[O], b: &mut [O]); + fn va_sub_vb_into_vb(&self, a: &[O], b: &mut [O]); // Assigns a[i] - b[i] to c[i] - fn va_sub_vb_into_vc(&self, a: &[O], b:&[O], c: &mut [O]); + fn va_sub_vb_into_vc( + &self, + a: &[O], + b: &[O], + c: &mut [O], + ); // NEG // Assigns -a[i] to a[i]. - fn va_neg_into_va(&self, a: &mut [O]); + fn va_neg_into_va(&self, a: &mut [O]); // Assigns -a[i] to a[i]. - fn va_neg_into_vb(&self, a: &[O], b: &mut [O]); + fn va_neg_into_vb(&self, a: &[O], b: &mut [O]); // MUL MONTGOMERY // Assigns a * 2^64 to b. - fn va_prep_mont_into_vb(&self, a: &[O], b: &mut [montgomery::Montgomery]); + fn va_prep_mont_into_vb( + &self, + a: &[O], + b: &mut [montgomery::Montgomery], + ); // Assigns a[i] * b[i] to c[i]. - fn va_mont_mul_vb_into_vc(&self, a:&[montgomery::Montgomery], b:&[O], c: &mut [O]); + fn va_mont_mul_vb_into_vc( + &self, + a: &[montgomery::Montgomery], + b: &[O], + c: &mut [O], + ); // Assigns a[i] * b[i] to b[i]. - fn va_mont_mul_vb_into_vb(&self, a:&[montgomery::Montgomery], b:&mut [O]); + fn va_mont_mul_vb_into_vb( + &self, + a: &[montgomery::Montgomery], + b: &mut [O], + ); // MUL BARRETT // Assigns a * b[i] to b[i]. - fn sa_barrett_mul_vb_into_vb(&self, a:& barrett::Barrett, b:&mut [u64]); + fn sa_barrett_mul_vb_into_vb( + &self, + a: &barrett::Barrett, + b: &mut [u64], + ); // Assigns a * b[i] to c[i]. - fn sa_barrett_mul_vb_into_vc(&self, a:& barrett::Barrett, b:&[u64], c: &mut [u64]); + fn sa_barrett_mul_vb_into_vc( + &self, + a: &barrett::Barrett, + b: &[u64], + c: &mut [u64], + ); // OTHERS // Assigns (a[i] + 2q - b[i]) * c to d[i]. - fn va_sub_vb_mul_sc_into_vd(&self, a: &[u64], b: &[u64], c: &barrett::Barrett, d: &mut [u64]); + fn va_sub_vb_mul_sc_into_vd( + &self, + a: &[u64], + b: &[u64], + c: &barrett::Barrett, + d: &mut [u64], + ); // Assigns (a[i] + 2q - b[i]) * c to b[i]. - fn va_sub_vb_mul_sc_into_vb(&self, a: &[u64], c: &barrett::Barrett, b: &mut [u64]); + fn va_sub_vb_mul_sc_into_vb( + &self, + a: &[u64], + c: &barrett::Barrett, + b: &mut [u64], + ); } - - - - diff --git a/math/src/modulus/barrett.rs b/math/src/modulus/barrett.rs index efaee2e..2f7dfa1 100644 --- a/math/src/modulus/barrett.rs +++ b/math/src/modulus/barrett.rs @@ -2,7 +2,6 @@ pub struct Barrett(pub O, pub O); impl Barrett { - #[inline(always)] pub fn value(&self) -> &O { &self.0 @@ -15,25 +14,23 @@ impl Barrett { } #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct BarrettPrecomp{ +pub struct BarrettPrecomp { pub q: O, - pub two_q:O, - pub four_q:O, - pub lo:O, - pub hi:O, + pub two_q: O, + pub four_q: O, + pub lo: O, + pub hi: O, pub one: Barrett, } -impl BarrettPrecomp{ - +impl BarrettPrecomp { #[inline(always)] - pub fn value_hi(&self) -> &O{ + pub fn value_hi(&self) -> &O { &self.hi } #[inline(always)] - pub fn value_lo(&self) -> &O{ + pub fn value_lo(&self) -> &O { &self.lo } } - diff --git a/math/src/modulus/impl_u64/barrett.rs b/math/src/modulus/impl_u64/barrett.rs index 0cec9f6..36bbb1b 100644 --- a/math/src/modulus/impl_u64/barrett.rs +++ b/math/src/modulus/impl_u64/barrett.rs @@ -1,17 +1,24 @@ use crate::modulus::barrett::{Barrett, BarrettPrecomp}; use crate::modulus::ReduceOnce; -use crate::modulus::{REDUCEMOD, NONE, ONCE, TWICE, FOURTIMES, BARRETT, BARRETTLAZY}; +use crate::modulus::{BARRETT, BARRETTLAZY, FOURTIMES, NONE, ONCE, REDUCEMOD, TWICE}; use num_bigint::BigUint; use num_traits::cast::ToPrimitive; -impl BarrettPrecomp{ - +impl BarrettPrecomp { pub fn new(q: u64) -> BarrettPrecomp { - let big_r: BigUint = (BigUint::from(1 as usize)<<((u64::BITS<<1) as usize)) / BigUint::from(q); + let big_r: BigUint = + (BigUint::from(1 as usize) << ((u64::BITS << 1) as usize)) / BigUint::from(q); let lo: u64 = (&big_r & BigUint::from(u64::MAX)).to_u64().unwrap(); let hi: u64 = (big_r >> u64::BITS).to_u64().unwrap(); - let mut precomp: BarrettPrecomp = Self{q:q, two_q:q<<1, four_q:q<<2, lo:lo, hi:hi, one:Barrett(0,0)}; + let mut precomp: BarrettPrecomp = Self { + q: q, + two_q: q << 1, + four_q: q << 2, + lo: lo, + hi: hi, + one: Barrett(0, 0), + }; precomp.one = precomp.prepare(1); precomp } @@ -22,27 +29,27 @@ impl BarrettPrecomp{ } #[inline(always)] - pub fn reduce_assign(&self, x: &mut u64){ + pub fn reduce_assign(&self, x: &mut u64) { match REDUCE { - NONE =>{}, - ONCE =>{x.reduce_once_assign(self.q)}, - TWICE=>{x.reduce_once_assign(self.two_q)}, - FOURTIMES =>{x.reduce_once_assign(self.four_q)}, - BARRETT =>{ + NONE => {} + ONCE => x.reduce_once_assign(self.q), + TWICE => x.reduce_once_assign(self.two_q), + FOURTIMES => x.reduce_once_assign(self.four_q), + BARRETT => { let (_, mhi) = x.widening_mul(self.hi); *x = *x - mhi.wrapping_mul(self.q); x.reduce_once_assign(self.q); - }, - BARRETTLAZY =>{ + } + BARRETTLAZY => { let (_, mhi) = x.widening_mul(self.hi); *x = *x - mhi.wrapping_mul(self.q) - }, - _ => unreachable!("invalid REDUCE argument") + } + _ => unreachable!("invalid REDUCE argument"), } } #[inline(always)] - pub fn reduce(&self, x: &u64) -> u64{ + pub fn reduce(&self, x: &u64) -> u64 { let mut r = *x; self.reduce_assign::(&mut r); r @@ -56,16 +63,16 @@ impl BarrettPrecomp{ } #[inline(always)] - pub fn mul_external(&self, lhs: Barrett, rhs: u64) -> u64 { + pub fn mul_external(&self, lhs: Barrett, rhs: u64) -> u64 { let mut r: u64 = rhs; self.mul_external_assign::(lhs, &mut r); r } #[inline(always)] - pub fn mul_external_assign(&self, lhs: Barrett, rhs: &mut u64){ + pub fn mul_external_assign(&self, lhs: Barrett, rhs: &mut u64) { let t: u64 = ((*lhs.quotient() as u128 * *rhs as u128) >> 64) as _; *rhs = (rhs.wrapping_mul(*lhs.value())).wrapping_sub(self.q.wrapping_mul(t)); self.reduce_assign::(rhs); } -} \ No newline at end of file +} diff --git a/math/src/modulus/impl_u64/mod.rs b/math/src/modulus/impl_u64/mod.rs index 6d8942a..8e38806 100644 --- a/math/src/modulus/impl_u64/mod.rs +++ b/math/src/modulus/impl_u64/mod.rs @@ -1,32 +1,32 @@ -pub mod prime; pub mod barrett; pub mod montgomery; pub mod operations; +pub mod prime; use crate::modulus::ReduceOnce; -impl ReduceOnce for u64{ +impl ReduceOnce for u64 { #[inline(always)] - fn reduce_once_constant_time_assign(&mut self, q: u64){ + fn reduce_once_constant_time_assign(&mut self, q: u64) { debug_assert!(q < 0x8000000000000000, "2q >= 2^64"); - *self -= (q.wrapping_sub(*self)>>63)*q; - } - - #[inline(always)] - fn reduce_once_constant_time(&self, q:u64) -> u64{ - debug_assert!(q < 0x8000000000000000, "2q >= 2^64"); - self - (q.wrapping_sub(*self)>>63)*q + *self -= (q.wrapping_sub(*self) >> 63) * q; } #[inline(always)] - fn reduce_once_assign(&mut self, q: u64){ + fn reduce_once_constant_time(&self, q: u64) -> u64 { + debug_assert!(q < 0x8000000000000000, "2q >= 2^64"); + self - (q.wrapping_sub(*self) >> 63) * q + } + + #[inline(always)] + fn reduce_once_assign(&mut self, q: u64) { debug_assert!(q < 0x8000000000000000, "2q >= 2^64"); *self = *self.min(&mut self.wrapping_sub(q)) } #[inline(always)] - fn reduce_once(&self, q:u64) -> u64{ + fn reduce_once(&self, q: u64) -> u64 { debug_assert!(q < 0x8000000000000000, "2q >= 2^64"); *self.min(&mut self.wrapping_sub(q)) } -} \ No newline at end of file +} diff --git a/math/src/modulus/impl_u64/montgomery.rs b/math/src/modulus/impl_u64/montgomery.rs index 99395d0..6bceebe 100644 --- a/math/src/modulus/impl_u64/montgomery.rs +++ b/math/src/modulus/impl_u64/montgomery.rs @@ -1,50 +1,52 @@ - -use crate::modulus::ReduceOnce; -use crate::modulus::montgomery::{MontgomeryPrecomp, Montgomery}; use crate::modulus::barrett::BarrettPrecomp; -use crate::modulus::{REDUCEMOD, ONCE}; +use crate::modulus::montgomery::{Montgomery, MontgomeryPrecomp}; +use crate::modulus::ReduceOnce; +use crate::modulus::{ONCE, REDUCEMOD}; extern crate test; /// MontgomeryPrecomp is a set of methods implemented for MontgomeryPrecomp /// enabling Montgomery arithmetic over u64 values. -impl MontgomeryPrecomp{ - +impl MontgomeryPrecomp { /// Returns an new instance of MontgomeryPrecomp. /// This method will fail if gcd(q, 2^64) != 1. #[inline(always)] - pub fn new(q: u64) -> MontgomeryPrecomp{ - assert!(q & 1 != 0, "Invalid argument: gcd(q={}, radix=2^64) != 1", q); + pub fn new(q: u64) -> MontgomeryPrecomp { + assert!( + q & 1 != 0, + "Invalid argument: gcd(q={}, radix=2^64) != 1", + q + ); let mut q_inv: u64 = 1; let mut q_pow = q; - for _i in 0..63{ + for _i in 0..63 { q_inv = q_inv.wrapping_mul(q_pow); q_pow = q_pow.wrapping_mul(q_pow); } - let mut precomp = Self{ + let mut precomp = Self { q: q, - two_q: q<<1, - four_q: q<<2, - barrett: BarrettPrecomp::new(q), + two_q: q << 1, + four_q: q << 2, + barrett: BarrettPrecomp::new(q), q_inv: q_inv, one: 0, - minus_one:0, + minus_one: 0, }; precomp.one = precomp.prepare::(1); - precomp.minus_one = q-precomp.one; + precomp.minus_one = q - precomp.one; precomp } /// Returns 2^64 mod q as a Montgomery. #[inline(always)] - pub fn one(&self) -> Montgomery{ + pub fn one(&self) -> Montgomery { self.one } /// Returns (q-1) * 2^64 mod q as a Montgomery. #[inline(always)] - pub fn minus_one(&self) -> Montgomery{ + pub fn minus_one(&self) -> Montgomery { self.minus_one } @@ -53,7 +55,7 @@ impl MontgomeryPrecomp{ /// - ONCE: subtracts q if x >= q. /// - FULL: maps x to x mod q using Barrett reduction. #[inline(always)] - pub fn reduce(&self, x: u64) -> u64{ + pub fn reduce(&self, x: u64) -> u64 { let mut r: u64 = x; self.reduce_assign::(&mut r); r @@ -64,13 +66,13 @@ impl MontgomeryPrecomp{ /// - ONCE: subtracts q if x >= q. /// - FULL: maps x to x mod q using Barrett reduction. #[inline(always)] - pub fn reduce_assign(&self, x: &mut u64){ + pub fn reduce_assign(&self, x: &mut u64) { self.barrett.reduce_assign::(x); } /// Returns lhs * 2^64 mod q as a Montgomery. #[inline(always)] - pub fn prepare(&self, lhs: u64) -> Montgomery{ + pub fn prepare(&self, lhs: u64) -> Montgomery { let mut rhs: u64 = 0; self.prepare_assign::(lhs, &mut rhs); rhs @@ -78,15 +80,17 @@ impl MontgomeryPrecomp{ /// Assigns lhs * 2^64 mod q to rhs. #[inline(always)] - pub fn prepare_assign(&self, lhs: u64, rhs: &mut Montgomery){ + pub fn prepare_assign(&self, lhs: u64, rhs: &mut Montgomery) { let (_, mhi) = lhs.widening_mul(*self.barrett.value_lo()); - *rhs = (lhs.wrapping_mul(*self.barrett.value_hi()).wrapping_add(mhi)).wrapping_mul(self.q).wrapping_neg(); - self.reduce_assign::(rhs); + *rhs = (lhs.wrapping_mul(*self.barrett.value_hi()).wrapping_add(mhi)) + .wrapping_mul(self.q) + .wrapping_neg(); + self.reduce_assign::(rhs); } /// Returns lhs * (2^64)^-1 mod q as a u64. #[inline(always)] - pub fn unprepare(&self, lhs: Montgomery) -> u64{ + pub fn unprepare(&self, lhs: Montgomery) -> u64 { let mut rhs = 0u64; self.unprepare_assign::(lhs, &mut rhs); rhs @@ -94,14 +98,14 @@ impl MontgomeryPrecomp{ /// Assigns lhs * (2^64)^-1 mod q to rhs. #[inline(always)] - pub fn unprepare_assign(&self, lhs: Montgomery, rhs: &mut u64){ + pub fn unprepare_assign(&self, lhs: Montgomery, rhs: &mut u64) { let (_, r) = self.q.widening_mul(lhs.wrapping_mul(self.q_inv)); - *rhs = self.reduce::(self.q.wrapping_sub(r)); + *rhs = self.reduce::(self.q.wrapping_sub(r)); } /// Returns lhs * rhs * (2^{64})^-1 mod q. #[inline(always)] - pub fn mul_external(&self, lhs: Montgomery, rhs: u64) -> u64{ + pub fn mul_external(&self, lhs: Montgomery, rhs: u64) -> u64 { let mut r: u64 = rhs; self.mul_external_assign::(lhs, &mut r); r @@ -109,7 +113,11 @@ impl MontgomeryPrecomp{ /// Assigns lhs * rhs * (2^{64})^-1 mod q to rhs. #[inline(always)] - pub fn mul_external_assign(&self, lhs: Montgomery, rhs: &mut u64){ + pub fn mul_external_assign( + &self, + lhs: Montgomery, + rhs: &mut u64, + ) { let (mlo, mhi) = lhs.widening_mul(*rhs); let (_, hhi) = self.q.widening_mul(mlo.wrapping_mul(self.q_inv)); *rhs = self.reduce::(mhi.wrapping_sub(hhi).wrapping_add(self.q)); @@ -117,42 +125,54 @@ impl MontgomeryPrecomp{ /// Returns lhs * rhs * (2^{64})^-1 mod q in range [0, 2q-1]. #[inline(always)] - pub fn mul_internal(&self, lhs: Montgomery, rhs: Montgomery) -> Montgomery{ + pub fn mul_internal( + &self, + lhs: Montgomery, + rhs: Montgomery, + ) -> Montgomery { self.mul_external::(lhs, rhs) } /// Assigns lhs * rhs * (2^{64})^-1 mod q to rhs. #[inline(always)] - pub fn mul_internal_assign(&self, lhs: Montgomery, rhs: &mut Montgomery){ + pub fn mul_internal_assign( + &self, + lhs: Montgomery, + rhs: &mut Montgomery, + ) { self.mul_external_assign::(lhs, rhs); } #[inline(always)] - pub fn add_internal(&self, lhs: Montgomery, rhs: Montgomery) -> Montgomery{ + pub fn add_internal(&self, lhs: Montgomery, rhs: Montgomery) -> Montgomery { rhs + lhs } /// Assigns lhs + rhs to rhs. #[inline(always)] - pub fn add_internal_lazy_assign(&self, lhs: Montgomery, rhs: &mut Montgomery){ + pub fn add_internal_lazy_assign(&self, lhs: Montgomery, rhs: &mut Montgomery) { *rhs += lhs } /// Assigns lhs + rhs - q if (lhs + rhs) >= q to rhs. #[inline(always)] - pub fn add_internal_reduce_once_assign(&self, lhs: Montgomery, rhs: &mut Montgomery){ + pub fn add_internal_reduce_once_assign( + &self, + lhs: Montgomery, + rhs: &mut Montgomery, + ) { self.add_internal_lazy_assign(lhs, rhs); rhs.reduce_once_assign(self.q); } /// Returns (x^exponent) * 2^64 mod q. #[inline(always)] - pub fn pow(&self, x: Montgomery, exponent:u64) -> Montgomery{ + pub fn pow(&self, x: Montgomery, exponent: u64) -> Montgomery { let mut y: Montgomery = self.one(); let mut x_mut: Montgomery = x; let mut i: u64 = exponent; - while i > 0{ - if i & 1 == 1{ + while i > 0 { + if i & 1 == 1 { self.mul_internal_assign::(x_mut, &mut y); } self.mul_internal_assign::(x_mut, &mut x_mut); @@ -166,27 +186,29 @@ impl MontgomeryPrecomp{ #[cfg(test)] mod tests { - use crate::modulus::montgomery; use super::*; + use crate::modulus::montgomery; use test::Bencher; #[test] fn test_mul_external() { let q: u64 = 0x1fffffffffe00001; - let m_precomp = montgomery::MontgomeryPrecomp::new(q); + let m_precomp = montgomery::MontgomeryPrecomp::new(q); let x: u64 = 0x5f876e514845cc8b; let y: u64 = 0xad726f98f24a761a; let y_mont = m_precomp.prepare::(y); - assert!(m_precomp.mul_external::(y_mont, x) == (x as u128 * y as u128 % q as u128) as u64); + assert!( + m_precomp.mul_external::(y_mont, x) == (x as u128 * y as u128 % q as u128) as u64 + ); } #[bench] - fn bench_mul_external(b: &mut Bencher){ + fn bench_mul_external(b: &mut Bencher) { let q: u64 = 0x1fffffffffe00001; - let m_precomp = montgomery::MontgomeryPrecomp::new(q); + let m_precomp = montgomery::MontgomeryPrecomp::new(q); let mut x: u64 = 0x5f876e514845cc8b; let y: u64 = 0xad726f98f24a761a; let y_mont = m_precomp.prepare::(y); b.iter(|| m_precomp.mul_external_assign::(y_mont, &mut x)); } -} \ No newline at end of file +} diff --git a/math/src/modulus/impl_u64/operations.rs b/math/src/modulus/impl_u64/operations.rs index ac218c6..c9804fb 100644 --- a/math/src/modulus/impl_u64/operations.rs +++ b/math/src/modulus/impl_u64/operations.rs @@ -1,15 +1,13 @@ - -use crate::modulus::{ScalarOperations, VectorOperations}; +use crate::modulus::barrett::Barrett; +use crate::modulus::montgomery::Montgomery; use crate::modulus::prime::Prime; use crate::modulus::ReduceOnce; -use crate::modulus::montgomery::Montgomery; -use crate::modulus::barrett::Barrett; use crate::modulus::REDUCEMOD; -use crate::{apply_v, apply_vv, apply_vvv, apply_sv, apply_svv, apply_vvsv, apply_vsv}; +use crate::modulus::{ScalarOperations, VectorOperations}; +use crate::{apply_sv, apply_svv, apply_v, apply_vsv, apply_vv, apply_vvsv, apply_vvv}; use itertools::izip; -impl ScalarOperations for Prime{ - +impl ScalarOperations for Prime { /// Applies a modular reduction on x based on REDUCE: /// - LAZY: no modular reduction. /// - ONCE: subtracts q if x >= q. @@ -18,84 +16,104 @@ impl ScalarOperations for Prime{ /// - BARRETT: maps x to x mod q using Barrett reduction. /// - BARRETTLAZY: maps x to x mod q using Barrett reduction with values in [0, 2q-1]. #[inline(always)] - fn sa_reduce_into_sa(&self, a: &mut u64){ + fn sa_reduce_into_sa(&self, a: &mut u64) { self.montgomery.reduce_assign::(a); } #[inline(always)] - fn sa_add_sb_into_sc(&self, a: &u64, b: &u64, c: &mut u64){ + fn sa_add_sb_into_sc(&self, a: &u64, b: &u64, c: &mut u64) { *c = a.wrapping_add(*b); self.sa_reduce_into_sa::(c); } #[inline(always)] - fn sa_add_sb_into_sb(&self, a: &u64, b: &mut u64){ + fn sa_add_sb_into_sb(&self, a: &u64, b: &mut u64) { *b = a.wrapping_add(*b); self.sa_reduce_into_sa::(b); } #[inline(always)] - fn sa_sub_sb_into_sc(&self, a: &u64, b: &u64, c: &mut u64){ + fn sa_sub_sb_into_sc(&self, a: &u64, b: &u64, c: &mut u64) { *c = a.wrapping_add(self.q.wrapping_sub(*b)).reduce_once(self.q); } #[inline(always)] - fn sa_sub_sb_into_sb(&self, a: &u64, b: &mut u64){ + fn sa_sub_sb_into_sb(&self, a: &u64, b: &mut u64) { *b = a.wrapping_add(self.q.wrapping_sub(*b)).reduce_once(self.q); } #[inline(always)] - fn sa_neg_into_sa(&self, a: &mut u64){ + fn sa_neg_into_sa(&self, a: &mut u64) { *a = self.q.wrapping_sub(*a); self.sa_reduce_into_sa::(a) } #[inline(always)] - fn sa_neg_into_sb(&self, a: &u64, b: &mut u64){ + fn sa_neg_into_sb(&self, a: &u64, b: &mut u64) { *b = self.q.wrapping_sub(*a); self.sa_reduce_into_sa::(b) } #[inline(always)] - fn sa_prep_mont_into_sb(&self, a: &u64, b: &mut Montgomery){ + fn sa_prep_mont_into_sb(&self, a: &u64, b: &mut Montgomery) { self.montgomery.prepare_assign::(*a, b); } #[inline(always)] - fn sa_mont_mul_sb_into_sc(&self, a: &Montgomery, b:&u64, c: &mut u64){ + fn sa_mont_mul_sb_into_sc( + &self, + a: &Montgomery, + b: &u64, + c: &mut u64, + ) { *c = self.montgomery.mul_external::(*a, *b); } #[inline(always)] - fn sa_mont_mul_sb_into_sb(&self, a:&Montgomery, b:&mut u64){ + fn sa_mont_mul_sb_into_sb(&self, a: &Montgomery, b: &mut u64) { self.montgomery.mul_external_assign::(*a, b); } #[inline(always)] - fn sa_barrett_mul_sb_into_sc(&self, a: &Barrett, b:&u64, c: &mut u64){ + fn sa_barrett_mul_sb_into_sc( + &self, + a: &Barrett, + b: &u64, + c: &mut u64, + ) { *c = self.barrett.mul_external::(*a, *b); } #[inline(always)] - fn sa_barrett_mul_sb_into_sb(&self, a:&Barrett, b:&mut u64){ + fn sa_barrett_mul_sb_into_sb(&self, a: &Barrett, b: &mut u64) { self.barrett.mul_external_assign::(*a, b); } #[inline(always)] - fn sa_sub_sb_mul_sc_into_sd(&self, a: &u64, b: &u64, c: &Barrett, d: &mut u64){ + fn sa_sub_sb_mul_sc_into_sd( + &self, + a: &u64, + b: &u64, + c: &Barrett, + d: &mut u64, + ) { *d = self.two_q.wrapping_sub(*b).wrapping_add(*a); self.barrett.mul_external_assign::(*c, d); } #[inline(always)] - fn sa_sub_sb_mul_sc_into_sb(&self, a: &u64, c: &Barrett, b: &mut u64){ + fn sa_sub_sb_mul_sc_into_sb( + &self, + a: &u64, + c: &Barrett, + b: &mut u64, + ) { *b = self.two_q.wrapping_sub(*b).wrapping_add(*a); self.barrett.mul_external_assign::(*c, b); } } -impl VectorOperations for Prime{ - +impl VectorOperations for Prime { /// Applies a modular reduction on x based on REDUCE: /// - LAZY: no modular reduction. /// - ONCE: subtracts q if x >= q. @@ -104,80 +122,166 @@ impl VectorOperations for Prime{ /// - BARRETT: maps x to x mod q using Barrett reduction. /// - BARRETTLAZY: maps x to x mod q using Barrett reduction with values in [0, 2q-1]. #[inline(always)] - fn va_reduce_into_va(&self, a: &mut [u64]){ + fn va_reduce_into_va(&self, a: &mut [u64]) { apply_v!(self, Self::sa_reduce_into_sa::, a, CHUNK); } #[inline(always)] - fn va_add_vb_into_vc(&self, a: &[u64], b:&[u64], c:&mut [u64]){ + fn va_add_vb_into_vc( + &self, + a: &[u64], + b: &[u64], + c: &mut [u64], + ) { apply_vvv!(self, Self::sa_add_sb_into_sc::, a, b, c, CHUNK); } #[inline(always)] - fn va_add_vb_into_vb(&self, a: &[u64], b:&mut [u64]){ + fn va_add_vb_into_vb( + &self, + a: &[u64], + b: &mut [u64], + ) { apply_vv!(self, Self::sa_add_sb_into_sb::, a, b, CHUNK); } #[inline(always)] - fn va_add_sb_into_vc(&self, a: &[u64], b:&u64, c:&mut [u64]){ + fn va_add_sb_into_vc( + &self, + a: &[u64], + b: &u64, + c: &mut [u64], + ) { apply_vsv!(self, Self::sa_add_sb_into_sc::, a, b, c, CHUNK); } #[inline(always)] - fn sa_add_vb_into_vb(&self, a:&u64, b:&mut [u64]){ + fn sa_add_vb_into_vb( + &self, + a: &u64, + b: &mut [u64], + ) { apply_sv!(self, Self::sa_add_sb_into_sb::, a, b, CHUNK); } #[inline(always)] - fn va_sub_vb_into_vc(&self, a: &[u64], b:&[u64], c:&mut [u64]){ + fn va_sub_vb_into_vc( + &self, + a: &[u64], + b: &[u64], + c: &mut [u64], + ) { apply_vvv!(self, Self::sa_sub_sb_into_sc::, a, b, c, CHUNK); } - + #[inline(always)] - fn va_sub_vb_into_vb(&self, a: &[u64], b:&mut [u64]){ + fn va_sub_vb_into_vb( + &self, + a: &[u64], + b: &mut [u64], + ) { apply_vv!(self, Self::sa_sub_sb_into_sb::, a, b, CHUNK); } #[inline(always)] - fn va_neg_into_va(&self, a: &mut [u64]){ + fn va_neg_into_va(&self, a: &mut [u64]) { apply_v!(self, Self::sa_neg_into_sa::, a, CHUNK); } #[inline(always)] - fn va_neg_into_vb(&self, a: &[u64], b: &mut [u64]){ + fn va_neg_into_vb( + &self, + a: &[u64], + b: &mut [u64], + ) { apply_vv!(self, Self::sa_neg_into_sb::, a, b, CHUNK); } #[inline(always)] - fn va_prep_mont_into_vb(&self, a: &[u64], b: &mut [Montgomery]){ + fn va_prep_mont_into_vb( + &self, + a: &[u64], + b: &mut [Montgomery], + ) { apply_vv!(self, Self::sa_prep_mont_into_sb::, a, b, CHUNK); } #[inline(always)] - fn va_mont_mul_vb_into_vc(&self, a:& [Montgomery], b:&[u64], c: &mut [u64]){ + fn va_mont_mul_vb_into_vc( + &self, + a: &[Montgomery], + b: &[u64], + c: &mut [u64], + ) { apply_vvv!(self, Self::sa_mont_mul_sb_into_sc::, a, b, c, CHUNK); } #[inline(always)] - fn va_mont_mul_vb_into_vb(&self, a:& [Montgomery], b:&mut [u64]){ + fn va_mont_mul_vb_into_vb( + &self, + a: &[Montgomery], + b: &mut [u64], + ) { apply_vv!(self, Self::sa_mont_mul_sb_into_sb::, a, b, CHUNK); } #[inline(always)] - fn sa_barrett_mul_vb_into_vc(&self, a:& Barrett, b:&[u64], c: &mut [u64]){ - apply_svv!(self, Self::sa_barrett_mul_sb_into_sc::, a, b, c, CHUNK); + fn sa_barrett_mul_vb_into_vc( + &self, + a: &Barrett, + b: &[u64], + c: &mut [u64], + ) { + apply_svv!( + self, + Self::sa_barrett_mul_sb_into_sc::, + a, + b, + c, + CHUNK + ); } #[inline(always)] - fn sa_barrett_mul_vb_into_vb(&self, a:& Barrett, b:&mut [u64]){ + fn sa_barrett_mul_vb_into_vb( + &self, + a: &Barrett, + b: &mut [u64], + ) { apply_sv!(self, Self::sa_barrett_mul_sb_into_sb::, a, b, CHUNK); } - fn va_sub_vb_mul_sc_into_vd(&self, a: &[u64], b: &[u64], c: &Barrett, d: &mut [u64]){ - apply_vvsv!(self, Self::sa_sub_sb_mul_sc_into_sd::, a, b, c, d, CHUNK); + fn va_sub_vb_mul_sc_into_vd( + &self, + a: &[u64], + b: &[u64], + c: &Barrett, + d: &mut [u64], + ) { + apply_vvsv!( + self, + Self::sa_sub_sb_mul_sc_into_sd::, + a, + b, + c, + d, + CHUNK + ); } - fn va_sub_vb_mul_sc_into_vb(&self, a: &[u64], b: &Barrett, c: &mut [u64]){ - apply_vsv!(self, Self::sa_sub_sb_mul_sc_into_sb::, a, b, c, CHUNK); + fn va_sub_vb_mul_sc_into_vb( + &self, + a: &[u64], + b: &Barrett, + c: &mut [u64], + ) { + apply_vsv!( + self, + Self::sa_sub_sb_mul_sc_into_sb::, + a, + b, + c, + CHUNK + ); } } diff --git a/math/src/modulus/impl_u64/prime.rs b/math/src/modulus/impl_u64/prime.rs index 13b27eb..1662d0c 100644 --- a/math/src/modulus/impl_u64/prime.rs +++ b/math/src/modulus/impl_u64/prime.rs @@ -1,117 +1,112 @@ -use crate::modulus::prime::Prime; -use crate::modulus::montgomery::{Montgomery, MontgomeryPrecomp}; use crate::modulus::barrett::BarrettPrecomp; +use crate::modulus::montgomery::{Montgomery, MontgomeryPrecomp}; +use crate::modulus::prime::Prime; use crate::modulus::ONCE; use primality_test::is_prime; use prime_factorization::Factorization; -impl Prime{ - +impl Prime { /// Returns a new instance of Prime. /// Panics if q_base is not a prime > 2 and /// if q_base^q_power would overflow u64. - pub fn new(q_base: u64, q_power: usize) -> Self{ - assert!(is_prime(q_base) && q_base > 2); + pub fn new(q_base: u64, q_power: usize) -> Self { + assert!(is_prime(q_base) && q_base > 2); Self::new_unchecked(q_base, q_power) - } + } /// Returns a new instance of Prime. /// Does not check if q_base is a prime > 2. /// Panics if q_base^q_power would overflow u64. - pub fn new_unchecked(q_base: u64, q_power: usize) -> Self { - + pub fn new_unchecked(q_base: u64, q_power: usize) -> Self { let mut q = q_base; - for _i in 1..q_power{ + for _i in 1..q_power { q *= q_base } assert!(q.next_power_of_two().ilog2() <= 61); - let mut phi = q_base-1; - for _i in 1..q_power{ + let mut phi = q_base - 1; + for _i in 1..q_power { phi *= q_base } let mut prime: Prime = Self { - q:q, - two_q:q<<1, - four_q:q<<2, - q_base:q_base, - q_power:q_power, + q: q, + two_q: q << 1, + four_q: q << 2, + q_base: q_base, + q_power: q_power, factors: Vec::new(), - montgomery:MontgomeryPrecomp::new(q), - barrett:BarrettPrecomp::new(q), - phi:phi, + montgomery: MontgomeryPrecomp::new(q), + barrett: BarrettPrecomp::new(q), + phi: phi, }; prime.check_factors(); prime - } - pub fn q(&self) -> u64{ + pub fn q(&self) -> u64 { self.q } - pub fn q_base(&self) -> u64{ + pub fn q_base(&self) -> u64 { self.q_base } - pub fn q_power(&self) -> usize{ + pub fn q_power(&self) -> usize { self.q_power } /// Returns x^exponen mod q. #[inline(always)] - pub fn pow(&self, x: u64, exponent: u64) -> u64{ + pub fn pow(&self, x: u64, exponent: u64) -> u64 { let mut y_mont: Montgomery = self.montgomery.one(); let mut x_mont: Montgomery = self.montgomery.prepare::(x); let mut i: u64 = exponent; - while i > 0{ - if i & 1 == 1{ - self.montgomery.mul_internal_assign::(x_mont, &mut y_mont); + while i > 0 { + if i & 1 == 1 { + self.montgomery + .mul_internal_assign::(x_mont, &mut y_mont); } - self.montgomery.mul_internal_assign::(x_mont, &mut x_mont); + self.montgomery + .mul_internal_assign::(x_mont, &mut x_mont); i >>= 1; - } - + self.montgomery.unprepare::(y_mont) } /// Returns x^-1 mod q. /// User must ensure that x is not divisible by q_base. #[inline(always)] - pub fn inv(&self, x: u64) -> u64{ - self.pow(x, self.phi-1) + pub fn inv(&self, x: u64) -> u64 { + self.pow(x, self.phi - 1) } } -impl Prime{ +impl Prime { /// Returns the smallest nth primitive root of q_base. - pub fn primitive_root(&self) -> u64{ - + pub fn primitive_root(&self) -> u64 { let mut candidate: u64 = 1u64; let mut not_found: bool = true; - while not_found{ - + while not_found { candidate += 1; - for &factor in &self.factors{ - - if pow(candidate, (self.q_base-1)/factor, self.q_base) == 1{ + for &factor in &self.factors { + if pow(candidate, (self.q_base - 1) / factor, self.q_base) == 1 { not_found = true; - break + break; } not_found = false; } } - if not_found{ + if not_found { panic!("failed to find a primitive root for q_base={}", self.q_base) } @@ -119,20 +114,31 @@ impl Prime{ } /// Returns an nth primitive root of q = q_base^q_power in Montgomery. - pub fn primitive_nth_root(&self, nth_root:u64) -> u64{ - - assert!(self.q & (nth_root-1) == 1, "invalid prime: q = {} % nth_root = {} = {} != 1", self.q, nth_root, self.q & (nth_root-1)); + pub fn primitive_nth_root(&self, nth_root: u64) -> u64 { + assert!( + self.q & (nth_root - 1) == 1, + "invalid prime: q = {} % nth_root = {} = {} != 1", + self.q, + nth_root, + self.q & (nth_root - 1) + ); let psi: u64 = self.primitive_root(); // nth primitive root mod q_base: psi_nth^(prime.q_base-1)/nth_root mod q_base - let psi_nth_q_base: u64 = pow(psi, (self.q_base-1)/nth_root, self.q_base); + let psi_nth_q_base: u64 = pow(psi, (self.q_base - 1) / nth_root, self.q_base); // lifts nth primitive root mod q_base to q = q_base^q_power let psi_nth_q: u64 = self.hensel_lift(psi_nth_q_base, nth_root); - assert!(self.pow(psi_nth_q, nth_root) == 1, "invalid nth primitive root: psi^nth_root != 1 mod q"); - assert!(self.pow(psi_nth_q, nth_root>>1) == self.q-1, "invalid nth primitive root: psi^(nth_root/2) != -1 mod q"); + assert!( + self.pow(psi_nth_q, nth_root) == 1, + "invalid nth primitive root: psi^nth_root != 1 mod q" + ); + assert!( + self.pow(psi_nth_q, nth_root >> 1) == self.q - 1, + "invalid nth primitive root: psi^(nth_root/2) != -1 mod q" + ); psi_nth_q } @@ -140,72 +146,76 @@ impl Prime{ /// Checks if the field self.factor is populated. /// If not, factorize q_base-1 and populates self.factor. /// If yes, checks that it contains the unique factors of q_base-1. - pub fn check_factors(&mut self){ - - if self.factors.len() == 0{ - - let factors = Factorization::run(self.q_base-1).prime_factor_repr(); + pub fn check_factors(&mut self) { + if self.factors.len() == 0 { + let factors = Factorization::run(self.q_base - 1).prime_factor_repr(); let mut distincts_factors: Vec = Vec::with_capacity(factors.len()); - for factor in factors.iter(){ + for factor in factors.iter() { distincts_factors.push(factor.0) } self.factors = distincts_factors - - }else{ + } else { let mut q_base: u64 = self.q_base; - for &factor in &self.factors{ - if !is_prime(factor){ + for &factor in &self.factors { + if !is_prime(factor) { panic!("invalid factor list: factor {} is not prime", factor) } - - while q_base%factor != 0{ + + while q_base % factor != 0 { q_base /= factor } } - - if q_base != 1{ + + if q_base != 1 { panic!("invalid factor list: does not fully divide q_base: q_base % (all factors) = {}", q_base) } - } + } } /// Returns (psi + a * q_base)^{nth_root} = 1 mod q = q_base^q_power given psi^{nth_root} = 1 mod q_base. /// Panics if psi^{nth_root} != 1 mod q_base. - fn hensel_lift(&self, psi: u64, nth_root: u64) -> u64{ - assert!(pow(psi, nth_root, self.q_base)==1, "invalid argument psi: psi^nth_root = {} != 1", pow(psi, nth_root, self.q_base)); + fn hensel_lift(&self, psi: u64, nth_root: u64) -> u64 { + assert!( + pow(psi, nth_root, self.q_base) == 1, + "invalid argument psi: psi^nth_root = {} != 1", + pow(psi, nth_root, self.q_base) + ); let mut psi_mont: Montgomery = self.montgomery.prepare::(psi); let nth_root_mont: Montgomery = self.montgomery.prepare::(nth_root); - for _i in 1..self.q_power{ + for _i in 1..self.q_power { + let psi_pow: Montgomery = self.montgomery.pow(psi_mont, nth_root - 1); - let psi_pow: Montgomery = self.montgomery.pow(psi_mont, nth_root-1); + let num: Montgomery = self.montgomery.one() + self.q + - self.montgomery.mul_internal::(psi_pow, psi_mont); - let num: Montgomery = self.montgomery.one() + self.q - self.montgomery.mul_internal::(psi_pow, psi_mont); + let mut den: Montgomery = + self.montgomery.mul_internal::(nth_root_mont, psi_pow); - let mut den: Montgomery = self.montgomery.mul_internal::(nth_root_mont, psi_pow); + den = self.montgomery.pow(den, self.phi - 1); - den = self.montgomery.pow(den, self.phi-1); - - psi_mont = self.montgomery.add_internal(psi_mont, self.montgomery.mul_internal::(num, den)); + psi_mont = self + .montgomery + .add_internal(psi_mont, self.montgomery.mul_internal::(num, den)); } - + self.montgomery.unprepare::(psi_mont) } } /// Returns x^exponent mod q. /// This function internally instantiate a new MontgomeryPrecomp -/// To be used when called only a few times and if there +/// To be used when called only a few times and if there /// is no Prime instantiated with q. -pub fn pow(x:u64, exponent:u64, q:u64) -> u64{ +pub fn pow(x: u64, exponent: u64, q: u64) -> u64 { let montgomery: MontgomeryPrecomp = MontgomeryPrecomp::::new(q); let mut y_mont: Montgomery = montgomery.one(); let mut x_mont: Montgomery = montgomery.prepare::(x); let mut i: u64 = exponent; - while i > 0{ - if i & 1 == 1{ + while i > 0 { + if i & 1 == 1 { montgomery.mul_internal_assign::(x_mont, &mut y_mont); } @@ -213,6 +223,6 @@ pub fn pow(x:u64, exponent:u64, q:u64) -> u64{ i >>= 1; } - + montgomery.unprepare::(y_mont) -} \ No newline at end of file +} diff --git a/math/src/modulus/montgomery.rs b/math/src/modulus/montgomery.rs index c3b15bc..27b0a6a 100644 --- a/math/src/modulus/montgomery.rs +++ b/math/src/modulus/montgomery.rs @@ -4,10 +4,10 @@ use crate::modulus::barrett::BarrettPrecomp; /// an element in the Montgomery domain. pub type Montgomery = O; -/// MontgomeryPrecomp is a generic struct storing +/// MontgomeryPrecomp is a generic struct storing /// precomputations for Montgomery arithmetic. #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct MontgomeryPrecomp{ +pub struct MontgomeryPrecomp { pub q: O, pub two_q: O, pub four_q: O, diff --git a/math/src/modulus/prime.rs b/math/src/modulus/prime.rs index 0841ab0..e0868f7 100644 --- a/math/src/modulus/prime.rs +++ b/math/src/modulus/prime.rs @@ -1,23 +1,25 @@ -use crate::modulus::montgomery::MontgomeryPrecomp; use crate::modulus::barrett::BarrettPrecomp; +use crate::modulus::montgomery::MontgomeryPrecomp; #[derive(Clone, Debug, PartialEq, Eq)] pub struct Prime { - pub q: O, /// q_base^q_powers + pub q: O, + /// q_base^q_powers pub two_q: O, pub four_q: O, pub q_base: O, pub q_power: usize, - pub factors: Vec, /// distinct factors of q-1 + pub factors: Vec, + /// distinct factors of q-1 pub montgomery: MontgomeryPrecomp, - pub barrett:BarrettPrecomp, + pub barrett: BarrettPrecomp, pub phi: O, } -pub struct NTTFriendlyPrimesGenerator{ - pub size: f64, - pub next_prime: O, - pub prev_prime: O, - pub check_next_prime: bool, - pub check_prev_prime: bool, +pub struct NTTFriendlyPrimesGenerator { + pub size: f64, + pub next_prime: O, + pub prev_prime: O, + pub check_next_prime: bool, + pub check_prev_prime: bool, } diff --git a/math/src/poly.rs b/math/src/poly.rs index 41daca0..c5988bc 100644 --- a/math/src/poly.rs +++ b/math/src/poly.rs @@ -4,45 +4,51 @@ use std::cmp::PartialEq; #[derive(Clone, Debug, Eq)] pub struct Poly(pub Vec); -impl Polywhere +impl Poly +where O: Default + Clone + Copy, - { - pub fn new(n: usize) -> Self{ - Self(vec![O::default();n]) +{ + pub fn new(n: usize) -> Self { + Self(vec![O::default(); n]) } - pub fn buffer_size(&self) -> usize{ - return self.0.len() + pub fn buffer_size(&self) -> usize { + return self.0.len(); } - pub fn from_buffer(&mut self, n: usize, buf: &mut [O]){ - assert!(buf.len() >= n, "invalid buffer: buf.len()={} < n={}", buf.len(), n); + pub fn from_buffer(&mut self, n: usize, buf: &mut [O]) { + assert!( + buf.len() >= n, + "invalid buffer: buf.len()={} < n={}", + buf.len(), + n + ); self.0 = Vec::from(&buf[..n]); } - pub fn log_n(&self) -> usize{ - (usize::BITS - (self.n()-1).leading_zeros()) as usize + pub fn log_n(&self) -> usize { + (usize::BITS - (self.n() - 1).leading_zeros()) as usize } - pub fn n(&self) -> usize{ + pub fn n(&self) -> usize { self.0.len() } - pub fn resize(&mut self, n:usize){ + pub fn resize(&mut self, n: usize) { self.0.resize(n, O::default()); } - pub fn set_all(&mut self, v: &O){ + pub fn set_all(&mut self, v: &O) { self.0.fill(*v) } - pub fn zero(&mut self){ + pub fn zero(&mut self) { self.set_all(&O::default()) } - pub fn copy_from(&mut self, other: &Poly){ - if std::ptr::eq(self, other){ - return + pub fn copy_from(&mut self, other: &Poly) { + if std::ptr::eq(self, other) { + return; } self.resize(other.n()); self.0.copy_from_slice(&other.0) @@ -64,80 +70,100 @@ impl Default for Poly { #[derive(Clone, Debug, Eq)] pub struct PolyRNS(pub Vec>); -impl PolyRNSwhere +impl PolyRNS +where O: Default + Clone + Copy, - { - - pub fn new(n: usize, level: usize) -> Self{ +{ + pub fn new(n: usize, level: usize) -> Self { let mut polyrns: PolyRNS = PolyRNS::::default(); - let mut buf: Vec = vec![O::default();polyrns.buffer_size(n, level)]; + let mut buf: Vec = vec![O::default(); polyrns.buffer_size(n, level)]; polyrns.from_buffer(n, level, &mut buf[..]); polyrns } - pub fn n(&self) -> usize{ + pub fn n(&self) -> usize { self.0[0].n() } - pub fn log_n(&self) -> usize{ + pub fn log_n(&self) -> usize { self.0[0].log_n() } - pub fn level(&self) -> usize{ - self.0.len()-1 + pub fn level(&self) -> usize { + self.0.len() - 1 } - pub fn buffer_size(&self, n: usize, level:usize) -> usize{ - n * (level+1) + pub fn buffer_size(&self, n: usize, level: usize) -> usize { + n * (level + 1) } - pub fn from_buffer(&mut self, n: usize, level: usize, buf: &mut [O]){ - assert!(buf.len() >= n * (level+1), "invalid buffer: buf.len()={} < n * (level+1)={}", buf.len(), level+1); + pub fn from_buffer(&mut self, n: usize, level: usize, buf: &mut [O]) { + assert!( + buf.len() >= n * (level + 1), + "invalid buffer: buf.len()={} < n * (level+1)={}", + buf.len(), + level + 1 + ); self.0.clear(); - for chunk in buf.chunks_mut(n).take(level+1) { + for chunk in buf.chunks_mut(n).take(level + 1) { let mut poly: Poly = Poly(Vec::new()); poly.from_buffer(n, chunk); self.0.push(poly); } } - pub fn resize(&mut self, level:usize){ - self.0.resize(level+1, Poly::::new(self.n())); + pub fn resize(&mut self, level: usize) { + self.0.resize(level + 1, Poly::::new(self.n())); } - pub fn split_at_mut(&mut self, level:usize) -> (&mut [Poly], &mut [Poly]){ + pub fn split_at_mut(&mut self, level: usize) -> (&mut [Poly], &mut [Poly]) { self.0.split_at_mut(level) } - pub fn at(&self, level:usize) -> &Poly{ - assert!(level <= self.level(), "invalid argument level: level={} > self.level()={}", level, self.level()); + pub fn at(&self, level: usize) -> &Poly { + assert!( + level <= self.level(), + "invalid argument level: level={} > self.level()={}", + level, + self.level() + ); &self.0[level] } - pub fn at_mut(&mut self, level:usize) -> &mut Poly{ + pub fn at_mut(&mut self, level: usize) -> &mut Poly { &mut self.0[level] } - pub fn set_all(&mut self, v: &O){ - (0..self.level()+1).for_each(|i| self.at_mut(i).set_all(v)) + pub fn set_all(&mut self, v: &O) { + (0..self.level() + 1).for_each(|i| self.at_mut(i).set_all(v)) } - pub fn zero(&mut self){ + pub fn zero(&mut self) { self.set_all(&O::default()) } - pub fn copy(&mut self, other: &PolyRNS){ - if std::ptr::eq(self, other){ - return + pub fn copy(&mut self, other: &PolyRNS) { + if std::ptr::eq(self, other) { + return; } self.resize(other.level()); self.copy_level(other.level(), other); } - pub fn copy_level(&mut self, level:usize, other: &PolyRNS){ - assert!(self.level() <= level, "invalid argument level: level={} > self.level()={}", level, self.level()); - assert!(other.level() <= level, "invalid argument level: level={} > other.level()={}", level, other.level()); - (0..level+1).for_each(|i| self.at_mut(i).copy_from(other.at(i))) + pub fn copy_level(&mut self, level: usize, other: &PolyRNS) { + assert!( + self.level() <= level, + "invalid argument level: level={} > self.level()={}", + level, + self.level() + ); + assert!( + other.level() <= level, + "invalid argument level: level={} > other.level()={}", + level, + other.level() + ); + (0..level + 1).for_each(|i| self.at_mut(i).copy_from(other.at(i))) } } @@ -147,8 +173,8 @@ impl PartialEq for PolyRNS { } } -impl Default for PolyRNS{ - fn default() -> Self{ - Self{0:Vec::new()} +impl Default for PolyRNS { + fn default() -> Self { + Self { 0: Vec::new() } } -} \ No newline at end of file +} diff --git a/math/src/ring.rs b/math/src/ring.rs index 6e92486..2771e19 100644 --- a/math/src/ring.rs +++ b/math/src/ring.rs @@ -1,49 +1,47 @@ pub mod impl_u64; -use num::traits::Unsigned; +use crate::dft::DFT; use crate::modulus::prime::Prime; use crate::poly::{Poly, PolyRNS}; -use crate::dft::DFT; +use num::traits::Unsigned; - -pub struct Ring{ - pub n:usize, - pub modulus:Prime, - pub dft:Box>, +pub struct Ring { + pub n: usize, + pub modulus: Prime, + pub dft: Box>, } -impl Ring{ - pub fn n(&self) -> usize{ - return self.n +impl Ring { + pub fn n(&self) -> usize { + return self.n; } - pub fn new_poly(&self) -> Poly{ + pub fn new_poly(&self) -> Poly { Poly::::new(self.n()) } } -pub struct RingRNS<'a, O: Unsigned>(pub & 'a [Ring]); +pub struct RingRNS<'a, O: Unsigned>(pub &'a [Ring]); impl RingRNS<'_, O> { - - pub fn n(&self) -> usize{ + pub fn n(&self) -> usize { self.0[0].n() } - pub fn new_polyrns(&self) -> PolyRNS{ + pub fn new_polyrns(&self) -> PolyRNS { PolyRNS::::new(self.n(), self.level()) } - pub fn max_level(&self) -> usize{ - self.0.len()-1 + pub fn max_level(&self) -> usize { + self.0.len() - 1 } - pub fn level(&self) -> usize{ - self.0.len()-1 + pub fn level(&self) -> usize { + self.0.len() - 1 } - pub fn at_level(&self, level:usize) -> RingRNS{ + pub fn at_level(&self, level: usize) -> RingRNS { assert!(level <= self.0.len()); - RingRNS(&self.0[..level+1]) + RingRNS(&self.0[..level + 1]) } } diff --git a/math/src/ring/impl_u64/automorphism.rs b/math/src/ring/impl_u64/automorphism.rs index 18179b0..4c12f07 100644 --- a/math/src/ring/impl_u64/automorphism.rs +++ b/math/src/ring/impl_u64/automorphism.rs @@ -1,43 +1,57 @@ use crate::modulus::WordOps; -use crate::ring::Ring; use crate::poly::Poly; +use crate::ring::Ring; /// Returns a lookup table for the automorphism X^{i} -> X^{i * k mod nth_root}. /// Method will panic if n or nth_root are not power-of-two. /// Method will panic if gal_el is not coprime with nth_root. -pub fn automorphism_index_ntt(n: usize, nth_root:u64, gal_el: u64) -> Vec{ - assert!(n&(n-1) != 0, "invalid n={}: not a power-of-two", n); - assert!(nth_root&(nth_root-1) != 0, "invalid nth_root={}: not a power-of-two", n); - assert!(gal_el & 1 == 1, "invalid gal_el={}: not coprime with nth_root={}", gal_el, nth_root); +pub fn automorphism_index_ntt(n: usize, nth_root: u64, gal_el: u64) -> Vec { + assert!(n & (n - 1) != 0, "invalid n={}: not a power-of-two", n); + assert!( + nth_root & (nth_root - 1) != 0, + "invalid nth_root={}: not a power-of-two", + n + ); + assert!( + gal_el & 1 == 1, + "invalid gal_el={}: not coprime with nth_root={}", + gal_el, + nth_root + ); - let mask = nth_root-1; + let mask = nth_root - 1; let log_nth_root: u32 = nth_root.log2() as u32; let mut index: Vec = Vec::with_capacity(n); - for i in 0..n{ - let i_rev: usize = 2*i.reverse_bits_msb(log_nth_root)+1; - let gal_el_i: u64 = (gal_el * (i_rev as u64) & mask)>>1; + for i in 0..n { + let i_rev: usize = 2 * i.reverse_bits_msb(log_nth_root) + 1; + let gal_el_i: u64 = (gal_el * (i_rev as u64) & mask) >> 1; index.push(gal_el_i.reverse_bits_msb(log_nth_root)); } index } -impl Ring{ - pub fn automorphism(&self, a:Poly, gal_el: u64, b:&mut Poly){ - debug_assert!(a.n() == b.n(), "invalid inputs: a.n() = {} != b.n() = {}", a.n(), b.n()); - debug_assert!(gal_el&1 == 1, "invalid gal_el = {}: not odd", gal_el); +impl Ring { + pub fn automorphism(&self, a: Poly, gal_el: u64, b: &mut Poly) { + debug_assert!( + a.n() == b.n(), + "invalid inputs: a.n() = {} != b.n() = {}", + a.n(), + b.n() + ); + debug_assert!(gal_el & 1 == 1, "invalid gal_el = {}: not odd", gal_el); let n: usize = a.n(); - let mask: u64 = (n-1) as u64; + let mask: u64 = (n - 1) as u64; let log_n: usize = n.log2(); let q: u64 = self.modulus.q(); let b_vec: &mut _ = &mut b.0; let a_vec: &_ = &a.0; - a_vec.iter().enumerate().for_each(|(i, ai)|{ + a_vec.iter().enumerate().for_each(|(i, ai)| { let gal_el_i: u64 = i as u64 * gal_el; - let sign: u64 = (gal_el_i>>log_n) & 1; + let sign: u64 = (gal_el_i >> log_n) & 1; let i_out: u64 = gal_el_i & mask; - b_vec[i_out as usize] = ai * (sign^1) | (q - ai) * sign + b_vec[i_out as usize] = ai * (sign ^ 1) | (q - ai) * sign }); } -} \ No newline at end of file +} diff --git a/math/src/ring/impl_u64/mod.rs b/math/src/ring/impl_u64/mod.rs index a5125f6..8915bbd 100644 --- a/math/src/ring/impl_u64/mod.rs +++ b/math/src/ring/impl_u64/mod.rs @@ -1,5 +1,5 @@ pub mod automorphism; +pub mod rescaling_rns; pub mod ring; pub mod ring_rns; -pub mod rescaling_rns; -pub mod sampling; \ No newline at end of file +pub mod sampling; diff --git a/math/src/ring/impl_u64/rescaling_rns.rs b/math/src/ring/impl_u64/rescaling_rns.rs index 9a82171..6bc1d5f 100644 --- a/math/src/ring/impl_u64/rescaling_rns.rs +++ b/math/src/ring/impl_u64/rescaling_rns.rs @@ -1,142 +1,269 @@ +use crate::modulus::barrett::Barrett; +use crate::modulus::ONCE; +use crate::poly::PolyRNS; use crate::ring::Ring; use crate::ring::RingRNS; -use crate::poly::PolyRNS; -use crate::modulus::barrett::Barrett; use crate::scalar::ScalarRNS; -use crate::modulus::ONCE; extern crate test; -impl RingRNS<'_, u64>{ - +impl RingRNS<'_, u64> { /// Updates b to floor(a / q[b.level()]). - pub fn div_floor_by_last_modulus(&self, a: &PolyRNS, buf: &mut PolyRNS, b: &mut PolyRNS){ - debug_assert!(self.level() <= a.level(), "invalid input a: self.level()={} > a.level()={}", self.level(), a.level()); - debug_assert!(b.level() >= a.level()-1, "invalid input b: b.level()={} < a.level()-1={}", b.level(), a.level()-1); + pub fn div_floor_by_last_modulus( + &self, + a: &PolyRNS, + buf: &mut PolyRNS, + b: &mut PolyRNS, + ) { + debug_assert!( + self.level() <= a.level(), + "invalid input a: self.level()={} > a.level()={}", + self.level(), + a.level() + ); + debug_assert!( + b.level() >= a.level() - 1, + "invalid input b: b.level()={} < a.level()-1={}", + b.level(), + a.level() - 1 + ); let level = self.level(); let rescaling_constants: ScalarRNS> = self.rescaling_constant(); - - if NTT{ + + if NTT { let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); self.0[level].intt::(a.at(level), &mut buf_ntt_q_scaling[0]); - for (i, r) in self.0[0..level].iter().enumerate(){ + for (i, r) in self.0[0..level].iter().enumerate() { r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); - r.sum_aqqmb_prod_c_scalar_barrett::(&buf_ntt_qi_scaling[0], a.at(i), &rescaling_constants.0[i], b.at_mut(i)); + r.sum_aqqmb_prod_c_scalar_barrett::( + &buf_ntt_qi_scaling[0], + a.at(i), + &rescaling_constants.0[i], + b.at_mut(i), + ); } - }else{ - for (i, r) in self.0[0..level].iter().enumerate(){ - r.sum_aqqmb_prod_c_scalar_barrett::(a.at(level), a.at(i), &rescaling_constants.0[i], b.at_mut(i)); + } else { + for (i, r) in self.0[0..level].iter().enumerate() { + r.sum_aqqmb_prod_c_scalar_barrett::( + a.at(level), + a.at(i), + &rescaling_constants.0[i], + b.at_mut(i), + ); } } } /// Updates a to floor(a / q[b.level()]). /// Expects a to be in the NTT domain. - pub fn div_floor_by_last_modulus_inplace(&self, buf: &mut PolyRNS, a: &mut PolyRNS){ - debug_assert!(self.level() <= a.level(), "invalid input a: self.level()={} > a.level()={}", self.level(), a.level()); + pub fn div_floor_by_last_modulus_inplace( + &self, + buf: &mut PolyRNS, + a: &mut PolyRNS, + ) { + debug_assert!( + self.level() <= a.level(), + "invalid input a: self.level()={} > a.level()={}", + self.level(), + a.level() + ); let level = self.level(); let rescaling_constants: ScalarRNS> = self.rescaling_constant(); - if NTT{ + if NTT { let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); self.0[level].intt::(a.at(level), &mut buf_ntt_q_scaling[0]); - for (i, r) in self.0[0..level].iter().enumerate(){ + for (i, r) in self.0[0..level].iter().enumerate() { r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); - r.sum_aqqmb_prod_c_scalar_barrett_inplace::(&buf_ntt_qi_scaling[0], &rescaling_constants.0[i], a.at_mut(i)); + r.sum_aqqmb_prod_c_scalar_barrett_inplace::( + &buf_ntt_qi_scaling[0], + &rescaling_constants.0[i], + a.at_mut(i), + ); } - }else{ + } else { let (a_i, a_level) = buf.0.split_at_mut(level); - for (i, r) in self.0[0..level].iter().enumerate(){ - r.sum_aqqmb_prod_c_scalar_barrett_inplace::(&a_level[0], &rescaling_constants.0[i], &mut a_i[i]); + for (i, r) in self.0[0..level].iter().enumerate() { + r.sum_aqqmb_prod_c_scalar_barrett_inplace::( + &a_level[0], + &rescaling_constants.0[i], + &mut a_i[i], + ); } } } /// Updates b to floor(a / prod_{level - nb_moduli}^{level} q[i]) - pub fn div_floor_by_last_moduli(&self, nb_moduli:usize, a: &PolyRNS, buf: &mut PolyRNS, c: &mut PolyRNS){ + pub fn div_floor_by_last_moduli( + &self, + nb_moduli: usize, + a: &PolyRNS, + buf: &mut PolyRNS, + c: &mut PolyRNS, + ) { + debug_assert!( + self.level() <= a.level(), + "invalid input a: self.level()={} > a.level()={}", + self.level(), + a.level() + ); + debug_assert!( + c.level() >= a.level() - 1, + "invalid input b: b.level()={} < a.level()-1={}", + c.level(), + a.level() - 1 + ); + debug_assert!( + nb_moduli <= a.level(), + "invalid input nb_moduli: nb_moduli={} > a.level()={}", + nb_moduli, + a.level() + ); - debug_assert!(self.level() <= a.level(), "invalid input a: self.level()={} > a.level()={}", self.level(), a.level()); - debug_assert!(c.level() >= a.level()-1, "invalid input b: b.level()={} < a.level()-1={}", c.level(), a.level()-1); - debug_assert!(nb_moduli <= a.level(), "invalid input nb_moduli: nb_moduli={} > a.level()={}", nb_moduli, a.level()); - - if nb_moduli == 0{ - if a != c{ + if nb_moduli == 0 { + if a != c { c.copy(a); } - }else{ - if NTT{ + } else { + if NTT { self.intt::(a, buf); - (0..nb_moduli).for_each(|i|{self.at_level(self.level()-i).div_floor_by_last_modulus_inplace::(&mut PolyRNS::::default(), buf)}); - self.at_level(self.level()-nb_moduli).ntt::(buf, c); - }else{ + (0..nb_moduli).for_each(|i| { + self.at_level(self.level() - i) + .div_floor_by_last_modulus_inplace::( + &mut PolyRNS::::default(), + buf, + ) + }); + self.at_level(self.level() - nb_moduli).ntt::(buf, c); + } else { self.div_floor_by_last_modulus::(a, buf, c); - (1..nb_moduli).for_each(|i|{self.at_level(self.level()-i).div_floor_by_last_modulus_inplace::(buf, c)}); + (1..nb_moduli).for_each(|i| { + self.at_level(self.level() - i) + .div_floor_by_last_modulus_inplace::(buf, c) + }); } } } /// Updates a to floor(a / prod_{level - nb_moduli}^{level} q[i]) - pub fn div_floor_by_last_moduli_inplace(&self, nb_moduli:usize, buf: &mut PolyRNS, a: &mut PolyRNS){ - debug_assert!(self.level() <= a.level(), "invalid input a: self.level()={} > a.level()={}", self.level(), a.level()); - debug_assert!(nb_moduli <= a.level(), "invalid input nb_moduli: nb_moduli={} > a.level()={}", nb_moduli, a.level()); - if NTT{ + pub fn div_floor_by_last_moduli_inplace( + &self, + nb_moduli: usize, + buf: &mut PolyRNS, + a: &mut PolyRNS, + ) { + debug_assert!( + self.level() <= a.level(), + "invalid input a: self.level()={} > a.level()={}", + self.level(), + a.level() + ); + debug_assert!( + nb_moduli <= a.level(), + "invalid input nb_moduli: nb_moduli={} > a.level()={}", + nb_moduli, + a.level() + ); + if NTT { self.intt::(a, buf); - (0..nb_moduli).for_each(|i|{self.at_level(self.level()-i).div_floor_by_last_modulus_inplace::(&mut PolyRNS::::default(), buf)}); - self.at_level(self.level()-nb_moduli).ntt::(buf, a); - }else{ - (0..nb_moduli).for_each(|i|{self.at_level(self.level()-i).div_floor_by_last_modulus_inplace::(buf, a)}); - } + (0..nb_moduli).for_each(|i| { + self.at_level(self.level() - i) + .div_floor_by_last_modulus_inplace::(&mut PolyRNS::::default(), buf) + }); + self.at_level(self.level() - nb_moduli).ntt::(buf, a); + } else { + (0..nb_moduli).for_each(|i| { + self.at_level(self.level() - i) + .div_floor_by_last_modulus_inplace::(buf, a) + }); + } } /// Updates b to round(a / q[b.level()]). /// Expects b to be in the NTT domain. - pub fn div_round_by_last_modulus(&self, a: &PolyRNS, buf: &mut PolyRNS, b: &mut PolyRNS){ - debug_assert!(self.level() <= a.level(), "invalid input a: self.level()={} > a.level()={}", self.level(), a.level()); - debug_assert!(b.level() >= a.level()-1, "invalid input b: b.level()={} < a.level()-1={}", b.level(), a.level()-1); + pub fn div_round_by_last_modulus( + &self, + a: &PolyRNS, + buf: &mut PolyRNS, + b: &mut PolyRNS, + ) { + debug_assert!( + self.level() <= a.level(), + "invalid input a: self.level()={} > a.level()={}", + self.level(), + a.level() + ); + debug_assert!( + b.level() >= a.level() - 1, + "invalid input b: b.level()={} < a.level()-1={}", + b.level(), + a.level() - 1 + ); let level: usize = self.level(); let r_last: &Ring = &self.0[level]; - let q_level_half: u64 = r_last.modulus.q>>1; + let q_level_half: u64 = r_last.modulus.q >> 1; let rescaling_constants: ScalarRNS> = self.rescaling_constant(); let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); - if NTT{ + if NTT { r_last.intt::(a.at(level), &mut buf_ntt_q_scaling[0]); r_last.add_scalar_inplace::(&q_level_half, &mut buf_ntt_q_scaling[0]); - for (i, r) in self.0[0..level].iter().enumerate(){ - r_last.add_scalar::(&buf_ntt_q_scaling[0], &q_level_half, &mut buf_ntt_qi_scaling[0]); + for (i, r) in self.0[0..level].iter().enumerate() { + r_last.add_scalar::( + &buf_ntt_q_scaling[0], + &q_level_half, + &mut buf_ntt_qi_scaling[0], + ); r.ntt_inplace::(&mut buf_ntt_qi_scaling[0]); - r.sum_aqqmb_prod_c_scalar_barrett::(&buf_ntt_qi_scaling[0], a.at(i), &rescaling_constants.0[i], b.at_mut(i)); + r.sum_aqqmb_prod_c_scalar_barrett::( + &buf_ntt_qi_scaling[0], + a.at(i), + &rescaling_constants.0[i], + b.at_mut(i), + ); } - }else{ - + } else { } - } /// Updates a to round(a / q[b.level()]). /// Expects a to be in the NTT domain. - pub fn div_round_by_last_modulus_inplace(&self, buf: &mut PolyRNS, a: &mut PolyRNS){ - debug_assert!(self.level() <= a.level(), "invalid input a: self.level()={} > a.level()={}", self.level(), a.level()); + pub fn div_round_by_last_modulus_inplace( + &self, + buf: &mut PolyRNS, + a: &mut PolyRNS, + ) { + debug_assert!( + self.level() <= a.level(), + "invalid input a: self.level()={} > a.level()={}", + self.level(), + a.level() + ); let level = self.level(); let r_last: &Ring = &self.0[level]; - let q_level_half: u64 = r_last.modulus.q>>1; + let q_level_half: u64 = r_last.modulus.q >> 1; let rescaling_constants: ScalarRNS> = self.rescaling_constant(); let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1); - if NTT{ + if NTT { r_last.intt::(a.at(level), &mut buf_ntt_q_scaling[0]); r_last.add_scalar_inplace::(&q_level_half, &mut buf_ntt_q_scaling[0]); - for (i, r) in self.0[0..level].iter().enumerate(){ - r_last.add_scalar::(&buf_ntt_q_scaling[0], &q_level_half, &mut buf_ntt_qi_scaling[0]); + for (i, r) in self.0[0..level].iter().enumerate() { + r_last.add_scalar::( + &buf_ntt_q_scaling[0], + &q_level_half, + &mut buf_ntt_qi_scaling[0], + ); r.ntt::(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]); - r.sum_aqqmb_prod_c_scalar_barrett_inplace::(&buf_ntt_qi_scaling[0], &rescaling_constants.0[i], a.at_mut(i)); + r.sum_aqqmb_prod_c_scalar_barrett_inplace::( + &buf_ntt_qi_scaling[0], + &rescaling_constants.0[i], + a.at_mut(i), + ); } } - } } - - diff --git a/math/src/ring/impl_u64/ring.rs b/math/src/ring/impl_u64/ring.rs index 377d47b..2dbfc15 100644 --- a/math/src/ring/impl_u64/ring.rs +++ b/math/src/ring/impl_u64/ring.rs @@ -1,17 +1,17 @@ -use crate::ring::Ring; use crate::dft::ntt::Table; -use crate::modulus::prime::Prime; -use crate::modulus::montgomery::Montgomery; use crate::modulus::barrett::Barrett; -use crate::poly::Poly; -use crate::modulus::{REDUCEMOD, BARRETT}; +use crate::modulus::montgomery::Montgomery; +use crate::modulus::prime::Prime; use crate::modulus::VectorOperations; +use crate::modulus::{BARRETT, REDUCEMOD}; +use crate::poly::Poly; +use crate::ring::Ring; +use crate::CHUNK; use num_bigint::BigInt; use num_traits::ToPrimitive; -use crate::CHUNK; -impl Ring{ - pub fn new(n:usize, q_base:u64, q_power:usize) -> Self{ +impl Ring { + pub fn new(n: usize, q_base: u64, q_power: usize) -> Self { let prime: Prime = Prime::::new(q_base, q_power); Self { n: n, @@ -20,156 +20,218 @@ impl Ring{ } } - pub fn from_bigint(&self, coeffs: &[BigInt], step:usize, a: &mut Poly){ - assert!(step <= a.n(), "invalid step: step={} > a.n()={}", step, a.n()); - assert!(coeffs.len() <= a.n() / step, "invalid coeffs: coeffs.len()={} > a.n()/step={}", coeffs.len(), a.n()/step); + pub fn from_bigint(&self, coeffs: &[BigInt], step: usize, a: &mut Poly) { + assert!( + step <= a.n(), + "invalid step: step={} > a.n()={}", + step, + a.n() + ); + assert!( + coeffs.len() <= a.n() / step, + "invalid coeffs: coeffs.len()={} > a.n()/step={}", + coeffs.len(), + a.n() / step + ); let q_big: BigInt = BigInt::from(self.modulus.q); - a.0.iter_mut().step_by(step).enumerate().for_each(|(i, v)| *v = (&coeffs[i] % &q_big).to_u64().unwrap()); + a.0.iter_mut() + .step_by(step) + .enumerate() + .for_each(|(i, v)| *v = (&coeffs[i] % &q_big).to_u64().unwrap()); } } -impl Ring{ - pub fn ntt_inplace(&self, poly: &mut Poly){ - match LAZY{ +impl Ring { + pub fn ntt_inplace(&self, poly: &mut Poly) { + match LAZY { true => self.dft.forward_inplace_lazy(&mut poly.0), - false => self.dft.forward_inplace(&mut poly.0) + false => self.dft.forward_inplace(&mut poly.0), } } - pub fn intt_inplace(&self, poly: &mut Poly){ - match LAZY{ + pub fn intt_inplace(&self, poly: &mut Poly) { + match LAZY { true => self.dft.backward_inplace_lazy(&mut poly.0), - false => self.dft.backward_inplace(&mut poly.0) + false => self.dft.backward_inplace(&mut poly.0), } } - pub fn ntt(&self, poly_in: &Poly, poly_out: &mut Poly){ + pub fn ntt(&self, poly_in: &Poly, poly_out: &mut Poly) { poly_out.0.copy_from_slice(&poly_in.0); - match LAZY{ + match LAZY { true => self.dft.forward_inplace_lazy(&mut poly_out.0), - false => self.dft.forward_inplace(&mut poly_out.0) + false => self.dft.forward_inplace(&mut poly_out.0), } } - pub fn intt(&self, poly_in: &Poly, poly_out: &mut Poly){ + pub fn intt(&self, poly_in: &Poly, poly_out: &mut Poly) { poly_out.0.copy_from_slice(&poly_in.0); - match LAZY{ + match LAZY { true => self.dft.backward_inplace_lazy(&mut poly_out.0), - false => self.dft.backward_inplace(&mut poly_out.0) + false => self.dft.backward_inplace(&mut poly_out.0), } } } -impl Ring{ - +impl Ring { #[inline(always)] - pub fn add_inplace(&self, a: &Poly, b: &mut Poly){ + pub fn add_inplace(&self, a: &Poly, b: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); - self.modulus.va_add_vb_into_vb::(&a.0, &mut b.0); + self.modulus + .va_add_vb_into_vb::(&a.0, &mut b.0); } #[inline(always)] - pub fn add(&self, a: &Poly, b: &Poly, c: &mut Poly){ + pub fn add(&self, a: &Poly, b: &Poly, c: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); - self.modulus.va_add_vb_into_vc::(&a.0, &b.0, &mut c.0); + self.modulus + .va_add_vb_into_vc::(&a.0, &b.0, &mut c.0); } #[inline(always)] - pub fn add_scalar_inplace(&self, a: &u64, b: &mut Poly){ + pub fn add_scalar_inplace(&self, a: &u64, b: &mut Poly) { debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus.sa_add_vb_into_vb::(a, &mut b.0); } #[inline(always)] - pub fn add_scalar(&self, a: &Poly, b: &u64, c: &mut Poly){ + pub fn add_scalar(&self, a: &Poly, b: &u64, c: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); - self.modulus.va_add_sb_into_vc::(&a.0, b, &mut c.0); + self.modulus + .va_add_sb_into_vc::(&a.0, b, &mut c.0); } #[inline(always)] - pub fn sub_inplace(&self, a: &Poly, b: &mut Poly){ + pub fn sub_inplace(&self, a: &Poly, b: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); - self.modulus.va_sub_vb_into_vb::(&a.0, &mut b.0); + self.modulus + .va_sub_vb_into_vb::(&a.0, &mut b.0); } #[inline(always)] - pub fn sub(&self, a: &Poly, b: &Poly, c: &mut Poly){ + pub fn sub(&self, a: &Poly, b: &Poly, c: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); - self.modulus.va_sub_vb_into_vc::(&a.0, &b.0, &mut c.0); + self.modulus + .va_sub_vb_into_vc::(&a.0, &b.0, &mut c.0); } #[inline(always)] - pub fn neg(&self, a: &Poly, b: &mut Poly){ + pub fn neg(&self, a: &Poly, b: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus.va_neg_into_vb::(&a.0, &mut b.0); } #[inline(always)] - pub fn neg_inplace(&self, a: &mut Poly){ + pub fn neg_inplace(&self, a: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); self.modulus.va_neg_into_va::(&mut a.0); } #[inline(always)] - pub fn mul_montgomery_external(&self, a:&Poly>, b:&Poly, c: &mut Poly){ + pub fn mul_montgomery_external( + &self, + a: &Poly>, + b: &Poly, + c: &mut Poly, + ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); - self.modulus.va_mont_mul_vb_into_vc::(&a.0, &b.0, &mut c.0); + self.modulus + .va_mont_mul_vb_into_vc::(&a.0, &b.0, &mut c.0); } #[inline(always)] - pub fn mul_montgomery_external_inplace(&self, a:&Poly>, b:&mut Poly){ + pub fn mul_montgomery_external_inplace( + &self, + a: &Poly>, + b: &mut Poly, + ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); - self.modulus.va_mont_mul_vb_into_vb::(&a.0, &mut b.0); + self.modulus + .va_mont_mul_vb_into_vb::(&a.0, &mut b.0); } #[inline(always)] - pub fn mul_scalar(&self, a:&Poly, b: &u64, c:&mut Poly){ + pub fn mul_scalar(&self, a: &Poly, b: &u64, c: &mut Poly) { debug_assert!(a.n() == self.n(), "b.n()={} != n={}", a.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); - self.modulus.sa_barrett_mul_vb_into_vc::(&self.modulus.barrett.prepare(*b), &a.0, &mut c.0); + self.modulus.sa_barrett_mul_vb_into_vc::( + &self.modulus.barrett.prepare(*b), + &a.0, + &mut c.0, + ); } #[inline(always)] - pub fn mul_scalar_inplace(&self, a:&u64, b:&mut Poly){ + pub fn mul_scalar_inplace(&self, a: &u64, b: &mut Poly) { debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); - self.modulus.sa_barrett_mul_vb_into_vb::(&self.modulus.barrett.prepare(self.modulus.barrett.reduce::(a)), &mut b.0); + self.modulus.sa_barrett_mul_vb_into_vb::( + &self + .modulus + .barrett + .prepare(self.modulus.barrett.reduce::(a)), + &mut b.0, + ); } #[inline(always)] - pub fn mul_scalar_barrett_inplace(&self, a:&Barrett, b:&mut Poly){ + pub fn mul_scalar_barrett_inplace( + &self, + a: &Barrett, + b: &mut Poly, + ) { debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); - self.modulus.sa_barrett_mul_vb_into_vb::(a, &mut b.0); + self.modulus + .sa_barrett_mul_vb_into_vb::(a, &mut b.0); } #[inline(always)] - pub fn mul_scalar_barrett(&self, a:&Barrett, b: &Poly, c:&mut Poly){ + pub fn mul_scalar_barrett( + &self, + a: &Barrett, + b: &Poly, + c: &mut Poly, + ) { debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); - self.modulus.sa_barrett_mul_vb_into_vc::(a, &b.0, &mut c.0); + self.modulus + .sa_barrett_mul_vb_into_vc::(a, &b.0, &mut c.0); } #[inline(always)] - pub fn sum_aqqmb_prod_c_scalar_barrett(&self, a: &Poly, b: &Poly, c: &Barrett, d: &mut Poly){ + pub fn sum_aqqmb_prod_c_scalar_barrett( + &self, + a: &Poly, + b: &Poly, + c: &Barrett, + d: &mut Poly, + ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(d.n() == self.n(), "d.n()={} != n={}", d.n(), self.n()); - self.modulus.va_sub_vb_mul_sc_into_vd::(&a.0, &b.0, c, &mut d.0); + self.modulus + .va_sub_vb_mul_sc_into_vd::(&a.0, &b.0, c, &mut d.0); } #[inline(always)] - pub fn sum_aqqmb_prod_c_scalar_barrett_inplace(&self, a: &Poly, c: &Barrett, b: &mut Poly){ + pub fn sum_aqqmb_prod_c_scalar_barrett_inplace( + &self, + a: &Poly, + c: &Barrett, + b: &mut Poly, + ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); - self.modulus.va_sub_vb_mul_sc_into_vb::(&a.0, c, &mut b.0); + self.modulus + .va_sub_vb_mul_sc_into_vb::(&a.0, c, &mut b.0); } -} \ No newline at end of file +} diff --git a/math/src/ring/impl_u64/ring_rns.rs b/math/src/ring/impl_u64/ring_rns.rs index 25a6b01..e8cbd69 100644 --- a/math/src/ring/impl_u64/ring_rns.rs +++ b/math/src/ring/impl_u64/ring_rns.rs @@ -1,158 +1,353 @@ -use crate::ring::{Ring, RingRNS}; -use crate::poly::PolyRNS; -use crate::modulus::montgomery::Montgomery; use crate::modulus::barrett::Barrett; -use crate::scalar::ScalarRNS; +use crate::modulus::montgomery::Montgomery; use crate::modulus::REDUCEMOD; +use crate::poly::PolyRNS; +use crate::ring::{Ring, RingRNS}; +use crate::scalar::ScalarRNS; use num_bigint::BigInt; -pub fn new_rings(n: usize, moduli: Vec) -> Vec>{ +pub fn new_rings(n: usize, moduli: Vec) -> Vec> { assert!(!moduli.is_empty(), "moduli cannot be empty"); let rings: Vec> = moduli .into_iter() - .map(|prime| Ring::new(n, prime, 1)) + .map(|prime| Ring::new(n, prime, 1)) .collect(); - return rings + return rings; } -impl<'a> RingRNS<'a, u64>{ - pub fn new(rings:&'a [Ring]) -> Self{ +impl<'a> RingRNS<'a, u64> { + pub fn new(rings: &'a [Ring]) -> Self { RingRNS(rings) } - pub fn modulus(&self) -> BigInt{ + pub fn modulus(&self) -> BigInt { let mut modulus = BigInt::from(1); - self.0.iter().enumerate().for_each(|(_, r)|modulus *= BigInt::from(r.modulus.q)); + self.0 + .iter() + .enumerate() + .for_each(|(_, r)| modulus *= BigInt::from(r.modulus.q)); modulus } pub fn rescaling_constant(&self) -> ScalarRNS> { let level = self.level(); let q_scale: u64 = self.0[level].modulus.q; - ScalarRNS((0..level).map(|i| {self.0[i].modulus.barrett.prepare(self.0[i].modulus.q - self.0[i].modulus.inv(q_scale))}).collect()) + ScalarRNS( + (0..level) + .map(|i| { + self.0[i] + .modulus + .barrett + .prepare(self.0[i].modulus.q - self.0[i].modulus.inv(q_scale)) + }) + .collect(), + ) } - pub fn from_bigint_inplace(&self, coeffs: &[BigInt], step:usize, a: &mut PolyRNS){ + pub fn from_bigint_inplace(&self, coeffs: &[BigInt], step: usize, a: &mut PolyRNS) { let level = self.level(); - assert!(level <= a.level(), "invalid level: level={} > a.level()={}", level, a.level()); - (0..level).for_each(|i|{self.0[i].from_bigint(coeffs, step, a.at_mut(i))}); + assert!( + level <= a.level(), + "invalid level: level={} > a.level()={}", + level, + a.level() + ); + (0..level).for_each(|i| self.0[i].from_bigint(coeffs, step, a.at_mut(i))); } - pub fn to_bigint_inplace(&self, a: &PolyRNS, step: usize, coeffs: &mut [BigInt]){ - assert!(step <= a.n(), "invalid step: step={} > a.n()={}", step, a.n()); - assert!(coeffs.len() <= a.n() / step, "invalid coeffs: coeffs.len()={} > a.n()/step={}", coeffs.len(), a.n()/step); + pub fn to_bigint_inplace(&self, a: &PolyRNS, step: usize, coeffs: &mut [BigInt]) { + assert!( + step <= a.n(), + "invalid step: step={} > a.n()={}", + step, + a.n() + ); + assert!( + coeffs.len() <= a.n() / step, + "invalid coeffs: coeffs.len()={} > a.n()/step={}", + coeffs.len(), + a.n() / step + ); - let mut inv_crt: Vec = vec![BigInt::default(); self.level()+1]; + let mut inv_crt: Vec = vec![BigInt::default(); self.level() + 1]; let q_big: BigInt = self.modulus(); - let q_big_half: BigInt = &q_big>>1; + let q_big_half: BigInt = &q_big >> 1; - inv_crt.iter_mut().enumerate().for_each(|(i, a)|{ + inv_crt.iter_mut().enumerate().for_each(|(i, a)| { let qi_big = BigInt::from(self.0[i].modulus.q); *a = &q_big / &qi_big; *a *= a.modinv(&qi_big).unwrap(); }); - (0..self.n()).step_by(step).enumerate().for_each(|(i, j)|{ + (0..self.n()).step_by(step).enumerate().for_each(|(i, j)| { coeffs[j] = BigInt::from(a.at(0).0[i]) * &inv_crt[0]; - (1..self.level()+1).for_each(|k|{ + (1..self.level() + 1).for_each(|k| { coeffs[j] += BigInt::from(a.at(k).0[i] * &inv_crt[k]); }); coeffs[j] %= &q_big; - if &coeffs[j] >= &q_big_half{ + if &coeffs[j] >= &q_big_half { coeffs[j] -= &q_big; } }); } } -impl RingRNS<'_, u64>{ - pub fn ntt_inplace(&self, a: &mut PolyRNS){ - self.0.iter().enumerate().for_each(|(i, ring)| ring.ntt_inplace::(&mut a.0[i])); +impl RingRNS<'_, u64> { + pub fn ntt_inplace(&self, a: &mut PolyRNS) { + self.0 + .iter() + .enumerate() + .for_each(|(i, ring)| ring.ntt_inplace::(&mut a.0[i])); } - pub fn intt_inplace(&self, a: &mut PolyRNS){ - self.0.iter().enumerate().for_each(|(i, ring)| ring.intt_inplace::(&mut a.0[i])); + pub fn intt_inplace(&self, a: &mut PolyRNS) { + self.0 + .iter() + .enumerate() + .for_each(|(i, ring)| ring.intt_inplace::(&mut a.0[i])); } - pub fn ntt(&self, a: &PolyRNS, b: &mut PolyRNS){ - self.0.iter().enumerate().for_each(|(i, ring)| ring.ntt::(&a.0[i], &mut b.0[i])); + pub fn ntt(&self, a: &PolyRNS, b: &mut PolyRNS) { + self.0 + .iter() + .enumerate() + .for_each(|(i, ring)| ring.ntt::(&a.0[i], &mut b.0[i])); } - pub fn intt(&self, a: &PolyRNS, b: &mut PolyRNS){ - self.0.iter().enumerate().for_each(|(i, ring)| ring.intt::(&a.0[i], &mut b.0[i])); + pub fn intt(&self, a: &PolyRNS, b: &mut PolyRNS) { + self.0 + .iter() + .enumerate() + .for_each(|(i, ring)| ring.intt::(&a.0[i], &mut b.0[i])); } } -impl RingRNS<'_, u64>{ - +impl RingRNS<'_, u64> { #[inline(always)] - pub fn add(&self, a: &PolyRNS, b: &PolyRNS, c: &mut PolyRNS){ - debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); - debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); - debug_assert!(c.level() >= self.level(), "c.level()={} < self.level()={}", c.level(), self.level()); - self.0.iter().enumerate().for_each(|(i, ring)| ring.add::(&a.0[i], &b.0[i], &mut c.0[i])); + pub fn add( + &self, + a: &PolyRNS, + b: &PolyRNS, + c: &mut PolyRNS, + ) { + debug_assert!( + a.level() >= self.level(), + "a.level()={} < self.level()={}", + a.level(), + self.level() + ); + debug_assert!( + b.level() >= self.level(), + "b.level()={} < self.level()={}", + b.level(), + self.level() + ); + debug_assert!( + c.level() >= self.level(), + "c.level()={} < self.level()={}", + c.level(), + self.level() + ); + self.0 + .iter() + .enumerate() + .for_each(|(i, ring)| ring.add::(&a.0[i], &b.0[i], &mut c.0[i])); } #[inline(always)] - pub fn add_inplace(&self, a: &PolyRNS, b: &mut PolyRNS){ - debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); - debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); - self.0.iter().enumerate().for_each(|(i, ring)| ring.add_inplace::(&a.0[i], &mut b.0[i])); + pub fn add_inplace(&self, a: &PolyRNS, b: &mut PolyRNS) { + debug_assert!( + a.level() >= self.level(), + "a.level()={} < self.level()={}", + a.level(), + self.level() + ); + debug_assert!( + b.level() >= self.level(), + "b.level()={} < self.level()={}", + b.level(), + self.level() + ); + self.0 + .iter() + .enumerate() + .for_each(|(i, ring)| ring.add_inplace::(&a.0[i], &mut b.0[i])); } #[inline(always)] - pub fn sub(&self, a: &PolyRNS, b: &PolyRNS, c: &mut PolyRNS){ - debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); - debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); - debug_assert!(c.level() >= self.level(), "c.level()={} < self.level()={}", c.level(), self.level()); - self.0.iter().enumerate().for_each(|(i, ring)| ring.sub::(&a.0[i], &b.0[i], &mut c.0[i])); + pub fn sub( + &self, + a: &PolyRNS, + b: &PolyRNS, + c: &mut PolyRNS, + ) { + debug_assert!( + a.level() >= self.level(), + "a.level()={} < self.level()={}", + a.level(), + self.level() + ); + debug_assert!( + b.level() >= self.level(), + "b.level()={} < self.level()={}", + b.level(), + self.level() + ); + debug_assert!( + c.level() >= self.level(), + "c.level()={} < self.level()={}", + c.level(), + self.level() + ); + self.0 + .iter() + .enumerate() + .for_each(|(i, ring)| ring.sub::(&a.0[i], &b.0[i], &mut c.0[i])); } #[inline(always)] - pub fn sub_inplace(&self, a: &PolyRNS, b: &mut PolyRNS){ - debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); - debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); - self.0.iter().enumerate().for_each(|(i, ring)| ring.sub_inplace::(&a.0[i], &mut b.0[i])); + pub fn sub_inplace(&self, a: &PolyRNS, b: &mut PolyRNS) { + debug_assert!( + a.level() >= self.level(), + "a.level()={} < self.level()={}", + a.level(), + self.level() + ); + debug_assert!( + b.level() >= self.level(), + "b.level()={} < self.level()={}", + b.level(), + self.level() + ); + self.0 + .iter() + .enumerate() + .for_each(|(i, ring)| ring.sub_inplace::(&a.0[i], &mut b.0[i])); } #[inline(always)] - pub fn neg(&self, a: &PolyRNS, b: &mut PolyRNS){ - debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); - debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); - self.0.iter().enumerate().for_each(|(i, ring)| ring.neg::(&a.0[i], &mut b.0[i])); + pub fn neg(&self, a: &PolyRNS, b: &mut PolyRNS) { + debug_assert!( + a.level() >= self.level(), + "a.level()={} < self.level()={}", + a.level(), + self.level() + ); + debug_assert!( + b.level() >= self.level(), + "b.level()={} < self.level()={}", + b.level(), + self.level() + ); + self.0 + .iter() + .enumerate() + .for_each(|(i, ring)| ring.neg::(&a.0[i], &mut b.0[i])); } #[inline(always)] - pub fn neg_inplace(&self, a: &mut PolyRNS){ - debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); - self.0.iter().enumerate().for_each(|(i, ring)| ring.neg_inplace::(&mut a.0[i])); + pub fn neg_inplace(&self, a: &mut PolyRNS) { + debug_assert!( + a.level() >= self.level(), + "a.level()={} < self.level()={}", + a.level(), + self.level() + ); + self.0 + .iter() + .enumerate() + .for_each(|(i, ring)| ring.neg_inplace::(&mut a.0[i])); } #[inline(always)] - pub fn mul_montgomery_external(&self, a:&PolyRNS>, b:&PolyRNS, c: &mut PolyRNS){ - debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); - debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); - debug_assert!(c.level() >= self.level(), "c.level()={} < self.level()={}", c.level(), self.level()); - self.0.iter().enumerate().for_each(|(i, ring)| ring.mul_montgomery_external::(&a.0[i], &b.0[i], &mut c.0[i])); + pub fn mul_montgomery_external( + &self, + a: &PolyRNS>, + b: &PolyRNS, + c: &mut PolyRNS, + ) { + debug_assert!( + a.level() >= self.level(), + "a.level()={} < self.level()={}", + a.level(), + self.level() + ); + debug_assert!( + b.level() >= self.level(), + "b.level()={} < self.level()={}", + b.level(), + self.level() + ); + debug_assert!( + c.level() >= self.level(), + "c.level()={} < self.level()={}", + c.level(), + self.level() + ); + self.0.iter().enumerate().for_each(|(i, ring)| { + ring.mul_montgomery_external::(&a.0[i], &b.0[i], &mut c.0[i]) + }); } #[inline(always)] - pub fn mul_montgomery_external_inplace(&self, a:&PolyRNS>, b:&mut PolyRNS){ - debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); - debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); - self.0.iter().enumerate().for_each(|(i, ring)| ring.mul_montgomery_external_inplace::(&a.0[i], &mut b.0[i])); + pub fn mul_montgomery_external_inplace( + &self, + a: &PolyRNS>, + b: &mut PolyRNS, + ) { + debug_assert!( + a.level() >= self.level(), + "a.level()={} < self.level()={}", + a.level(), + self.level() + ); + debug_assert!( + b.level() >= self.level(), + "b.level()={} < self.level()={}", + b.level(), + self.level() + ); + self.0.iter().enumerate().for_each(|(i, ring)| { + ring.mul_montgomery_external_inplace::(&a.0[i], &mut b.0[i]) + }); } #[inline(always)] - pub fn mul_scalar(&self, a: &PolyRNS, b: &u64, c: &mut PolyRNS){ - debug_assert!(a.level() >= self.level(), "a.level()={} < self.level()={}", a.level(), self.level()); - debug_assert!(c.level() >= self.level(), "b.level()={} < self.level()={}", c.level(), self.level()); - self.0.iter().enumerate().for_each(|(i, ring)| ring.mul_scalar::(&a.0[i], b, &mut c.0[i])); + pub fn mul_scalar( + &self, + a: &PolyRNS, + b: &u64, + c: &mut PolyRNS, + ) { + debug_assert!( + a.level() >= self.level(), + "a.level()={} < self.level()={}", + a.level(), + self.level() + ); + debug_assert!( + c.level() >= self.level(), + "b.level()={} < self.level()={}", + c.level(), + self.level() + ); + self.0 + .iter() + .enumerate() + .for_each(|(i, ring)| ring.mul_scalar::(&a.0[i], b, &mut c.0[i])); } #[inline(always)] - pub fn mul_scalar_inplace(&self, a: &u64, b: &mut PolyRNS){ - debug_assert!(b.level() >= self.level(), "b.level()={} < self.level()={}", b.level(), self.level()); - self.0.iter().enumerate().for_each(|(i, ring)| ring.mul_scalar_inplace::(a, &mut b.0[i])); + pub fn mul_scalar_inplace(&self, a: &u64, b: &mut PolyRNS) { + debug_assert!( + b.level() >= self.level(), + "b.level()={} < self.level()={}", + b.level(), + self.level() + ); + self.0 + .iter() + .enumerate() + .for_each(|(i, ring)| ring.mul_scalar_inplace::(a, &mut b.0[i])); } -} \ No newline at end of file +} diff --git a/math/src/ring/impl_u64/sampling.rs b/math/src/ring/impl_u64/sampling.rs index c788440..8a2ab27 100644 --- a/math/src/ring/impl_u64/sampling.rs +++ b/math/src/ring/impl_u64/sampling.rs @@ -1,18 +1,22 @@ -use sampling::source::Source; use crate::modulus::WordOps; -use crate::ring::{Ring, RingRNS}; use crate::poly::{Poly, PolyRNS}; +use crate::ring::{Ring, RingRNS}; +use sampling::source::Source; -impl Ring{ - pub fn fill_uniform(&self, source: &mut Source, a: &mut Poly){ - let max:u64 = self.modulus.q; +impl Ring { + pub fn fill_uniform(&self, source: &mut Source, a: &mut Poly) { + let max: u64 = self.modulus.q; let mask: u64 = max.mask(); - a.0.iter_mut().for_each(|a|{*a = source.next_u64n(max, mask)}); + a.0.iter_mut() + .for_each(|a| *a = source.next_u64n(max, mask)); } } -impl RingRNS<'_, u64>{ - pub fn fill_uniform(&self, source: &mut Source, a: &mut PolyRNS){ - self.0.iter().enumerate().for_each(|(i, r)|{r.fill_uniform(source, a.at_mut(i))}); +impl RingRNS<'_, u64> { + pub fn fill_uniform(&self, source: &mut Source, a: &mut PolyRNS) { + self.0 + .iter() + .enumerate() + .for_each(|(i, r)| r.fill_uniform(source, a.at_mut(i))); } -} \ No newline at end of file +} diff --git a/math/src/scalar.rs b/math/src/scalar.rs index 1f13e60..3ce014f 100644 --- a/math/src/scalar.rs +++ b/math/src/scalar.rs @@ -1,2 +1,2 @@ #[derive(Clone, Debug, PartialEq, Eq)] -pub struct ScalarRNS(pub Vec); \ No newline at end of file +pub struct ScalarRNS(pub Vec); diff --git a/math/tests/rescaling_rns.rs b/math/tests/rescaling_rns.rs index 84d8e29..b5c7dba 100644 --- a/math/tests/rescaling_rns.rs +++ b/math/tests/rescaling_rns.rs @@ -1,55 +1,58 @@ -use num_bigint::BigInt; -use num_bigint::Sign; -use math::ring::{Ring, RingRNS}; use math::poly::PolyRNS; use math::ring::impl_u64::ring_rns::new_rings; -use sampling::source::Source; +use math::ring::{Ring, RingRNS}; +use num_bigint::BigInt; +use num_bigint::Sign; +use sampling::source::Source; #[test] -fn rescaling_rns_u64(){ - let n = 1<<10; +fn rescaling_rns_u64() { + let n = 1 << 10; let moduli: Vec = vec![0x1fffffffffc80001u64, 0x1fffffffffe00001u64]; let rings: Vec> = new_rings(n, moduli); let ring_rns: RingRNS<'_, u64> = RingRNS::new(&rings); - + test_div_floor_by_last_modulus::(&ring_rns); test_div_floor_by_last_modulus::(&ring_rns); } -fn test_div_floor_by_last_modulus(ring_rns: &RingRNS) { - - let seed: [u8; 32] = [0;32]; +fn test_div_floor_by_last_modulus(ring_rns: &RingRNS) { + let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); let mut a: PolyRNS = ring_rns.new_polyrns(); let mut b: PolyRNS = ring_rns.new_polyrns(); - let mut c: PolyRNS = ring_rns.at_level(ring_rns.level()-1).new_polyrns(); + let mut c: PolyRNS = ring_rns.at_level(ring_rns.level() - 1).new_polyrns(); // Allocates a random PolyRNS ring_rns.fill_uniform(&mut source, &mut a); // Maps PolyRNS to [BigInt] - let mut coeffs_a: Vec = (0..a.n()).map(|i|{BigInt::from(i)}).collect(); - ring_rns.at_level(a.level()).to_bigint_inplace(&a, 1, &mut coeffs_a); + let mut coeffs_a: Vec = (0..a.n()).map(|i| BigInt::from(i)).collect(); + ring_rns + .at_level(a.level()) + .to_bigint_inplace(&a, 1, &mut coeffs_a); // Performs c = intt(ntt(a) / q_level) - if NTT{ + if NTT { ring_rns.ntt_inplace::(&mut a); } - + ring_rns.div_floor_by_last_modulus::(&a, &mut b, &mut c); - if NTT{ + if NTT { ring_rns.at_level(c.level()).intt_inplace::(&mut c); } - + // Exports c to coeffs_c - let mut coeffs_c = vec![BigInt::from(0);c.n()]; - ring_rns.at_level(c.level()).to_bigint_inplace(&c, 1, &mut coeffs_c); + let mut coeffs_c = vec![BigInt::from(0); c.n()]; + ring_rns + .at_level(c.level()) + .to_bigint_inplace(&c, 1, &mut coeffs_c); // Performs floor division on a let scalar_big = BigInt::from(ring_rns.0[ring_rns.level()].modulus.q); - coeffs_a.iter_mut().for_each(|a|{ + coeffs_a.iter_mut().for_each(|a| { // Emulates floor division in [0, q-1] and maps to [-(q-1)/2, (q-1)/2-1] *a /= &scalar_big; if a.sign() == Sign::Minus { @@ -58,4 +61,4 @@ fn test_div_floor_by_last_modulus(ring_rns: &RingRNS) { }); assert!(coeffs_a == coeffs_c); -} \ No newline at end of file +} diff --git a/sampling/src/lib.rs b/sampling/src/lib.rs index b5cb700..1779b0b 100644 --- a/sampling/src/lib.rs +++ b/sampling/src/lib.rs @@ -1 +1 @@ -pub mod source; \ No newline at end of file +pub mod source; diff --git a/sampling/src/source.rs b/sampling/src/source.rs index 55f651c..edd3cde 100644 --- a/sampling/src/source.rs +++ b/sampling/src/source.rs @@ -1,45 +1,47 @@ use rand_chacha::rand_core::SeedableRng; +use rand_chacha::ChaCha8Rng; use rand_core::RngCore; -use rand_chacha::{ChaCha8Rng}; const MAXF64: f64 = 9007199254740992.0; -pub struct Source{ - source:ChaCha8Rng, +pub struct Source { + source: ChaCha8Rng, } -impl Source{ - pub fn new(seed: [u8;32]) -> Source{ - Source{source:ChaCha8Rng::from_seed(seed)} +impl Source { + pub fn new(seed: [u8; 32]) -> Source { + Source { + source: ChaCha8Rng::from_seed(seed), + } } - pub fn new_seed(&mut self) -> [u8;32]{ - let mut seed: [u8; 32] = [0u8;32]; + pub fn new_seed(&mut self) -> [u8; 32] { + let mut seed: [u8; 32] = [0u8; 32]; self.source.fill_bytes(&mut seed); seed } #[inline(always)] - pub fn next_u64(&mut self) -> u64{ + pub fn next_u64(&mut self) -> u64 { self.source.next_u64() } #[inline(always)] - pub fn next_u64n(&mut self, max: u64, mask: u64) -> u64{ + pub fn next_u64n(&mut self, max: u64, mask: u64) -> u64 { let mut x: u64 = self.next_u64() & mask; - while x >= max{ + while x >= max { x = self.next_u64() & mask; } x } #[inline(always)] - pub fn next_f64(&mut self, min: f64, max: f64) -> f64{ - min + ((self.next_u64()<<11>>11) as f64)/MAXF64 * (max-min) + pub fn next_f64(&mut self, min: f64, max: f64) -> f64 { + min + ((self.next_u64() << 11 >> 11) as f64) / MAXF64 * (max - min) } #[inline(always)] - pub fn fill_bytes(&mut self, bytes: &mut [u8]){ + pub fn fill_bytes(&mut self, bytes: &mut [u8]) { self.source.fill_bytes(bytes) } -} \ No newline at end of file +}