|
|
@ -7,10 +7,21 @@ use crate::ntt::NTT as NTT_p; |
|
|
|
|
|
|
|
// const P0: u64 = 4611686018427387905;
|
|
|
|
// const P1: u64 = 4611686018326724609;
|
|
|
|
const P0: u64 = 8070449433331580929;
|
|
|
|
const P1: u64 = 8070450532384645121;
|
|
|
|
// const P0: u64 = (1 << 60) - (1 << 28) + 1;
|
|
|
|
// const P1: u64 = (1 << 60) - (1 << 27) + 1;
|
|
|
|
|
|
|
|
// const P0: u64 = 8070449433331580929; // max use 1<<62 -1
|
|
|
|
// const P1: u64 = 8070450532384645121;
|
|
|
|
|
|
|
|
// const P0: u64 = 0x80000000080001; // max use 1<<55 -1
|
|
|
|
// const P1: u64 = 0x80000000130001;
|
|
|
|
|
|
|
|
// const P0: u64 = ((1u128 << 64) - (1u128 << 28) + 1u128) as u64;
|
|
|
|
// const P1: u64 = ((1u128 << 64) - (1u128 << 27) + 1u128) as u64;
|
|
|
|
// const P2: u64 = ((1u128 << 64) - (1u128 << 26) + 1u128) as u64;
|
|
|
|
|
|
|
|
const P0: u64 = ((1u128 << 32) - (1u128 << 18) + 1u128) as u64;
|
|
|
|
const P1: u64 = ((1u128 << 32) - (1u128 << 17) + 1u128) as u64;
|
|
|
|
const P2: u64 = ((1u128 << 32) - (1u128 << 16) + 1u128) as u64;
|
|
|
|
// const P1: u64 = ((1u128 << 64) - (1u128 << 27) + 1u128) as u64;
|
|
|
|
// const P2: u64 = (1 << 60) - (1 << 26) + 1;
|
|
|
|
|
|
|
|
#[derive(Debug)]
|
|
|
@ -23,7 +34,7 @@ impl NTT { |
|
|
|
) -> (
|
|
|
|
Vec<u64>,
|
|
|
|
Vec<u64>,
|
|
|
|
// Vec<u64>,
|
|
|
|
Vec<u64>,
|
|
|
|
// Vec<u64>,
|
|
|
|
// Vec<u64>,
|
|
|
|
// Vec<u64>,
|
|
|
@ -32,23 +43,27 @@ impl NTT { |
|
|
|
// TODO ensure that: a_i <P0
|
|
|
|
|
|
|
|
// apply modulus p_i
|
|
|
|
let a_0: Vec<u64> = a.iter().map(|a_i| (a_i % P0 + P0) % P0).collect();
|
|
|
|
let a_1: Vec<u64> = a.iter().map(|a_i| (a_i % P1 + P1) % P1).collect();
|
|
|
|
let a_0: Vec<u64> = a.iter().map(|a_i| a_i % P0).collect();
|
|
|
|
let a_1: Vec<u64> = a.iter().map(|a_i| a_i % P1).collect();
|
|
|
|
let a_2: Vec<u64> = a.iter().map(|a_i| a_i % P2).collect();
|
|
|
|
// let a_0: Vec<u64> = a.iter().map(|a_i| (a_i % P0 + P0) % P0).collect();
|
|
|
|
// let a_1: Vec<u64> = a.iter().map(|a_i| (a_i % P1 + P1) % P1).collect();
|
|
|
|
// let a_2: Vec<u64> = a.iter().map(|a_i| (a_i % P2 + P2) % P2).collect();
|
|
|
|
// let a_3: Vec<u64> = a.iter().map(|a_i| (a_i % P3 + P3) % P3).collect();
|
|
|
|
// let a_4: Vec<u64> = a.iter().map(|a_i| (a_i % P4 + P4) % P4).collect();
|
|
|
|
// let a_5: Vec<u64> = a.iter().map(|a_i| (a_i % P5 + P5) % P5).collect();
|
|
|
|
// let a_6: Vec<u64> = a.iter().map(|a_i| (a_i % P6 + P6) % P6).collect();
|
|
|
|
|
|
|
|
dbg!(&a_0, &a_1, &a_2);
|
|
|
|
let r_0 = NTT_p::ntt(P0, n, &a_0);
|
|
|
|
let r_1 = NTT_p::ntt(P1, n, &a_1);
|
|
|
|
// let r_2 = NTT_p::ntt(P2, n, &a_2);
|
|
|
|
let r_2 = NTT_p::ntt(P2, n, &a_2);
|
|
|
|
// let r_3 = NTT_p::ntt(P3, n, &a_3);
|
|
|
|
// let r_4 = NTT_p::ntt(P4, n, &a_4);
|
|
|
|
// let r_5 = NTT_p::ntt(P5, n, &a_5);
|
|
|
|
// let r_6 = NTT_p::ntt(P6, n, &a_6);
|
|
|
|
|
|
|
|
(r_0, r_1) //, r_2) //, r_3, r_4) //, r_5, r_6)
|
|
|
|
(r_0, r_1, r_2) //, r_3, r_4) //, r_5, r_6)
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn intt(
|
|
|
@ -56,7 +71,7 @@ impl NTT { |
|
|
|
r: &(
|
|
|
|
Vec<u64>,
|
|
|
|
Vec<u64>,
|
|
|
|
// Vec<u64>,
|
|
|
|
Vec<u64>,
|
|
|
|
// Vec<u64>,
|
|
|
|
// Vec<u64>,
|
|
|
|
// Vec<u64>,
|
|
|
@ -65,7 +80,7 @@ impl NTT { |
|
|
|
) -> Vec<u64> {
|
|
|
|
let a_0 = NTT_p::intt(P0, n, &r.0);
|
|
|
|
let a_1 = NTT_p::intt(P1, n, &r.1);
|
|
|
|
// let a_2 = NTT_p::intt(P2, n, &r.2);
|
|
|
|
let a_2 = NTT_p::intt(P2, n, &r.2);
|
|
|
|
// let a_3 = NTT_p::intt(P3, n, &r.3);
|
|
|
|
// let a_4 = NTT_p::intt(P4, n, &r.4);
|
|
|
|
// let a_5 = NTT_p::intt(P5, n, &r.5);
|
|
|
@ -75,28 +90,28 @@ impl NTT { |
|
|
|
// let inv_p1_mod_p2: u128 = inv_mod_u64(p1 % p2, p2) as u128;
|
|
|
|
// const INV_P1_MOD_P2: u128 = 4895217125691974194;
|
|
|
|
|
|
|
|
reconstruct(a_0, a_1) //, a_2) // , a_3, a_4) //, a_5, a_6)
|
|
|
|
reconstruct(a_0, a_1, a_2) // , a_3, a_4) //, a_5, a_6)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn reconstruct(
|
|
|
|
a0: Vec<u64>,
|
|
|
|
a1: Vec<u64>,
|
|
|
|
// a2: Vec<u64>,
|
|
|
|
a2: Vec<u64>,
|
|
|
|
// a_3: Vec<u64>,
|
|
|
|
// a_4: Vec<u64>,
|
|
|
|
// a_5: Vec<u64>,
|
|
|
|
// a_6: Vec<u64>,
|
|
|
|
) -> Vec<u64> {
|
|
|
|
// let Q = P0 as u128 * P1 as u128;
|
|
|
|
|
|
|
|
// y_i = q/q_i
|
|
|
|
let y0 = ((u64::MAX as u128 + 1) / P0 as u128);
|
|
|
|
let y1 = ((u64::MAX as u128 + 1) / P1 as u128);
|
|
|
|
// let y0 = ((u64::MAX as u128 + 1) / P0 as u128);
|
|
|
|
// let y1 = ((u64::MAX as u128 + 1) / P1 as u128);
|
|
|
|
// let y2 = ((u64::MAX as u128 + 1) / P2 as u128) as u64;
|
|
|
|
// let y0: u128 = P1 as u128 * P2 as u128; // N_i =Q/P0 = P1*P2
|
|
|
|
// let y1: u128 = P0 as u128 * P2 as u128;
|
|
|
|
// let y2: u128 = P0 as u128 * P1 as u128;
|
|
|
|
// let y0: u128 = P1 as u128; // N_i =Q/P0 = P1*P2
|
|
|
|
// let y1: u128 = P0 as u128;
|
|
|
|
let y0: u128 = P1 as u128 * P2 as u128; // N_i =Q/P0 = P1*P2
|
|
|
|
let y1: u128 = P0 as u128 * P2 as u128;
|
|
|
|
let y2: u128 = P0 as u128 * P1 as u128;
|
|
|
|
// let y0 = (Q / P0 as u128) as u64;
|
|
|
|
// let y1 = (Q / P1 as u128) as u64;
|
|
|
|
// let y2 = ((u64::MAX as u128 + 1) / P2 as u128) as u64;
|
|
|
@ -104,9 +119,10 @@ fn reconstruct( |
|
|
|
// let y4 = ((u64::MAX as u128 + 1) / P4 as u128) as u64;
|
|
|
|
|
|
|
|
// y_i^-1 mod q_i = z_i
|
|
|
|
dbg!(P0, y0);
|
|
|
|
let z0: u128 = inv_mod(P0 as u128, y0); // M_i = N_i^-1 mod q_i
|
|
|
|
let z1: u128 = inv_mod(P1 as u128, y1);
|
|
|
|
// let z2: u128 = inv_mod(P2 as u128, y2);
|
|
|
|
let z2: u128 = inv_mod(P2 as u128, y2);
|
|
|
|
// let y2_inv = inv_mod(P2 as u128, y2);
|
|
|
|
// let y3_inv = inv_mod(P3 as u128, y3);
|
|
|
|
// let y4_inv = inv_mod(P4 as u128, y4);
|
|
|
@ -115,25 +131,45 @@ fn reconstruct( |
|
|
|
// aux = (a2 - a1) * m1 mod q2
|
|
|
|
// a = a1 + (q1 * m1) * aux
|
|
|
|
|
|
|
|
let m1 = inv_mod(P1 as u128, P0 as u128) as u64; // P0^-1 mod P1
|
|
|
|
let aux: Vec<u64> = itertools::zip_eq(a0.clone(), a1.clone())
|
|
|
|
.map(|(a0_i, a1_i)| ((a1_i - a0_i) * m1) % P1)
|
|
|
|
.collect();
|
|
|
|
let a: Vec<u64> = itertools::zip_eq(a0, aux)
|
|
|
|
// .map(|(a1_i, aux_i)| a1_i + (P1 * m1) * aux_i)
|
|
|
|
// .map(|(a0_i, aux_i)| a0_i + (P0 * m1) * aux_i)
|
|
|
|
.map(|(a0_i, aux_i)| a0_i + ((P0 * m1) % P1) * aux_i)
|
|
|
|
.collect();
|
|
|
|
a
|
|
|
|
// m1 == z1
|
|
|
|
// let m1 = inv_mod(P1 as u128, P0 as u128); // P0^-1 mod P1
|
|
|
|
// let aux: Vec<u64> = itertools::zip_eq(a0.clone(), a1.clone())
|
|
|
|
// .map(|(a0_i, a1_i)| ((a1_i - a0_i) * m1) % P1)
|
|
|
|
// .collect();
|
|
|
|
let p0: u128 = P0 as u128;
|
|
|
|
let p1: u128 = P1 as u128;
|
|
|
|
let p2: u128 = P2 as u128;
|
|
|
|
// let a: Vec<u64> = itertools::zip_eq(a0, a1)
|
|
|
|
// // .map(|(a1_i, aux_i)| a1_i + (P1 * m1) * aux_i)
|
|
|
|
// // .map(|(a0_i, aux_i)| a0_i + (P0 * m1) * aux_i)
|
|
|
|
// .map(|(a0_i, a1_i)| a0_i as u128 + ((p0 * z1) % p1) * (((a1_i - a0_i) as u128 * z1) % p1))
|
|
|
|
// .map(|v| v as u64)
|
|
|
|
// // let a: Vec<u64> = itertools::zip_eq(a0, a1)
|
|
|
|
// // .map(|(a0_i, a1_i)| {
|
|
|
|
// // ((((a0_i as u128 * z0) % P0 as u128) * y0) + (a1_i as u128 * z1 % P1 as u128) * y1)
|
|
|
|
// // as u64
|
|
|
|
// // })
|
|
|
|
// .collect();
|
|
|
|
// a
|
|
|
|
|
|
|
|
// dbg!(a0[0] as u128);
|
|
|
|
// dbg!(a0[0] as u128 * y0);
|
|
|
|
// dbg!(a0[0] as u128 * z0);
|
|
|
|
// dbg!(a0[0] as u128 * y0 * z0);
|
|
|
|
// let a: Vec<u128> = itertools::multizip((a0, a1, a2))
|
|
|
|
// .map(|(a0_i, a1_i, a2_i)| {
|
|
|
|
// a0_i as u128 * y0 * z0 + a1_i as u128 * y1 * z1 + a2_i as u128 * y2 * z2
|
|
|
|
// })
|
|
|
|
// .collect();
|
|
|
|
dbg!(&y0, &y1, &y2);
|
|
|
|
dbg!(&z0, &z1, &z2);
|
|
|
|
let Q = P0 as u128 * P1 as u128 * P2 as u128;
|
|
|
|
let a: Vec<u64> = itertools::multizip((a0, a1, a2))
|
|
|
|
.map(|(a0_i, a1_i, a2_i)| {
|
|
|
|
(a0_i as u128 * y0 * z0)// % Q
|
|
|
|
+ (a1_i as u128 * y1 * z1)// % Q
|
|
|
|
+ (a2_i as u128 * y2 * z2) // % Q
|
|
|
|
})
|
|
|
|
.map(|v| v % Q)
|
|
|
|
.map(|v| v as u32)
|
|
|
|
.map(|v| v as u64)
|
|
|
|
.collect();
|
|
|
|
a
|
|
|
|
// dbg!(&a);
|
|
|
|
// let Q = y2 * P2 as u128;
|
|
|
|
// let a: Vec<u128> = a.iter().map(|a_i| a_i % Q).collect();
|
|
|
@ -202,11 +238,75 @@ fn exp_mod(q: u128, x: u128, k: u128) -> u128 { |
|
|
|
}
|
|
|
|
r
|
|
|
|
}
|
|
|
|
/// returns x^-1 mod Q
|
|
|
|
fn inv_mod(q: u128, x: u128) -> u128 {
|
|
|
|
// by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q
|
|
|
|
exp_mod(q, x, q - 2)
|
|
|
|
/// returns x^-1 mod Q, assuming x and Q are coprime, generally Q is prime
|
|
|
|
// fn inv_mod(q: u128, x: u128) -> u128 {
|
|
|
|
// // by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q
|
|
|
|
// // exp_mod(q, x, q - 2)
|
|
|
|
// exp_mod(q, x, q - 2)
|
|
|
|
// }
|
|
|
|
|
|
|
|
fn inv_mod(m: u128, a: u128) -> u128 {
|
|
|
|
// if m == 1 {
|
|
|
|
// return Some(0);
|
|
|
|
// }
|
|
|
|
|
|
|
|
let mut m = m.clone();
|
|
|
|
let mut a = a.clone();
|
|
|
|
let m0 = m.clone();
|
|
|
|
let mut x0: i128 = 0;
|
|
|
|
let mut x1: i128 = 1;
|
|
|
|
|
|
|
|
while a > 1 {
|
|
|
|
let q = a / m;
|
|
|
|
let t = m.clone();
|
|
|
|
|
|
|
|
m = a % m;
|
|
|
|
a = t.clone();
|
|
|
|
|
|
|
|
let t = x0;
|
|
|
|
x0 = x1 - (q as i128) * x0;
|
|
|
|
x1 = t;
|
|
|
|
}
|
|
|
|
|
|
|
|
if x1 < 0 {
|
|
|
|
x1 += m0 as i128;
|
|
|
|
}
|
|
|
|
|
|
|
|
x1 as u128
|
|
|
|
}
|
|
|
|
// fn inv_mod(m: u128, a: u128) -> u128 {
|
|
|
|
// let mut m = m.clone();
|
|
|
|
// let mut a = a.clone();
|
|
|
|
// let m0 = m.clone();
|
|
|
|
// let mut x0 = 0;
|
|
|
|
// let mut x1 = 1;
|
|
|
|
//
|
|
|
|
// if m == 1 {
|
|
|
|
// return 0;
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// while a > 1 {
|
|
|
|
// let q = a / m;
|
|
|
|
//
|
|
|
|
// let mut t = m.clone();
|
|
|
|
//
|
|
|
|
// m = a % m;
|
|
|
|
// a = t.clone();
|
|
|
|
//
|
|
|
|
// t = x0.clone();
|
|
|
|
//
|
|
|
|
// x0 = x1 - q * x0;
|
|
|
|
//
|
|
|
|
// x1 = t.clone();
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// // if (x1 < 0) {
|
|
|
|
// // x1 = x1 + m0
|
|
|
|
// // }
|
|
|
|
//
|
|
|
|
// // return x1 % m0;
|
|
|
|
// return x1 % m0;
|
|
|
|
// }
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
mod tests {
|
|
|
@ -215,10 +315,50 @@ mod tests { |
|
|
|
|
|
|
|
use anyhow::Result;
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_inv_mod() -> Result<()> {
|
|
|
|
let x = 3;
|
|
|
|
let x_inv = inv_mod(P0 as u128, x);
|
|
|
|
dbg!(&P0);
|
|
|
|
dbg!(&x_inv);
|
|
|
|
|
|
|
|
// let r = x_inv * x;
|
|
|
|
// dbg!(&r);
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_reconstruct() -> Result<()> {
|
|
|
|
let n: usize = 16;
|
|
|
|
|
|
|
|
use rand::Rng;
|
|
|
|
let mut rng = rand::thread_rng();
|
|
|
|
let a: Vec<u64> = (0..n)
|
|
|
|
// .map(|_| rng.gen_range(0..(1 << 64)))
|
|
|
|
.map(|_| rng.gen_range(0..(1 << 32)))
|
|
|
|
// .map(|_| rng.gen_range(0..16))
|
|
|
|
// .map(|_| rng.sample(rand::distributions::Standard))
|
|
|
|
.collect();
|
|
|
|
|
|
|
|
dbg!(a.len());
|
|
|
|
|
|
|
|
let a_0: Vec<u64> = a.iter().map(|a_i| a_i % P0).collect();
|
|
|
|
let a_1: Vec<u64> = a.iter().map(|a_i| a_i % P1).collect();
|
|
|
|
let a_2: Vec<u64> = a.iter().map(|a_i| a_i % P2).collect();
|
|
|
|
dbg!(&a_0, &a_1, &a_2);
|
|
|
|
let a_reconstructed = reconstruct(a_0, a_1, a_2);
|
|
|
|
|
|
|
|
dbg!(&a_reconstructed);
|
|
|
|
assert_eq!(a_reconstructed, a);
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_dbg() -> Result<()> {
|
|
|
|
println!("{}", 1u128 << 64);
|
|
|
|
let n: usize = 16;
|
|
|
|
let n: usize = 2;
|
|
|
|
|
|
|
|
println!("{}", P0);
|
|
|
|
println!("{}", P1);
|
|
|
@ -231,38 +371,30 @@ mod tests { |
|
|
|
use rand::Rng;
|
|
|
|
let mut rng = rand::thread_rng();
|
|
|
|
let a: Vec<u64> = (0..n)
|
|
|
|
// .map(|_| rng.gen_range(0..=(1u64 << 57) - 1) - (1u64 << 56))
|
|
|
|
.map(|_| rng.gen_range(0..(1 << 61)))
|
|
|
|
// .map(|_| rng.gen_range(0..(1 << 64)))
|
|
|
|
// .map(|_| rng.gen_range(0..(1 << 32)))
|
|
|
|
.map(|_| rng.gen_range(0..16))
|
|
|
|
// .map(|_| rng.sample(rand::distributions::Standard))
|
|
|
|
.collect();
|
|
|
|
|
|
|
|
dbg!(a.len());
|
|
|
|
|
|
|
|
let a_ntt = NTT::ntt(n, &a);
|
|
|
|
dbg!(&a_ntt);
|
|
|
|
|
|
|
|
let a_intt = NTT::intt(n, &a_ntt);
|
|
|
|
let a_0: Vec<u64> = a.iter().map(|a_i| a_i % P0).collect();
|
|
|
|
let a_1: Vec<u64> = a.iter().map(|a_i| a_i % P1).collect();
|
|
|
|
let a_2: Vec<u64> = a.iter().map(|a_i| a_i % P2).collect();
|
|
|
|
dbg!(&a_0, &a_1, &a_2);
|
|
|
|
// let a_0 = vec![3];
|
|
|
|
// let a_1 = vec![3];
|
|
|
|
// let a_2 = vec![3];
|
|
|
|
let a_intt = reconstruct(a_0, a_1, a_2);
|
|
|
|
// let a_ntt = NTT::ntt(n, &a);
|
|
|
|
// dbg!(&a_ntt);
|
|
|
|
//
|
|
|
|
// let a_intt = NTT::intt(n, &a_ntt);
|
|
|
|
|
|
|
|
dbg!(&a_intt);
|
|
|
|
assert_eq!(a_intt, a);
|
|
|
|
|
|
|
|
// unnecessary:
|
|
|
|
// let a: Vec<u64> = vec![2u64, 4, 6, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
|
|
|
|
// let a_ntt = NTT_p::ntt(P0, n, &a);
|
|
|
|
// dbg!(&a_ntt);
|
|
|
|
// let a_intt = NTT_p::intt(P0, n, &a_ntt);
|
|
|
|
// dbg!(&a_intt);
|
|
|
|
|
|
|
|
// NOTE: *n_inv is already done in the intt method.
|
|
|
|
|
|
|
|
// Multiplies the values by the inverse of the polynomial modulo the NTT modulus
|
|
|
|
// let n_inv = inv_mod(P0 as u128, n as u64); // n^-1 mod p0
|
|
|
|
// let a_new: Vec<u64> = a_0_intt
|
|
|
|
// .iter()
|
|
|
|
// .map(|a_i| ((*a_i as u128 * n_inv as u128) % P0 as u128) as u64)
|
|
|
|
// .collect();
|
|
|
|
|
|
|
|
// assert_eq!(a_intt, a);
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
}
|