fixed automorphism for ring and added test

This commit is contained in:
Jean-Philippe Bossuat
2025-01-08 15:24:21 +01:00
parent e4c19a163e
commit c1ed2e38fa
8 changed files with 163 additions and 37 deletions

View File

@@ -3,7 +3,7 @@ use math::dft::DFT;
use math::{dft::ntt::Table, modulus::prime::Prime}; use math::{dft::ntt::Table, modulus::prime::Prime};
fn forward_inplace(c: &mut Criterion) { fn forward_inplace(c: &mut Criterion) {
fn runner(prime_instance: Prime<u64>, nth_root: u64) -> Box<dyn FnMut()> { fn runner(prime_instance: Prime<u64>, nth_root: usize) -> Box<dyn FnMut()> {
let ntt_table: Table<u64> = Table::<u64>::new(prime_instance, nth_root); let ntt_table: Table<u64> = Table::<u64>::new(prime_instance, nth_root);
let mut a: Vec<u64> = vec![0; (nth_root >> 1) as usize]; let mut a: Vec<u64> = vec![0; (nth_root >> 1) as usize];
for i in 0..a.len() { for i in 0..a.len() {
@@ -26,7 +26,7 @@ fn forward_inplace(c: &mut Criterion) {
} }
fn forward_inplace_lazy(c: &mut Criterion) { fn forward_inplace_lazy(c: &mut Criterion) {
fn runner(prime_instance: Prime<u64>, nth_root: u64) -> Box<dyn FnMut()> { fn runner(prime_instance: Prime<u64>, nth_root: usize) -> Box<dyn FnMut()> {
let ntt_table: Table<u64> = Table::<u64>::new(prime_instance, nth_root); let ntt_table: Table<u64> = Table::<u64>::new(prime_instance, nth_root);
let mut a: Vec<u64> = vec![0; (nth_root >> 1) as usize]; let mut a: Vec<u64> = vec![0; (nth_root >> 1) as usize];
for i in 0..a.len() { for i in 0..a.len() {
@@ -49,7 +49,7 @@ fn forward_inplace_lazy(c: &mut Criterion) {
} }
fn backward_inplace(c: &mut Criterion) { fn backward_inplace(c: &mut Criterion) {
fn runner(prime_instance: Prime<u64>, nth_root: u64) -> Box<dyn FnMut()> { fn runner(prime_instance: Prime<u64>, nth_root: usize) -> Box<dyn FnMut()> {
let ntt_table: Table<u64> = Table::<u64>::new(prime_instance, nth_root); let ntt_table: Table<u64> = Table::<u64>::new(prime_instance, nth_root);
let mut a: Vec<u64> = vec![0; (nth_root >> 1) as usize]; let mut a: Vec<u64> = vec![0; (nth_root >> 1) as usize];
for i in 0..a.len() { for i in 0..a.len() {
@@ -72,7 +72,7 @@ fn backward_inplace(c: &mut Criterion) {
} }
fn backward_inplace_lazy(c: &mut Criterion) { fn backward_inplace_lazy(c: &mut Criterion) {
fn runner(prime_instance: Prime<u64>, nth_root: u64) -> Box<dyn FnMut()> { fn runner(prime_instance: Prime<u64>, nth_root: usize) -> Box<dyn FnMut()> {
let ntt_table: Table<u64> = Table::<u64>::new(prime_instance, nth_root); let ntt_table: Table<u64> = Table::<u64>::new(prime_instance, nth_root);
let mut a: Vec<u64> = vec![0; (nth_root >> 1) as usize]; let mut a: Vec<u64> = vec![0; (nth_root >> 1) as usize];
for i in 0..a.len() { for i in 0..a.len() {

View File

@@ -5,7 +5,7 @@ use math::ring::RingRNS;
fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) { fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) {
fn runner(r: RingRNS<u64>) -> Box<dyn FnMut()> { fn runner(r: RingRNS<u64>) -> Box<dyn FnMut()> {
let a: PolyRNS<u64> = r.new_polyrns(); let a: PolyRNS<u64> = r.new_polyrns();
let mut b: PolyRNS<u64> = r.new_polyrns(); let mut b: [math::poly::Poly<u64>; 2] = [r.new_poly(), r.new_poly()];
let mut c: PolyRNS<u64> = r.new_polyrns(); let mut c: PolyRNS<u64> = r.new_polyrns();
Box::new(move || r.div_by_last_modulus::<false, true>(&a, &mut b, &mut c)) Box::new(move || r.div_by_last_modulus::<false, true>(&a, &mut b, &mut c))

View File

@@ -14,8 +14,8 @@ fn main() {
println!("q_base: {}", prime_instance.q_base()); println!("q_base: {}", prime_instance.q_base());
println!("q_power: {}", prime_instance.q_power()); println!("q_power: {}", prime_instance.q_power());
let n: u64 = 32; let n: usize = 32;
let nth_root: u64 = n << 1; let nth_root: usize = n << 1;
let ntt_table: Table<u64> = Table::<u64>::new(prime_instance, nth_root); let ntt_table: Table<u64> = Table::<u64>::new(prime_instance, nth_root);
@@ -44,7 +44,7 @@ fn main() {
p0.0[i] = i as u64 p0.0[i] = i as u64
} }
r.automorphism(p0, (2 * r.n - 1) as u64, &mut p1); r.automorphism::<false>(&p0, 2 * r.n - 1, nth_root, &mut p1);
println!("{:?}", p1); println!("{:?}", p1);
} }

View File

@@ -19,14 +19,14 @@ pub struct Table<O> {
} }
impl Table<u64> { impl Table<u64> {
pub fn new(prime: Prime<u64>, nth_root: u64) -> Table<u64> { pub fn new(prime: Prime<u64>, nth_root: usize) -> Table<u64> {
assert!( assert!(
nth_root & (nth_root - 1) == 0, nth_root & (nth_root - 1) == 0,
"invalid argument: nth_root = {} is not a power of two", "invalid argument: nth_root = {} is not a power of two",
nth_root nth_root
); );
let psi: u64 = prime.primitive_nth_root(nth_root); let psi: u64 = prime.primitive_nth_root(nth_root as u64);
let psi_mont: Montgomery<u64> = prime.montgomery.prepare::<ONCE>(psi); let psi_mont: Montgomery<u64> = prime.montgomery.prepare::<ONCE>(psi);
let psi_inv_mont: Montgomery<u64> = prime.montgomery.pow(psi_mont, prime.phi - 1); let psi_inv_mont: Montgomery<u64> = prime.montgomery.pow(psi_mont, prime.phi - 1);
@@ -311,8 +311,8 @@ mod tests {
let q_base: u64 = 0x800000000004001; let q_base: u64 = 0x800000000004001;
let q_power: usize = 1; let q_power: usize = 1;
let prime_instance: Prime<u64> = Prime::<u64>::new(q_base, q_power); let prime_instance: Prime<u64> = Prime::<u64>::new(q_base, q_power);
let n: u64 = 32; let n: usize = 32;
let two_nth_root: u64 = n << 1; let two_nth_root: usize = n << 1;
let ntt_table: Table<u64> = Table::<u64>::new(prime_instance, two_nth_root); let ntt_table: Table<u64> = Table::<u64>::new(prime_instance, two_nth_root);
let mut a: Vec<u64> = vec![0; n as usize]; let mut a: Vec<u64> = vec![0; n as usize];
for i in 0..a.len() { for i in 0..a.len() {

View File

@@ -5,10 +5,10 @@ use crate::ring::Ring;
/// Returns a lookup table for the automorphism X^{i} -> X^{i * k mod nth_root}. /// Returns a lookup table for the automorphism X^{i} -> X^{i * k mod nth_root}.
/// Method will panic if n or nth_root are not power-of-two. /// Method will panic if n or nth_root are not power-of-two.
/// Method will panic if gal_el is not coprime with nth_root. /// Method will panic if gal_el is not coprime with nth_root.
pub fn automorphism_index_ntt(n: usize, nth_root: u64, gal_el: u64) -> Vec<u64> { pub fn automorphism_index<const NTT: bool>(n: usize, nth_root: usize, gal_el: usize) -> Vec<usize> {
assert!(n & (n - 1) != 0, "invalid n={}: not a power-of-two", n); assert!(n & (n - 1) != 0, "invalid n={}: not a power-of-two", n);
assert!( assert!(
nth_root & (nth_root - 1) != 0, nth_root & (nth_root - 1) == 0,
"invalid nth_root={}: not a power-of-two", "invalid nth_root={}: not a power-of-two",
n n
); );
@@ -19,39 +19,112 @@ pub fn automorphism_index_ntt(n: usize, nth_root: u64, gal_el: u64) -> Vec<u64>
nth_root nth_root
); );
let mut index: Vec<usize> = Vec::with_capacity(n);
if NTT {
let mask = nth_root - 1; let mask = nth_root - 1;
let log_nth_root: u32 = nth_root.log2() as u32; let log_nth_root_half: u32 = nth_root.log2() as u32 - 1;
let mut index: Vec<u64> = Vec::with_capacity(n);
for i in 0..n { for i in 0..n {
let i_rev: usize = 2 * i.reverse_bits_msb(log_nth_root) + 1; let i_rev: usize = 2 * i.reverse_bits_msb(log_nth_root_half) + 1;
let gal_el_i: u64 = (gal_el * (i_rev as u64) & mask) >> 1; let gal_el_i: usize = ((gal_el * i_rev) & mask) >> 1;
index.push(gal_el_i.reverse_bits_msb(log_nth_root)); index.push(gal_el_i.reverse_bits_msb(log_nth_root_half));
} }
} else {
let log_n: usize = n.log2();
let mask: usize = (n - 1) as usize;
for i in 0..n {
let gal_el_i: usize = i as usize * gal_el;
let sign: usize = (gal_el_i >> log_n) & 1;
let i_out: usize = (gal_el_i & mask) | (sign << (usize::BITS - 1));
index.push(i_out)
}
}
index index
} }
impl Ring<u64> { impl Ring<u64> {
pub fn automorphism(&self, a: Poly<u64>, gal_el: u64, b: &mut Poly<u64>) { pub fn automorphism<const NTT: bool>(
&self,
a: &Poly<u64>,
gal_el: usize,
nth_root: usize,
b: &mut Poly<u64>,
) {
debug_assert!( debug_assert!(
a.n() == b.n(), a.n() == b.n(),
"invalid inputs: a.n() = {} != b.n() = {}", "invalid inputs: a.n() = {} != b.n() = {}",
a.n(), a.n(),
b.n() b.n()
); );
debug_assert!(gal_el & 1 == 1, "invalid gal_el = {}: not odd", gal_el);
assert!(
gal_el & 1 == 1,
"invalid gal_el={}: not coprime with nth_root={}",
gal_el,
nth_root
);
assert!(
nth_root & (nth_root - 1) == 0,
"invalid nth_root={}: not a power-of-two",
nth_root
);
let b_vec: &mut Vec<u64> = &mut b.0;
let a_vec: &Vec<u64> = &a.0;
if NTT {
let mask: usize = nth_root - 1;
let log_nth_root_half: u32 = nth_root.log2() as u32 - 1;
a_vec.iter().enumerate().for_each(|(i, ai)| {
let i_rev: usize = 2 * i.reverse_bits_msb(log_nth_root_half) + 1;
let gal_el_i: usize = (((gal_el * i_rev) & mask) - 1) >> 1;
let idx: usize = gal_el_i.reverse_bits_msb(log_nth_root_half);
b_vec[idx] = *ai;
});
} else {
let n: usize = a.n(); let n: usize = a.n();
let mask: u64 = (n - 1) as u64; let mask: usize = n - 1;
let log_n: usize = n.log2(); let log_n: usize = n.log2();
let q: u64 = self.modulus.q(); let q: u64 = self.modulus.q();
let b_vec: &mut _ = &mut b.0;
let a_vec: &_ = &a.0;
a_vec.iter().enumerate().for_each(|(i, ai)| { a_vec.iter().enumerate().for_each(|(i, ai)| {
let gal_el_i: u64 = i as u64 * gal_el; let gal_el_i: usize = i * gal_el;
let sign: u64 = (gal_el_i >> log_n) & 1; let sign: u64 = ((gal_el_i >> log_n) & 1) as u64;
let i_out: u64 = gal_el_i & mask; let i_out: usize = gal_el_i & mask;
b_vec[i_out as usize] = ai * (sign ^ 1) | (q - ai) * sign b_vec[i_out] = ai * (sign ^ 1) | (q - ai) * sign
}); });
} }
} }
pub fn automorphism_from_index<const NTT: bool>(
&self,
a: &Poly<u64>,
idx: &[usize],
b: &mut Poly<u64>,
) {
debug_assert!(
a.n() == b.n(),
"invalid inputs: a.n() = {} != b.n() = {}",
a.n(),
b.n()
);
let b_vec: &mut Vec<u64> = &mut b.0;
let a_vec: &Vec<u64> = &a.0;
if NTT {
a_vec.iter().enumerate().for_each(|(i, ai)| {
b_vec[idx[i]] = *ai;
});
} else {
let n: usize = a.n();
let mask: usize = n - 1;
let q: u64 = self.modulus.q();
a_vec.iter().enumerate().for_each(|(i, ai)| {
let sign: u64 = (idx[i] >> usize::BITS - 1) as u64;
b_vec[idx[i] & mask] = ai * (sign ^ 1) | (q - ai) * sign;
});
}
}
}

View File

@@ -16,7 +16,7 @@ impl Ring<u64> {
Self { Self {
n: n, n: n,
modulus: prime.clone(), modulus: prime.clone(),
dft: Box::new(Table::<u64>::new(prime, (2 * n) as u64)), dft: Box::new(Table::<u64>::new(prime, n << 1)),
} }
} }

View File

@@ -2,7 +2,7 @@ use crate::modulus::WordOps;
use crate::poly::{Poly, PolyRNS}; use crate::poly::{Poly, PolyRNS};
use crate::ring::{Ring, RingRNS}; use crate::ring::{Ring, RingRNS};
use num::ToPrimitive; use num::ToPrimitive;
use rand_distr::{Distribution, Normal}; use rand_distr::Distribution;
use sampling::source::Source; use sampling::source::Source;
impl Ring<u64> { impl Ring<u64> {

View File

@@ -0,0 +1,53 @@
use itertools::izip;
use math::poly::Poly;
use math::ring::Ring;
#[test]
fn automorphism_u64() {
let n: usize = 1 << 4;
let nth_root: usize = n << 1;
let q_base: u64 = 65537u64;
let q_power: usize = 1usize;
let ring: Ring<u64> = Ring::new(n, q_base, q_power);
sub_test("test_automorphism_u64::<NTT:false>", || {
test_automorphism_u64::<false>(&ring, nth_root)
});
sub_test("test_automorphism_u64::<NTT:true>", || {
test_automorphism_u64::<true>(&ring, nth_root)
});
}
fn sub_test<F: FnOnce()>(name: &str, f: F) {
println!("Running {}", name);
f();
}
fn test_automorphism_u64<const NTT: bool>(ring: &Ring<u64>, nth_root: usize) {
let n: usize = ring.n();
let q: u64 = ring.modulus.q;
let mut p0: Poly<u64> = ring.new_poly();
let mut p1: Poly<u64> = ring.new_poly();
for i in 0..p0.n() {
p0.0[i] = i as u64
}
if NTT {
ring.ntt_inplace::<false>(&mut p0);
}
ring.automorphism::<NTT>(&p0, 2 * n - 1, nth_root, &mut p1);
if NTT {
ring.intt_inplace::<false>(&mut p1);
}
p0.0[0] = 0;
for i in 1..p0.n() {
p0.0[i] = q - (n - i) as u64
}
izip!(p0.0, p1.0).for_each(|(a, b)| assert_eq!(a, b));
}