diff --git a/examples/main.rs b/examples/main.rs index 634257e..0a3a833 100644 --- a/examples/main.rs +++ b/examples/main.rs @@ -17,6 +17,6 @@ fn main() { let n: u64 = 1024; let nth_root: u64 = n<<1; - let ntt_table: Table<'_, u64> = Table::::new(&mut prime_instance, n, nth_root); + let ntt_table: Table<'_, u64> = Table::::new(&mut prime_instance, nth_root); } \ No newline at end of file diff --git a/src/dft/ntt.rs b/src/dft/ntt.rs index ae35c27..0bd5886 100644 --- a/src/dft/ntt.rs +++ b/src/dft/ntt.rs @@ -5,15 +5,12 @@ pub struct Table<'a, O>{ prime:&'a Prime, pub psi_forward_rev:Vec>, psi_backward_rev: Vec>, - n_inv: Montgomery, } impl<'a> Table<'a, u64> { - pub fn new(prime: &'a mut Prime, n: u64, nth_root: u64)->Self{ + pub fn new(prime: &'a mut Prime, nth_root: u64)->Self{ - assert!(n&(n-1) == 0, "invalid argument: n = {} is not a power of two", n); - assert!(n&(n-1) == 0, "invalid argument: nth_root = {} is not a power of two", nth_root); - assert!(n < nth_root, "invalid argument: n = {} cannot be greater or equal to nth_root = {}", n, nth_root); + assert!(nth_root&(nth_root-1) == 0, "invalid argument: nth_root = {} is not a power of two", nth_root); let psi: u64 = prime.primitive_nth_root(nth_root); @@ -26,24 +23,21 @@ impl<'a> Table<'a, u64> { psi_forward_rev[0] = prime.montgomery.one(); psi_backward_rev[0] = prime.montgomery.one(); - let log_nth_root_half: usize = (usize::MAX - ((nth_root>>1 as usize)-1).leading_zeros() as usize) as usize; + let log_nth_root_half: u32 = usize::BITS - ((nth_root>>1 as usize)-1).leading_zeros(); for i in 1..(nth_root>>1) as usize{ - let i_rev_prev: usize = (i-1).reverse_bits() >> (usize::MAX - log_nth_root_half) as usize; - let i_rev_next: usize = i.reverse_bits() >> (usize::MAX - log_nth_root_half) as usize; + let i_rev_prev: usize = (i-1).reverse_bits() >> (usize::BITS - log_nth_root_half); + let i_rev_next: usize = i.reverse_bits() >> (usize::BITS - log_nth_root_half); psi_forward_rev[i_rev_next] = prime.montgomery.mul_internal(psi_forward_rev[i_rev_prev], psi_mont); psi_backward_rev[i_rev_next] = prime.montgomery.mul_internal(psi_backward_rev[i_rev_prev], psi_inv_mont); } - let n_inv: Montgomery = prime.montgomery.pow(prime.montgomery.prepare(nth_root>>1), prime.phi-1); - Self{ prime: prime, psi_forward_rev: psi_forward_rev, psi_backward_rev: psi_backward_rev, - n_inv: n_inv, } } } \ No newline at end of file diff --git a/src/modulus/barrett.rs b/src/modulus/barrett.rs index 2a625f3..6b1bfec 100644 --- a/src/modulus/barrett.rs +++ b/src/modulus/barrett.rs @@ -25,10 +25,9 @@ impl BarrettPrecomp{ impl BarrettPrecomp{ pub fn new(q: u64) -> BarrettPrecomp { - let mut big_r = BigUint::parse_bytes(b"100000000000000000000000000000000", 16).unwrap(); - big_r = big_r / BigUint::from(q); - let lo = (&big_r & BigUint::from(u64::MAX)).to_u64().unwrap(); - let hi = (big_r >> 64u64).to_u64().unwrap(); + let big_r: BigUint = (BigUint::from(1 as usize)<<((u64::BITS<<1) as usize)) / BigUint::from(q); + let lo: u64 = (&big_r & BigUint::from(u64::MAX)).to_u64().unwrap(); + let hi: u64 = (big_r >> u64::BITS).to_u64().unwrap(); Self{q, lo, hi} }