From c1ed2e38fa3742873e3764561c03d2463fcd87ae Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 8 Jan 2025 15:24:21 +0100 Subject: [PATCH] fixed automorphism for ring and added test --- math/benches/ntt.rs | 8 +- math/benches/ring_rns.rs | 2 +- math/examples/main.rs | 6 +- math/src/dft/ntt.rs | 8 +- math/src/ring/impl_u64/automorphism.rs | 119 ++++++++++++++++++++----- math/src/ring/impl_u64/ring.rs | 2 +- math/src/ring/impl_u64/sampling.rs | 2 +- math/tests/automorphism.rs | 53 +++++++++++ 8 files changed, 163 insertions(+), 37 deletions(-) create mode 100644 math/tests/automorphism.rs diff --git a/math/benches/ntt.rs b/math/benches/ntt.rs index bb530d5..c287389 100644 --- a/math/benches/ntt.rs +++ b/math/benches/ntt.rs @@ -3,7 +3,7 @@ use math::dft::DFT; use math::{dft::ntt::Table, modulus::prime::Prime}; fn forward_inplace(c: &mut Criterion) { - fn runner(prime_instance: Prime, nth_root: u64) -> Box { + fn runner(prime_instance: Prime, nth_root: usize) -> Box { let ntt_table: Table = Table::::new(prime_instance, nth_root); let mut a: Vec = vec![0; (nth_root >> 1) as usize]; for i in 0..a.len() { @@ -26,7 +26,7 @@ fn forward_inplace(c: &mut Criterion) { } fn forward_inplace_lazy(c: &mut Criterion) { - fn runner(prime_instance: Prime, nth_root: u64) -> Box { + fn runner(prime_instance: Prime, nth_root: usize) -> Box { let ntt_table: Table = Table::::new(prime_instance, nth_root); let mut a: Vec = vec![0; (nth_root >> 1) as usize]; for i in 0..a.len() { @@ -49,7 +49,7 @@ fn forward_inplace_lazy(c: &mut Criterion) { } fn backward_inplace(c: &mut Criterion) { - fn runner(prime_instance: Prime, nth_root: u64) -> Box { + fn runner(prime_instance: Prime, nth_root: usize) -> Box { let ntt_table: Table = Table::::new(prime_instance, nth_root); let mut a: Vec = vec![0; (nth_root >> 1) as usize]; for i in 0..a.len() { @@ -72,7 +72,7 @@ fn backward_inplace(c: &mut Criterion) { } fn backward_inplace_lazy(c: &mut Criterion) { - fn runner(prime_instance: Prime, nth_root: u64) -> Box { + fn runner(prime_instance: Prime, nth_root: usize) -> Box { let ntt_table: Table = Table::::new(prime_instance, nth_root); let mut a: Vec = vec![0; (nth_root >> 1) as usize]; for i in 0..a.len() { diff --git a/math/benches/ring_rns.rs b/math/benches/ring_rns.rs index 2ffb1c6..6beb239 100644 --- a/math/benches/ring_rns.rs +++ b/math/benches/ring_rns.rs @@ -5,7 +5,7 @@ use math::ring::RingRNS; fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) { fn runner(r: RingRNS) -> Box { let a: PolyRNS = r.new_polyrns(); - let mut b: PolyRNS = r.new_polyrns(); + let mut b: [math::poly::Poly; 2] = [r.new_poly(), r.new_poly()]; let mut c: PolyRNS = r.new_polyrns(); Box::new(move || r.div_by_last_modulus::(&a, &mut b, &mut c)) diff --git a/math/examples/main.rs b/math/examples/main.rs index bdd9b74..96242bd 100644 --- a/math/examples/main.rs +++ b/math/examples/main.rs @@ -14,8 +14,8 @@ fn main() { println!("q_base: {}", prime_instance.q_base()); println!("q_power: {}", prime_instance.q_power()); - let n: u64 = 32; - let nth_root: u64 = n << 1; + let n: usize = 32; + let nth_root: usize = n << 1; let ntt_table: Table = Table::::new(prime_instance, nth_root); @@ -44,7 +44,7 @@ fn main() { p0.0[i] = i as u64 } - r.automorphism(p0, (2 * r.n - 1) as u64, &mut p1); + r.automorphism::(&p0, 2 * r.n - 1, nth_root, &mut p1); println!("{:?}", p1); } diff --git a/math/src/dft/ntt.rs b/math/src/dft/ntt.rs index 5612c1e..a2fdfc2 100644 --- a/math/src/dft/ntt.rs +++ b/math/src/dft/ntt.rs @@ -19,14 +19,14 @@ pub struct Table { } impl Table { - pub fn new(prime: Prime, nth_root: u64) -> Table { + pub fn new(prime: Prime, nth_root: usize) -> Table { assert!( nth_root & (nth_root - 1) == 0, "invalid argument: nth_root = {} is not a power of two", nth_root ); - let psi: u64 = prime.primitive_nth_root(nth_root); + let psi: u64 = prime.primitive_nth_root(nth_root as u64); let psi_mont: Montgomery = prime.montgomery.prepare::(psi); let psi_inv_mont: Montgomery = prime.montgomery.pow(psi_mont, prime.phi - 1); @@ -311,8 +311,8 @@ mod tests { let q_base: u64 = 0x800000000004001; let q_power: usize = 1; let prime_instance: Prime = Prime::::new(q_base, q_power); - let n: u64 = 32; - let two_nth_root: u64 = n << 1; + let n: usize = 32; + let two_nth_root: usize = n << 1; let ntt_table: Table = Table::::new(prime_instance, two_nth_root); let mut a: Vec = vec![0; n as usize]; for i in 0..a.len() { diff --git a/math/src/ring/impl_u64/automorphism.rs b/math/src/ring/impl_u64/automorphism.rs index 4c12f07..a366d00 100644 --- a/math/src/ring/impl_u64/automorphism.rs +++ b/math/src/ring/impl_u64/automorphism.rs @@ -5,10 +5,10 @@ use crate::ring::Ring; /// Returns a lookup table for the automorphism X^{i} -> X^{i * k mod nth_root}. /// Method will panic if n or nth_root are not power-of-two. /// Method will panic if gal_el is not coprime with nth_root. -pub fn automorphism_index_ntt(n: usize, nth_root: u64, gal_el: u64) -> Vec { +pub fn automorphism_index(n: usize, nth_root: usize, gal_el: usize) -> Vec { assert!(n & (n - 1) != 0, "invalid n={}: not a power-of-two", n); assert!( - nth_root & (nth_root - 1) != 0, + nth_root & (nth_root - 1) == 0, "invalid nth_root={}: not a power-of-two", n ); @@ -19,39 +19,112 @@ pub fn automorphism_index_ntt(n: usize, nth_root: u64, gal_el: u64) -> Vec nth_root ); - let mask = nth_root - 1; - let log_nth_root: u32 = nth_root.log2() as u32; - let mut index: Vec = Vec::with_capacity(n); - for i in 0..n { - let i_rev: usize = 2 * i.reverse_bits_msb(log_nth_root) + 1; - let gal_el_i: u64 = (gal_el * (i_rev as u64) & mask) >> 1; - index.push(gal_el_i.reverse_bits_msb(log_nth_root)); + let mut index: Vec = Vec::with_capacity(n); + + if NTT { + let mask = nth_root - 1; + let log_nth_root_half: u32 = nth_root.log2() as u32 - 1; + for i in 0..n { + let i_rev: usize = 2 * i.reverse_bits_msb(log_nth_root_half) + 1; + let gal_el_i: usize = ((gal_el * i_rev) & mask) >> 1; + index.push(gal_el_i.reverse_bits_msb(log_nth_root_half)); + } + } else { + let log_n: usize = n.log2(); + let mask: usize = (n - 1) as usize; + for i in 0..n { + let gal_el_i: usize = i as usize * gal_el; + let sign: usize = (gal_el_i >> log_n) & 1; + let i_out: usize = (gal_el_i & mask) | (sign << (usize::BITS - 1)); + index.push(i_out) + } } + index } impl Ring { - pub fn automorphism(&self, a: Poly, gal_el: u64, b: &mut Poly) { + pub fn automorphism( + &self, + a: &Poly, + gal_el: usize, + nth_root: usize, + b: &mut Poly, + ) { debug_assert!( a.n() == b.n(), "invalid inputs: a.n() = {} != b.n() = {}", a.n(), b.n() ); - debug_assert!(gal_el & 1 == 1, "invalid gal_el = {}: not odd", gal_el); - let n: usize = a.n(); - let mask: u64 = (n - 1) as u64; - let log_n: usize = n.log2(); - let q: u64 = self.modulus.q(); - let b_vec: &mut _ = &mut b.0; - let a_vec: &_ = &a.0; + assert!( + gal_el & 1 == 1, + "invalid gal_el={}: not coprime with nth_root={}", + gal_el, + nth_root + ); - a_vec.iter().enumerate().for_each(|(i, ai)| { - let gal_el_i: u64 = i as u64 * gal_el; - let sign: u64 = (gal_el_i >> log_n) & 1; - let i_out: u64 = gal_el_i & mask; - b_vec[i_out as usize] = ai * (sign ^ 1) | (q - ai) * sign - }); + assert!( + nth_root & (nth_root - 1) == 0, + "invalid nth_root={}: not a power-of-two", + nth_root + ); + + let b_vec: &mut Vec = &mut b.0; + let a_vec: &Vec = &a.0; + + if NTT { + let mask: usize = nth_root - 1; + let log_nth_root_half: u32 = nth_root.log2() as u32 - 1; + a_vec.iter().enumerate().for_each(|(i, ai)| { + let i_rev: usize = 2 * i.reverse_bits_msb(log_nth_root_half) + 1; + let gal_el_i: usize = (((gal_el * i_rev) & mask) - 1) >> 1; + let idx: usize = gal_el_i.reverse_bits_msb(log_nth_root_half); + b_vec[idx] = *ai; + }); + } else { + let n: usize = a.n(); + let mask: usize = n - 1; + let log_n: usize = n.log2(); + let q: u64 = self.modulus.q(); + a_vec.iter().enumerate().for_each(|(i, ai)| { + let gal_el_i: usize = i * gal_el; + let sign: u64 = ((gal_el_i >> log_n) & 1) as u64; + let i_out: usize = gal_el_i & mask; + b_vec[i_out] = ai * (sign ^ 1) | (q - ai) * sign + }); + } + } + + pub fn automorphism_from_index( + &self, + a: &Poly, + idx: &[usize], + b: &mut Poly, + ) { + debug_assert!( + a.n() == b.n(), + "invalid inputs: a.n() = {} != b.n() = {}", + a.n(), + b.n() + ); + + let b_vec: &mut Vec = &mut b.0; + let a_vec: &Vec = &a.0; + + if NTT { + a_vec.iter().enumerate().for_each(|(i, ai)| { + b_vec[idx[i]] = *ai; + }); + } else { + let n: usize = a.n(); + let mask: usize = n - 1; + let q: u64 = self.modulus.q(); + a_vec.iter().enumerate().for_each(|(i, ai)| { + let sign: u64 = (idx[i] >> usize::BITS - 1) as u64; + b_vec[idx[i] & mask] = ai * (sign ^ 1) | (q - ai) * sign; + }); + } } } diff --git a/math/src/ring/impl_u64/ring.rs b/math/src/ring/impl_u64/ring.rs index 29bf97a..5c22b7b 100644 --- a/math/src/ring/impl_u64/ring.rs +++ b/math/src/ring/impl_u64/ring.rs @@ -16,7 +16,7 @@ impl Ring { Self { n: n, modulus: prime.clone(), - dft: Box::new(Table::::new(prime, (2 * n) as u64)), + dft: Box::new(Table::::new(prime, n << 1)), } } diff --git a/math/src/ring/impl_u64/sampling.rs b/math/src/ring/impl_u64/sampling.rs index 4955a00..3af27bf 100644 --- a/math/src/ring/impl_u64/sampling.rs +++ b/math/src/ring/impl_u64/sampling.rs @@ -2,7 +2,7 @@ use crate::modulus::WordOps; use crate::poly::{Poly, PolyRNS}; use crate::ring::{Ring, RingRNS}; use num::ToPrimitive; -use rand_distr::{Distribution, Normal}; +use rand_distr::Distribution; use sampling::source::Source; impl Ring { diff --git a/math/tests/automorphism.rs b/math/tests/automorphism.rs new file mode 100644 index 0000000..93b95e4 --- /dev/null +++ b/math/tests/automorphism.rs @@ -0,0 +1,53 @@ +use itertools::izip; +use math::poly::Poly; +use math::ring::Ring; + +#[test] +fn automorphism_u64() { + let n: usize = 1 << 4; + let nth_root: usize = n << 1; + let q_base: u64 = 65537u64; + let q_power: usize = 1usize; + let ring: Ring = Ring::new(n, q_base, q_power); + + sub_test("test_automorphism_u64::", || { + test_automorphism_u64::(&ring, nth_root) + }); + sub_test("test_automorphism_u64::", || { + test_automorphism_u64::(&ring, nth_root) + }); +} + +fn sub_test(name: &str, f: F) { + println!("Running {}", name); + f(); +} + +fn test_automorphism_u64(ring: &Ring, nth_root: usize) { + let n: usize = ring.n(); + let q: u64 = ring.modulus.q; + + let mut p0: Poly = ring.new_poly(); + let mut p1: Poly = ring.new_poly(); + + for i in 0..p0.n() { + p0.0[i] = i as u64 + } + + if NTT { + ring.ntt_inplace::(&mut p0); + } + + ring.automorphism::(&p0, 2 * n - 1, nth_root, &mut p1); + + if NTT { + ring.intt_inplace::(&mut p1); + } + + p0.0[0] = 0; + for i in 1..p0.n() { + p0.0[i] = q - (n - i) as u64 + } + + izip!(p0.0, p1.0).for_each(|(a, b)| assert_eq!(a, b)); +}