From f95f0389cf172d157e145e8df924d16193c2e1b2 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Sun, 24 Aug 2025 17:03:38 +0000 Subject: [PATCH] [composite-ntt] reconstruct u32 works --- arith/src/lib.rs | 1 + arith/src/ntt_u62.rs | 269 +++++++++++++++++++++++++++++++++++++++++++ arith/src/ntt_u64.rs | 262 ++++++++++++++++++++++++++++++----------- 3 files changed, 467 insertions(+), 65 deletions(-) create mode 100644 arith/src/ntt_u62.rs diff --git a/arith/src/lib.rs b/arith/src/lib.rs index f95512c..9a64292 100644 --- a/arith/src/lib.rs +++ b/arith/src/lib.rs @@ -17,6 +17,7 @@ pub mod tuple_ring; // mod naive_ntt; // note: for dev only pub mod ntt; +pub mod ntt_u62; pub mod ntt_u64; // expose objects diff --git a/arith/src/ntt_u62.rs b/arith/src/ntt_u62.rs new file mode 100644 index 0000000..0e6d8f0 --- /dev/null +++ b/arith/src/ntt_u62.rs @@ -0,0 +1,269 @@ +//! This file implements the wrapper on top of the ntt.rs to be able to compute +//! the NTT for non-prime modulus, specifically for modulus 2^64 (for u64). + +use crate::ntt::NTT as NTT_p; + +// const P0: u64 = 17293822569241362433; + +// const P0: u64 = 4611686018427387905; +// const P1: u64 = 4611686018326724609; + +const P0: u64 = 8070449433331580929; // max use 1<<62 +const P1: u64 = 8070450532384645121; + +// const P0: u64 = 0x80000000080001; // max use 1<<55 +// const P1: u64 = 0x80000000130001; + +// const P0: u64 = ((1u128 << 64) - (1u128 << 28) + 1u128) as u64; +// const P1: u64 = ((1u128 << 64) - (1u128 << 27) + 1u128) as u64; +// const P0: u64 = ((1u128 << 64) - (1u128 << 28) + 1u128) as u64; +// const P1: u64 = ((1u128 << 64) - (1u128 << 27) + 1u128) as u64; +// const P2: u64 = (1 << 60) - (1 << 26) + 1; + +#[derive(Debug)] +pub struct NTT {} + +impl NTT { + pub fn ntt( + n: usize, + a: &Vec, + ) -> ( + Vec, + Vec, + // Vec, + // Vec, + // Vec, + // Vec, + // Vec, + ) { + // TODO ensure that: a_i = a.iter().map(|a_i| (a_i % P0 + P0) % P0).collect(); + let a_1: Vec = a.iter().map(|a_i| (a_i % P1 + P1) % P1).collect(); + // let a_2: Vec = a.iter().map(|a_i| (a_i % P2 + P2) % P2).collect(); + // let a_3: Vec = a.iter().map(|a_i| (a_i % P3 + P3) % P3).collect(); + // let a_4: Vec = a.iter().map(|a_i| (a_i % P4 + P4) % P4).collect(); + // let a_5: Vec = a.iter().map(|a_i| (a_i % P5 + P5) % P5).collect(); + // let a_6: Vec = a.iter().map(|a_i| (a_i % P6 + P6) % P6).collect(); + + let r_0 = NTT_p::ntt(P0, n, &a_0); + let r_1 = NTT_p::ntt(P1, n, &a_1); + // let r_2 = NTT_p::ntt(P2, n, &a_2); + // let r_3 = NTT_p::ntt(P3, n, &a_3); + // let r_4 = NTT_p::ntt(P4, n, &a_4); + // let r_5 = NTT_p::ntt(P5, n, &a_5); + // let r_6 = NTT_p::ntt(P6, n, &a_6); + + (r_0, r_1) //, r_2) //, r_3, r_4) //, r_5, r_6) + } + + pub fn intt( + n: usize, + r: &( + Vec, + Vec, + // Vec, + // Vec, + // Vec, + // Vec, + // Vec, + ), + ) -> Vec { + let a_0 = NTT_p::intt(P0, n, &r.0); + let a_1 = NTT_p::intt(P1, n, &r.1); + // let a_2 = NTT_p::intt(P2, n, &r.2); + // let a_3 = NTT_p::intt(P3, n, &r.3); + // let a_4 = NTT_p::intt(P4, n, &r.4); + // let a_5 = NTT_p::intt(P5, n, &r.5); + // let a_6 = NTT_p::intt(P6, n, &r.6); + + // Garner CRT for two moduli: combine (r1 mod p1, r2 mod p2) -> Z/(p1*p2) + // let inv_p1_mod_p2: u128 = inv_mod_u64(p1 % p2, p2) as u128; + // const INV_P1_MOD_P2: u128 = 4895217125691974194; + + reconstruct(a_0, a_1) //, a_2) // , a_3, a_4) //, a_5, a_6) + } +} + +fn reconstruct( + a0: Vec, + a1: Vec, + // a2: Vec, + // a_3: Vec, + // a_4: Vec, + // a_5: Vec, + // a_6: Vec, +) -> Vec { + // let Q = P0 as u128 * P1 as u128; + + // y_i = q/q_i + // let y0 = ((u64::MAX as u128 + 1) / P0 as u128); + // let y1 = ((u64::MAX as u128 + 1) / P1 as u128); + // let y2 = ((u64::MAX as u128 + 1) / P2 as u128) as u64; + let y0: u128 = P1 as u128; // N_i =Q/P0 = P1*P2 + let y1: u128 = P0 as u128; + // let y1: u128 = P0 as u128 * P2 as u128; + // let y0: u128 = P1 as u128 * P2 as u128; // N_i =Q/P0 = P1*P2 + // let y1: u128 = P0 as u128 * P2 as u128; + // let y2: u128 = P0 as u128 * P1 as u128; + // let y0 = (Q / P0 as u128) as u64; + // let y1 = (Q / P1 as u128) as u64; + // let y2 = ((u64::MAX as u128 + 1) / P2 as u128) as u64; + // let y3 = ((u64::MAX as u128 + 1) / P3 as u128) as u64; + // let y4 = ((u64::MAX as u128 + 1) / P4 as u128) as u64; + + // y_i^-1 mod q_i = z_i + let z0: u128 = inv_mod(P0 as u128, y0); // M_i = N_i^-1 mod q_i + let z1: u128 = inv_mod(P1 as u128, y1); + // let z2: u128 = inv_mod(P2 as u128, y2); + // let y2_inv = inv_mod(P2 as u128, y2); + // let y3_inv = inv_mod(P3 as u128, y3); + // let y4_inv = inv_mod(P4 as u128, y4); + + // m1 = q1^-1 mod q2 + // aux = (a2 - a1) * m1 mod q2 + // a = a1 + (q1 * m1) * aux + + /* + let m1 = inv_mod(P1 as u128, P0 as u128) as u64; // P0^-1 mod P1 + let aux: Vec = itertools::zip_eq(a0.clone(), a1.clone()) + .map(|(a0_i, a1_i)| ((a1_i - a0_i) * m1) % P1) + .collect(); + let a: Vec = itertools::zip_eq(a0, aux) + // .map(|(a1_i, aux_i)| a1_i + (P1 * m1) * aux_i) + // .map(|(a0_i, aux_i)| a0_i + (P0 * m1) * aux_i) + .map(|(a0_i, aux_i)| a0_i + ((P0 * m1) % P1) * aux_i) + .collect(); + a + */ + let p0: u128 = P0 as u128; + let p1: u128 = P1 as u128; + let a: Vec = itertools::zip_eq(a0, a1) + .map(|(a0_i, a1_i)| a0_i as u128 + ((p0 * z1) % p1) * (((a1_i - a0_i) as u128 * z1) % p1)) + .map(|v| v as u64) + .collect(); + a + // dbg!(a0[0] as u128); + // dbg!(a0[0] as u128 * y0); + // dbg!(a0[0] as u128 * z0); + // dbg!(a0[0] as u128 * y0 * z0); + // let a: Vec = itertools::multizip((a0, a1, a2)) + // .map(|(a0_i, a1_i, a2_i)| { + // a0_i as u128 * y0 * z0 + a1_i as u128 * y1 * z1 + a2_i as u128 * y2 * z2 + // }) + // .collect(); + // dbg!(&a); + // let Q = y2 * P2 as u128; + // let a: Vec = a.iter().map(|a_i| a_i % Q).collect(); + // dbg!(&a); + // let q64 = 1_u128 << 64; + // let a: Vec = a.iter().map(|a_i| (a_i % q64) as u64).collect(); + // a + + /* + // x_i*z_i mod q_i + let r0: Vec = a_0.iter().map(|a_i| ((a_i * z0) % P0) * y0).collect(); + let r1: Vec = a_1.iter().map(|a_i| ((a_i * z1) % P1) * y1).collect(); + // let r0: Vec = a_0.iter().map(|a_i| ((a_i * z0) % P0) * y0).collect(); + // let r1: Vec = a_1.iter().map(|a_i| ((a_i * z1) % P1) * y1).collect(); + // let r2: Vec = a_2.iter().map(|a_i| ((a_i * y2_inv) % P2) * y2).collect(); + // let r3: Vec = a_3.iter().map(|a_i| ((a_i * y3_inv) % P3) * y3).collect(); + // let r4: Vec = a_4.iter().map(|a_i| ((a_i * y4_inv) % P4) * y4).collect(); + + let r: Vec = itertools::multizip((r0.iter(), r1.iter())) + .map(|(a, b)| a + b) + .collect(); + // let r = r0; + // + dbg!(&r); + + let p1p2: u128 = (P0 as u128) * (P1 as u128); + // let p1p2_inv: u128 = inv_mod((P0 % P1) as u128, P1) as u128; + let p1p2_inv: u128 = inv_mod((P0) as u128, P1) as u128; + dbg!(&p1p2); + dbg!(&p1p2_inv); + // let p1p2: u128 = P0 as u128 / 2; // PIHALF + let r = r + .iter() + .map(|c_i_u64| { + let c_i = *c_i_u64 as u128; + if c_i * 2 >= p1p2 { + // if c_i >= p1p2 { + c_i.wrapping_sub(p1p2) as u64 + } else { + c_i as u64 + } + }) + .collect(); + // let r: Vec = itertools::multizip((r0.iter(), r1.iter(), r2.iter(), r3.iter(), r4.iter())) + // .map(|(a, b, c, d, e)| a + b + c + d + e) + // .collect(); + // let mut r = a_0 + y0_inv + a_1 * y1_inv + a_2 * y2_inv + a_3 * y3_inv + a_4 * y4_inv; + + r + */ +} + +fn exp_mod(q: u128, x: u128, k: u128) -> u128 { + // work on u128 to avoid overflow + let mut r = 1u128; + let mut x = x.clone(); + let mut k = k.clone(); + x = x % q; + // exponentiation by square strategy + while k > 0 { + if k % 2 == 1 { + r = (r * x) % q; + } + x = (x * x) % q; + k /= 2; + } + r +} +/// returns x^-1 mod Q +fn inv_mod(q: u128, x: u128) -> u128 { + // by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q + exp_mod(q, x, q - 2) +} + +#[cfg(test)] +mod tests { + use super::*; + use rand_distr::Distribution; + + use anyhow::Result; + + #[test] + fn test_dbg() -> Result<()> { + println!("{}", 1u128 << 64); + let n: usize = 16; + + println!("{}", P0); + println!("{}", P1); + // let q = 1u128 << 64; + // assert!(P0 as u128 * P1 as u128 > (n as u128 * (q * q)) / 2); + + // let a: Vec = vec![1u64, 2, 3, 4]; + // let a: Vec = vec![1u64, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + // let a: Vec = vec![9u64, 8, 7, 6, 0, 9999, 0, 0, 0, 0, 0, 0, 6, 7, 8, 9]; + use rand::Rng; + let mut rng = rand::thread_rng(); + let a: Vec = (0..n) + // .map(|_| rng.gen_range(0..=(1u64 << 57) - 1) - (1u64 << 56)) + .map(|_| rng.gen_range(0..(1 << 62))) + .collect(); + + dbg!(a.len()); + + let a_ntt = NTT::ntt(n, &a); + dbg!(&a_ntt); + + let a_intt = NTT::intt(n, &a_ntt); + + dbg!(&a_intt); + assert_eq!(a_intt, a); + + Ok(()) + } +} diff --git a/arith/src/ntt_u64.rs b/arith/src/ntt_u64.rs index d58235b..a5b7445 100644 --- a/arith/src/ntt_u64.rs +++ b/arith/src/ntt_u64.rs @@ -7,10 +7,21 @@ use crate::ntt::NTT as NTT_p; // const P0: u64 = 4611686018427387905; // const P1: u64 = 4611686018326724609; -const P0: u64 = 8070449433331580929; -const P1: u64 = 8070450532384645121; -// const P0: u64 = (1 << 60) - (1 << 28) + 1; -// const P1: u64 = (1 << 60) - (1 << 27) + 1; + +// const P0: u64 = 8070449433331580929; // max use 1<<62 -1 +// const P1: u64 = 8070450532384645121; + +// const P0: u64 = 0x80000000080001; // max use 1<<55 -1 +// const P1: u64 = 0x80000000130001; + +// const P0: u64 = ((1u128 << 64) - (1u128 << 28) + 1u128) as u64; +// const P1: u64 = ((1u128 << 64) - (1u128 << 27) + 1u128) as u64; +// const P2: u64 = ((1u128 << 64) - (1u128 << 26) + 1u128) as u64; + +const P0: u64 = ((1u128 << 32) - (1u128 << 18) + 1u128) as u64; +const P1: u64 = ((1u128 << 32) - (1u128 << 17) + 1u128) as u64; +const P2: u64 = ((1u128 << 32) - (1u128 << 16) + 1u128) as u64; +// const P1: u64 = ((1u128 << 64) - (1u128 << 27) + 1u128) as u64; // const P2: u64 = (1 << 60) - (1 << 26) + 1; #[derive(Debug)] @@ -23,7 +34,7 @@ impl NTT { ) -> ( Vec, Vec, - // Vec, + Vec, // Vec, // Vec, // Vec, @@ -32,23 +43,27 @@ impl NTT { // TODO ensure that: a_i = a.iter().map(|a_i| (a_i % P0 + P0) % P0).collect(); - let a_1: Vec = a.iter().map(|a_i| (a_i % P1 + P1) % P1).collect(); + let a_0: Vec = a.iter().map(|a_i| a_i % P0).collect(); + let a_1: Vec = a.iter().map(|a_i| a_i % P1).collect(); + let a_2: Vec = a.iter().map(|a_i| a_i % P2).collect(); + // let a_0: Vec = a.iter().map(|a_i| (a_i % P0 + P0) % P0).collect(); + // let a_1: Vec = a.iter().map(|a_i| (a_i % P1 + P1) % P1).collect(); // let a_2: Vec = a.iter().map(|a_i| (a_i % P2 + P2) % P2).collect(); // let a_3: Vec = a.iter().map(|a_i| (a_i % P3 + P3) % P3).collect(); // let a_4: Vec = a.iter().map(|a_i| (a_i % P4 + P4) % P4).collect(); // let a_5: Vec = a.iter().map(|a_i| (a_i % P5 + P5) % P5).collect(); // let a_6: Vec = a.iter().map(|a_i| (a_i % P6 + P6) % P6).collect(); + dbg!(&a_0, &a_1, &a_2); let r_0 = NTT_p::ntt(P0, n, &a_0); let r_1 = NTT_p::ntt(P1, n, &a_1); - // let r_2 = NTT_p::ntt(P2, n, &a_2); + let r_2 = NTT_p::ntt(P2, n, &a_2); // let r_3 = NTT_p::ntt(P3, n, &a_3); // let r_4 = NTT_p::ntt(P4, n, &a_4); // let r_5 = NTT_p::ntt(P5, n, &a_5); // let r_6 = NTT_p::ntt(P6, n, &a_6); - (r_0, r_1) //, r_2) //, r_3, r_4) //, r_5, r_6) + (r_0, r_1, r_2) //, r_3, r_4) //, r_5, r_6) } pub fn intt( @@ -56,7 +71,7 @@ impl NTT { r: &( Vec, Vec, - // Vec, + Vec, // Vec, // Vec, // Vec, @@ -65,7 +80,7 @@ impl NTT { ) -> Vec { let a_0 = NTT_p::intt(P0, n, &r.0); let a_1 = NTT_p::intt(P1, n, &r.1); - // let a_2 = NTT_p::intt(P2, n, &r.2); + let a_2 = NTT_p::intt(P2, n, &r.2); // let a_3 = NTT_p::intt(P3, n, &r.3); // let a_4 = NTT_p::intt(P4, n, &r.4); // let a_5 = NTT_p::intt(P5, n, &r.5); @@ -75,28 +90,28 @@ impl NTT { // let inv_p1_mod_p2: u128 = inv_mod_u64(p1 % p2, p2) as u128; // const INV_P1_MOD_P2: u128 = 4895217125691974194; - reconstruct(a_0, a_1) //, a_2) // , a_3, a_4) //, a_5, a_6) + reconstruct(a_0, a_1, a_2) // , a_3, a_4) //, a_5, a_6) } } fn reconstruct( a0: Vec, a1: Vec, - // a2: Vec, + a2: Vec, // a_3: Vec, // a_4: Vec, // a_5: Vec, // a_6: Vec, ) -> Vec { - // let Q = P0 as u128 * P1 as u128; - // y_i = q/q_i - let y0 = ((u64::MAX as u128 + 1) / P0 as u128); - let y1 = ((u64::MAX as u128 + 1) / P1 as u128); + // let y0 = ((u64::MAX as u128 + 1) / P0 as u128); + // let y1 = ((u64::MAX as u128 + 1) / P1 as u128); // let y2 = ((u64::MAX as u128 + 1) / P2 as u128) as u64; - // let y0: u128 = P1 as u128 * P2 as u128; // N_i =Q/P0 = P1*P2 - // let y1: u128 = P0 as u128 * P2 as u128; - // let y2: u128 = P0 as u128 * P1 as u128; + // let y0: u128 = P1 as u128; // N_i =Q/P0 = P1*P2 + // let y1: u128 = P0 as u128; + let y0: u128 = P1 as u128 * P2 as u128; // N_i =Q/P0 = P1*P2 + let y1: u128 = P0 as u128 * P2 as u128; + let y2: u128 = P0 as u128 * P1 as u128; // let y0 = (Q / P0 as u128) as u64; // let y1 = (Q / P1 as u128) as u64; // let y2 = ((u64::MAX as u128 + 1) / P2 as u128) as u64; @@ -104,9 +119,10 @@ fn reconstruct( // let y4 = ((u64::MAX as u128 + 1) / P4 as u128) as u64; // y_i^-1 mod q_i = z_i + dbg!(P0, y0); let z0: u128 = inv_mod(P0 as u128, y0); // M_i = N_i^-1 mod q_i let z1: u128 = inv_mod(P1 as u128, y1); - // let z2: u128 = inv_mod(P2 as u128, y2); + let z2: u128 = inv_mod(P2 as u128, y2); // let y2_inv = inv_mod(P2 as u128, y2); // let y3_inv = inv_mod(P3 as u128, y3); // let y4_inv = inv_mod(P4 as u128, y4); @@ -115,25 +131,45 @@ fn reconstruct( // aux = (a2 - a1) * m1 mod q2 // a = a1 + (q1 * m1) * aux - let m1 = inv_mod(P1 as u128, P0 as u128) as u64; // P0^-1 mod P1 - let aux: Vec = itertools::zip_eq(a0.clone(), a1.clone()) - .map(|(a0_i, a1_i)| ((a1_i - a0_i) * m1) % P1) - .collect(); - let a: Vec = itertools::zip_eq(a0, aux) - // .map(|(a1_i, aux_i)| a1_i + (P1 * m1) * aux_i) - // .map(|(a0_i, aux_i)| a0_i + (P0 * m1) * aux_i) - .map(|(a0_i, aux_i)| a0_i + ((P0 * m1) % P1) * aux_i) - .collect(); - a + // m1 == z1 + // let m1 = inv_mod(P1 as u128, P0 as u128); // P0^-1 mod P1 + // let aux: Vec = itertools::zip_eq(a0.clone(), a1.clone()) + // .map(|(a0_i, a1_i)| ((a1_i - a0_i) * m1) % P1) + // .collect(); + let p0: u128 = P0 as u128; + let p1: u128 = P1 as u128; + let p2: u128 = P2 as u128; + // let a: Vec = itertools::zip_eq(a0, a1) + // // .map(|(a1_i, aux_i)| a1_i + (P1 * m1) * aux_i) + // // .map(|(a0_i, aux_i)| a0_i + (P0 * m1) * aux_i) + // .map(|(a0_i, a1_i)| a0_i as u128 + ((p0 * z1) % p1) * (((a1_i - a0_i) as u128 * z1) % p1)) + // .map(|v| v as u64) + // // let a: Vec = itertools::zip_eq(a0, a1) + // // .map(|(a0_i, a1_i)| { + // // ((((a0_i as u128 * z0) % P0 as u128) * y0) + (a1_i as u128 * z1 % P1 as u128) * y1) + // // as u64 + // // }) + // .collect(); + // a + // dbg!(a0[0] as u128); // dbg!(a0[0] as u128 * y0); // dbg!(a0[0] as u128 * z0); // dbg!(a0[0] as u128 * y0 * z0); - // let a: Vec = itertools::multizip((a0, a1, a2)) - // .map(|(a0_i, a1_i, a2_i)| { - // a0_i as u128 * y0 * z0 + a1_i as u128 * y1 * z1 + a2_i as u128 * y2 * z2 - // }) - // .collect(); + dbg!(&y0, &y1, &y2); + dbg!(&z0, &z1, &z2); + let Q = P0 as u128 * P1 as u128 * P2 as u128; + let a: Vec = itertools::multizip((a0, a1, a2)) + .map(|(a0_i, a1_i, a2_i)| { + (a0_i as u128 * y0 * z0)// % Q + + (a1_i as u128 * y1 * z1)// % Q + + (a2_i as u128 * y2 * z2) // % Q + }) + .map(|v| v % Q) + .map(|v| v as u32) + .map(|v| v as u64) + .collect(); + a // dbg!(&a); // let Q = y2 * P2 as u128; // let a: Vec = a.iter().map(|a_i| a_i % Q).collect(); @@ -202,11 +238,75 @@ fn exp_mod(q: u128, x: u128, k: u128) -> u128 { } r } -/// returns x^-1 mod Q -fn inv_mod(q: u128, x: u128) -> u128 { - // by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q - exp_mod(q, x, q - 2) +/// returns x^-1 mod Q, assuming x and Q are coprime, generally Q is prime +// fn inv_mod(q: u128, x: u128) -> u128 { +// // by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q +// // exp_mod(q, x, q - 2) +// exp_mod(q, x, q - 2) +// } + +fn inv_mod(m: u128, a: u128) -> u128 { + // if m == 1 { + // return Some(0); + // } + + let mut m = m.clone(); + let mut a = a.clone(); + let m0 = m.clone(); + let mut x0: i128 = 0; + let mut x1: i128 = 1; + + while a > 1 { + let q = a / m; + let t = m.clone(); + + m = a % m; + a = t.clone(); + + let t = x0; + x0 = x1 - (q as i128) * x0; + x1 = t; + } + + if x1 < 0 { + x1 += m0 as i128; + } + + x1 as u128 } +// fn inv_mod(m: u128, a: u128) -> u128 { +// let mut m = m.clone(); +// let mut a = a.clone(); +// let m0 = m.clone(); +// let mut x0 = 0; +// let mut x1 = 1; +// +// if m == 1 { +// return 0; +// } +// +// while a > 1 { +// let q = a / m; +// +// let mut t = m.clone(); +// +// m = a % m; +// a = t.clone(); +// +// t = x0.clone(); +// +// x0 = x1 - q * x0; +// +// x1 = t.clone(); +// } +// +// // if (x1 < 0) { +// // x1 = x1 + m0 +// // } +// +// // return x1 % m0; +// return x1 % m0; +// } #[cfg(test)] mod tests { @@ -215,10 +315,50 @@ mod tests { use anyhow::Result; + #[test] + fn test_inv_mod() -> Result<()> { + let x = 3; + let x_inv = inv_mod(P0 as u128, x); + dbg!(&P0); + dbg!(&x_inv); + + // let r = x_inv * x; + // dbg!(&r); + + Ok(()) + } + + #[test] + fn test_reconstruct() -> Result<()> { + let n: usize = 16; + + use rand::Rng; + let mut rng = rand::thread_rng(); + let a: Vec = (0..n) + // .map(|_| rng.gen_range(0..(1 << 64))) + .map(|_| rng.gen_range(0..(1 << 32))) + // .map(|_| rng.gen_range(0..16)) + // .map(|_| rng.sample(rand::distributions::Standard)) + .collect(); + + dbg!(a.len()); + + let a_0: Vec = a.iter().map(|a_i| a_i % P0).collect(); + let a_1: Vec = a.iter().map(|a_i| a_i % P1).collect(); + let a_2: Vec = a.iter().map(|a_i| a_i % P2).collect(); + dbg!(&a_0, &a_1, &a_2); + let a_reconstructed = reconstruct(a_0, a_1, a_2); + + dbg!(&a_reconstructed); + assert_eq!(a_reconstructed, a); + + Ok(()) + } + #[test] fn test_dbg() -> Result<()> { println!("{}", 1u128 << 64); - let n: usize = 16; + let n: usize = 2; println!("{}", P0); println!("{}", P1); @@ -231,38 +371,30 @@ mod tests { use rand::Rng; let mut rng = rand::thread_rng(); let a: Vec = (0..n) - // .map(|_| rng.gen_range(0..=(1u64 << 57) - 1) - (1u64 << 56)) - .map(|_| rng.gen_range(0..(1 << 61))) + // .map(|_| rng.gen_range(0..(1 << 64))) + // .map(|_| rng.gen_range(0..(1 << 32))) + .map(|_| rng.gen_range(0..16)) + // .map(|_| rng.sample(rand::distributions::Standard)) .collect(); dbg!(a.len()); - let a_ntt = NTT::ntt(n, &a); - dbg!(&a_ntt); - - let a_intt = NTT::intt(n, &a_ntt); + let a_0: Vec = a.iter().map(|a_i| a_i % P0).collect(); + let a_1: Vec = a.iter().map(|a_i| a_i % P1).collect(); + let a_2: Vec = a.iter().map(|a_i| a_i % P2).collect(); + dbg!(&a_0, &a_1, &a_2); + // let a_0 = vec![3]; + // let a_1 = vec![3]; + // let a_2 = vec![3]; + let a_intt = reconstruct(a_0, a_1, a_2); + // let a_ntt = NTT::ntt(n, &a); + // dbg!(&a_ntt); + // + // let a_intt = NTT::intt(n, &a_ntt); dbg!(&a_intt); assert_eq!(a_intt, a); - // unnecessary: - // let a: Vec = vec![2u64, 4, 6, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - // let a_ntt = NTT_p::ntt(P0, n, &a); - // dbg!(&a_ntt); - // let a_intt = NTT_p::intt(P0, n, &a_ntt); - // dbg!(&a_intt); - - // NOTE: *n_inv is already done in the intt method. - - // Multiplies the values by the inverse of the polynomial modulo the NTT modulus - // let n_inv = inv_mod(P0 as u128, n as u64); // n^-1 mod p0 - // let a_new: Vec = a_0_intt - // .iter() - // .map(|a_i| ((*a_i as u128 * n_inv as u128) % P0 as u128) as u64) - // .collect(); - - // assert_eq!(a_intt, a); - Ok(()) } }