From 900c42ac814835adc69c4f13c37ed6f5fc2d0a45 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Fri, 29 Aug 2025 17:12:06 +0000 Subject: [PATCH] [composite-ntt] reconstruct u64 works (CRT with 3 eq'ns) --- arith/Cargo.toml | 3 + arith/src/ntt_u62.rs | 186 ++--------------------- arith/src/ntt_u64.rs | 340 +++++++++++++++++-------------------------- 3 files changed, 148 insertions(+), 381 deletions(-) diff --git a/arith/Cargo.toml b/arith/Cargo.toml index 692757b..1a27132 100644 --- a/arith/Cargo.toml +++ b/arith/Cargo.toml @@ -17,3 +17,6 @@ num = "0.4.3" num-complex = "0.4.6" ndarray = "0.16.1" ndarray-linalg = { version = "0.17.0", features = ["intel-mkl"] } + +num-bigint = "0.4.6" +num-traits = "0.2.19" diff --git a/arith/src/ntt_u62.rs b/arith/src/ntt_u62.rs index 0e6d8f0..ec1f191 100644 --- a/arith/src/ntt_u62.rs +++ b/arith/src/ntt_u62.rs @@ -3,140 +3,49 @@ use crate::ntt::NTT as NTT_p; -// const P0: u64 = 17293822569241362433; - -// const P0: u64 = 4611686018427387905; -// const P1: u64 = 4611686018326724609; - -const P0: u64 = 8070449433331580929; // max use 1<<62 +const P0: u64 = 8070449433331580929; // max use: 1<<62 const P1: u64 = 8070450532384645121; -// const P0: u64 = 0x80000000080001; // max use 1<<55 +// const P0: u64 = 0x80000000080001; // max use: 1<<55 // const P1: u64 = 0x80000000130001; -// const P0: u64 = ((1u128 << 64) - (1u128 << 28) + 1u128) as u64; -// const P1: u64 = ((1u128 << 64) - (1u128 << 27) + 1u128) as u64; -// const P0: u64 = ((1u128 << 64) - (1u128 << 28) + 1u128) as u64; -// const P1: u64 = ((1u128 << 64) - (1u128 << 27) + 1u128) as u64; -// const P2: u64 = (1 << 60) - (1 << 26) + 1; - #[derive(Debug)] pub struct NTT {} impl NTT { - pub fn ntt( - n: usize, - a: &Vec, - ) -> ( - Vec, - Vec, - // Vec, - // Vec, - // Vec, - // Vec, - // Vec, - ) { - // TODO ensure that: a_i ) -> (Vec, Vec) { // apply modulus p_i let a_0: Vec = a.iter().map(|a_i| (a_i % P0 + P0) % P0).collect(); let a_1: Vec = a.iter().map(|a_i| (a_i % P1 + P1) % P1).collect(); - // let a_2: Vec = a.iter().map(|a_i| (a_i % P2 + P2) % P2).collect(); - // let a_3: Vec = a.iter().map(|a_i| (a_i % P3 + P3) % P3).collect(); - // let a_4: Vec = a.iter().map(|a_i| (a_i % P4 + P4) % P4).collect(); - // let a_5: Vec = a.iter().map(|a_i| (a_i % P5 + P5) % P5).collect(); - // let a_6: Vec = a.iter().map(|a_i| (a_i % P6 + P6) % P6).collect(); let r_0 = NTT_p::ntt(P0, n, &a_0); let r_1 = NTT_p::ntt(P1, n, &a_1); - // let r_2 = NTT_p::ntt(P2, n, &a_2); - // let r_3 = NTT_p::ntt(P3, n, &a_3); - // let r_4 = NTT_p::ntt(P4, n, &a_4); - // let r_5 = NTT_p::ntt(P5, n, &a_5); - // let r_6 = NTT_p::ntt(P6, n, &a_6); - (r_0, r_1) //, r_2) //, r_3, r_4) //, r_5, r_6) + (r_0, r_1) } - pub fn intt( - n: usize, - r: &( - Vec, - Vec, - // Vec, - // Vec, - // Vec, - // Vec, - // Vec, - ), - ) -> Vec { + pub fn intt(n: usize, r: &(Vec, Vec)) -> Vec { let a_0 = NTT_p::intt(P0, n, &r.0); let a_1 = NTT_p::intt(P1, n, &r.1); - // let a_2 = NTT_p::intt(P2, n, &r.2); - // let a_3 = NTT_p::intt(P3, n, &r.3); - // let a_4 = NTT_p::intt(P4, n, &r.4); - // let a_5 = NTT_p::intt(P5, n, &r.5); - // let a_6 = NTT_p::intt(P6, n, &r.6); - - // Garner CRT for two moduli: combine (r1 mod p1, r2 mod p2) -> Z/(p1*p2) - // let inv_p1_mod_p2: u128 = inv_mod_u64(p1 % p2, p2) as u128; - // const INV_P1_MOD_P2: u128 = 4895217125691974194; - reconstruct(a_0, a_1) //, a_2) // , a_3, a_4) //, a_5, a_6) + reconstruct(a_0, a_1) } } -fn reconstruct( - a0: Vec, - a1: Vec, - // a2: Vec, - // a_3: Vec, - // a_4: Vec, - // a_5: Vec, - // a_6: Vec, -) -> Vec { - // let Q = P0 as u128 * P1 as u128; - +/// applies CRT to reconstruct the composite original value. Uses Garner's CRT algorithm for two +/// moduli: combine (r1 mod p1, r2 mod p2) -> Z/(p1*p2) +fn reconstruct(a0: Vec, a1: Vec) -> Vec { // y_i = q/q_i - // let y0 = ((u64::MAX as u128 + 1) / P0 as u128); - // let y1 = ((u64::MAX as u128 + 1) / P1 as u128); - // let y2 = ((u64::MAX as u128 + 1) / P2 as u128) as u64; - let y0: u128 = P1 as u128; // N_i =Q/P0 = P1*P2 - let y1: u128 = P0 as u128; - // let y1: u128 = P0 as u128 * P2 as u128; - // let y0: u128 = P1 as u128 * P2 as u128; // N_i =Q/P0 = P1*P2 - // let y1: u128 = P0 as u128 * P2 as u128; - // let y2: u128 = P0 as u128 * P1 as u128; - // let y0 = (Q / P0 as u128) as u64; - // let y1 = (Q / P1 as u128) as u64; - // let y2 = ((u64::MAX as u128 + 1) / P2 as u128) as u64; - // let y3 = ((u64::MAX as u128 + 1) / P3 as u128) as u64; - // let y4 = ((u64::MAX as u128 + 1) / P4 as u128) as u64; + // let y0: u128 = P1 as u128; + let y1: u128 = P0 as u128; // Q/P1 = P0*P1/P1 = P0 // y_i^-1 mod q_i = z_i - let z0: u128 = inv_mod(P0 as u128, y0); // M_i = N_i^-1 mod q_i + // let z0: u128 = inv_mod(P0 as u128, y0); // M_i = N_i^-1 mod q_i let z1: u128 = inv_mod(P1 as u128, y1); - // let z2: u128 = inv_mod(P2 as u128, y2); - // let y2_inv = inv_mod(P2 as u128, y2); - // let y3_inv = inv_mod(P3 as u128, y3); - // let y4_inv = inv_mod(P4 as u128, y4); // m1 = q1^-1 mod q2 // aux = (a2 - a1) * m1 mod q2 // a = a1 + (q1 * m1) * aux - - /* - let m1 = inv_mod(P1 as u128, P0 as u128) as u64; // P0^-1 mod P1 - let aux: Vec = itertools::zip_eq(a0.clone(), a1.clone()) - .map(|(a0_i, a1_i)| ((a1_i - a0_i) * m1) % P1) - .collect(); - let a: Vec = itertools::zip_eq(a0, aux) - // .map(|(a1_i, aux_i)| a1_i + (P1 * m1) * aux_i) - // .map(|(a0_i, aux_i)| a0_i + (P0 * m1) * aux_i) - .map(|(a0_i, aux_i)| a0_i + ((P0 * m1) % P1) * aux_i) - .collect(); - a - */ let p0: u128 = P0 as u128; let p1: u128 = P1 as u128; let a: Vec = itertools::zip_eq(a0, a1) @@ -144,65 +53,6 @@ fn reconstruct( .map(|v| v as u64) .collect(); a - // dbg!(a0[0] as u128); - // dbg!(a0[0] as u128 * y0); - // dbg!(a0[0] as u128 * z0); - // dbg!(a0[0] as u128 * y0 * z0); - // let a: Vec = itertools::multizip((a0, a1, a2)) - // .map(|(a0_i, a1_i, a2_i)| { - // a0_i as u128 * y0 * z0 + a1_i as u128 * y1 * z1 + a2_i as u128 * y2 * z2 - // }) - // .collect(); - // dbg!(&a); - // let Q = y2 * P2 as u128; - // let a: Vec = a.iter().map(|a_i| a_i % Q).collect(); - // dbg!(&a); - // let q64 = 1_u128 << 64; - // let a: Vec = a.iter().map(|a_i| (a_i % q64) as u64).collect(); - // a - - /* - // x_i*z_i mod q_i - let r0: Vec = a_0.iter().map(|a_i| ((a_i * z0) % P0) * y0).collect(); - let r1: Vec = a_1.iter().map(|a_i| ((a_i * z1) % P1) * y1).collect(); - // let r0: Vec = a_0.iter().map(|a_i| ((a_i * z0) % P0) * y0).collect(); - // let r1: Vec = a_1.iter().map(|a_i| ((a_i * z1) % P1) * y1).collect(); - // let r2: Vec = a_2.iter().map(|a_i| ((a_i * y2_inv) % P2) * y2).collect(); - // let r3: Vec = a_3.iter().map(|a_i| ((a_i * y3_inv) % P3) * y3).collect(); - // let r4: Vec = a_4.iter().map(|a_i| ((a_i * y4_inv) % P4) * y4).collect(); - - let r: Vec = itertools::multizip((r0.iter(), r1.iter())) - .map(|(a, b)| a + b) - .collect(); - // let r = r0; - // - dbg!(&r); - - let p1p2: u128 = (P0 as u128) * (P1 as u128); - // let p1p2_inv: u128 = inv_mod((P0 % P1) as u128, P1) as u128; - let p1p2_inv: u128 = inv_mod((P0) as u128, P1) as u128; - dbg!(&p1p2); - dbg!(&p1p2_inv); - // let p1p2: u128 = P0 as u128 / 2; // PIHALF - let r = r - .iter() - .map(|c_i_u64| { - let c_i = *c_i_u64 as u128; - if c_i * 2 >= p1p2 { - // if c_i >= p1p2 { - c_i.wrapping_sub(p1p2) as u64 - } else { - c_i as u64 - } - }) - .collect(); - // let r: Vec = itertools::multizip((r0.iter(), r1.iter(), r2.iter(), r3.iter(), r4.iter())) - // .map(|(a, b, c, d, e)| a + b + c + d + e) - // .collect(); - // let mut r = a_0 + y0_inv + a_1 * y1_inv + a_2 * y2_inv + a_3 * y3_inv + a_4 * y4_inv; - - r - */ } fn exp_mod(q: u128, x: u128, k: u128) -> u128 { @@ -230,29 +80,19 @@ fn inv_mod(q: u128, x: u128) -> u128 { #[cfg(test)] mod tests { use super::*; - use rand_distr::Distribution; use anyhow::Result; #[test] fn test_dbg() -> Result<()> { - println!("{}", 1u128 << 64); let n: usize = 16; - println!("{}", P0); - println!("{}", P1); // let q = 1u128 << 64; // assert!(P0 as u128 * P1 as u128 > (n as u128 * (q * q)) / 2); - // let a: Vec = vec![1u64, 2, 3, 4]; - // let a: Vec = vec![1u64, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - // let a: Vec = vec![9u64, 8, 7, 6, 0, 9999, 0, 0, 0, 0, 0, 0, 6, 7, 8, 9]; use rand::Rng; let mut rng = rand::thread_rng(); - let a: Vec = (0..n) - // .map(|_| rng.gen_range(0..=(1u64 << 57) - 1) - (1u64 << 56)) - .map(|_| rng.gen_range(0..(1 << 62))) - .collect(); + let a: Vec = (0..n).map(|_| rng.gen_range(0..(1 << 62))).collect(); dbg!(a.len()); diff --git a/arith/src/ntt_u64.rs b/arith/src/ntt_u64.rs index a5b7445..2a60104 100644 --- a/arith/src/ntt_u64.rs +++ b/arith/src/ntt_u64.rs @@ -14,34 +14,22 @@ use crate::ntt::NTT as NTT_p; // const P0: u64 = 0x80000000080001; // max use 1<<55 -1 // const P1: u64 = 0x80000000130001; -// const P0: u64 = ((1u128 << 64) - (1u128 << 28) + 1u128) as u64; -// const P1: u64 = ((1u128 << 64) - (1u128 << 27) + 1u128) as u64; -// const P2: u64 = ((1u128 << 64) - (1u128 << 26) + 1u128) as u64; - -const P0: u64 = ((1u128 << 32) - (1u128 << 18) + 1u128) as u64; -const P1: u64 = ((1u128 << 32) - (1u128 << 17) + 1u128) as u64; -const P2: u64 = ((1u128 << 32) - (1u128 << 16) + 1u128) as u64; -// const P1: u64 = ((1u128 << 64) - (1u128 << 27) + 1u128) as u64; -// const P2: u64 = (1 << 60) - (1 << 26) + 1; +const BITS: u128 = 64; +const P0: u64 = ((1u128 << BITS) - (1u128 << 28) + 1u128) as u64; +const P1: u64 = ((1u128 << BITS) - (1u128 << 27) + 1u128) as u64; +const P2: u64 = ((1u128 << BITS) - (1u128 << 26) + 1u128) as u64; +// const P0: u64 = ((1u128 << BITS) - (1u128 << 18) + 1u128) as u64; +// const P1: u64 = ((1u128 << BITS) - (1u128 << 17) + 1u128) as u64; +// const P2: u64 = ((1u128 << BITS) - (1u128 << 16) + 1u128) as u64; +// const P0: u64 = 0x0FFFFFFF0000001; // 56 bits each P_i +// const P1: u64 = 0x0FFFFFFE8000001; +// const P2: u64 = 0x0FFFFFFE4000001; #[derive(Debug)] pub struct NTT {} impl NTT { - pub fn ntt( - n: usize, - a: &Vec, - ) -> ( - Vec, - Vec, - Vec, - // Vec, - // Vec, - // Vec, - // Vec, - ) { - // TODO ensure that: a_i ) -> (Vec, Vec, Vec) { // apply modulus p_i let a_0: Vec = a.iter().map(|a_i| a_i % P0).collect(); let a_1: Vec = a.iter().map(|a_i| a_i % P1).collect(); @@ -49,83 +37,45 @@ impl NTT { // let a_0: Vec = a.iter().map(|a_i| (a_i % P0 + P0) % P0).collect(); // let a_1: Vec = a.iter().map(|a_i| (a_i % P1 + P1) % P1).collect(); // let a_2: Vec = a.iter().map(|a_i| (a_i % P2 + P2) % P2).collect(); - // let a_3: Vec = a.iter().map(|a_i| (a_i % P3 + P3) % P3).collect(); - // let a_4: Vec = a.iter().map(|a_i| (a_i % P4 + P4) % P4).collect(); - // let a_5: Vec = a.iter().map(|a_i| (a_i % P5 + P5) % P5).collect(); - // let a_6: Vec = a.iter().map(|a_i| (a_i % P6 + P6) % P6).collect(); dbg!(&a_0, &a_1, &a_2); let r_0 = NTT_p::ntt(P0, n, &a_0); let r_1 = NTT_p::ntt(P1, n, &a_1); let r_2 = NTT_p::ntt(P2, n, &a_2); - // let r_3 = NTT_p::ntt(P3, n, &a_3); - // let r_4 = NTT_p::ntt(P4, n, &a_4); - // let r_5 = NTT_p::ntt(P5, n, &a_5); - // let r_6 = NTT_p::ntt(P6, n, &a_6); + dbg!(&r_0, &r_1, &r_2); - (r_0, r_1, r_2) //, r_3, r_4) //, r_5, r_6) + (r_0, r_1, r_2) } - pub fn intt( - n: usize, - r: &( - Vec, - Vec, - Vec, - // Vec, - // Vec, - // Vec, - // Vec, - ), - ) -> Vec { + pub fn intt(n: usize, r: &(Vec, Vec, Vec)) -> Vec { let a_0 = NTT_p::intt(P0, n, &r.0); let a_1 = NTT_p::intt(P1, n, &r.1); let a_2 = NTT_p::intt(P2, n, &r.2); - // let a_3 = NTT_p::intt(P3, n, &r.3); - // let a_4 = NTT_p::intt(P4, n, &r.4); - // let a_5 = NTT_p::intt(P5, n, &r.5); - // let a_6 = NTT_p::intt(P6, n, &r.6); - - // Garner CRT for two moduli: combine (r1 mod p1, r2 mod p2) -> Z/(p1*p2) - // let inv_p1_mod_p2: u128 = inv_mod_u64(p1 % p2, p2) as u128; - // const INV_P1_MOD_P2: u128 = 4895217125691974194; + dbg!(&a_0, &a_1, &a_2); + // TODO WIP + // let a_0: Vec = a_0.iter().map(|a_i| a_i % P0).collect(); + // let a_1: Vec = a_1.iter().map(|a_i| a_i % P1).collect(); + // let a_2: Vec = a_2.iter().map(|a_i| a_i % P2).collect(); - reconstruct(a_0, a_1, a_2) // , a_3, a_4) //, a_5, a_6) + reconstruct(a_0, a_1, a_2) } } -fn reconstruct( - a0: Vec, - a1: Vec, - a2: Vec, - // a_3: Vec, - // a_4: Vec, - // a_5: Vec, - // a_6: Vec, -) -> Vec { +/// applies CRT to reconstruct the composite original value +fn reconstruct(a0: Vec, a1: Vec, a2: Vec) -> Vec { + let p0: u128 = P0 as u128; + let p1: u128 = P1 as u128; + let p2: u128 = P2 as u128; // y_i = q/q_i - // let y0 = ((u64::MAX as u128 + 1) / P0 as u128); - // let y1 = ((u64::MAX as u128 + 1) / P1 as u128); - // let y2 = ((u64::MAX as u128 + 1) / P2 as u128) as u64; - // let y0: u128 = P1 as u128; // N_i =Q/P0 = P1*P2 - // let y1: u128 = P0 as u128; - let y0: u128 = P1 as u128 * P2 as u128; // N_i =Q/P0 = P1*P2 - let y1: u128 = P0 as u128 * P2 as u128; - let y2: u128 = P0 as u128 * P1 as u128; - // let y0 = (Q / P0 as u128) as u64; - // let y1 = (Q / P1 as u128) as u64; - // let y2 = ((u64::MAX as u128 + 1) / P2 as u128) as u64; - // let y3 = ((u64::MAX as u128 + 1) / P3 as u128) as u64; - // let y4 = ((u64::MAX as u128 + 1) / P4 as u128) as u64; + let y0: u128 = p1 * p2; // N_i =Q/P0 = P1*P2 + let y1: u128 = p0 * p2; + let y2: u128 = p0 * p1; // y_i^-1 mod q_i = z_i dbg!(P0, y0); - let z0: u128 = inv_mod(P0 as u128, y0); // M_i = N_i^-1 mod q_i - let z1: u128 = inv_mod(P1 as u128, y1); - let z2: u128 = inv_mod(P2 as u128, y2); - // let y2_inv = inv_mod(P2 as u128, y2); - // let y3_inv = inv_mod(P3 as u128, y3); - // let y4_inv = inv_mod(P4 as u128, y4); + let z0: u128 = inv_mod(p0, y0); // M_i = N_i^-1 mod q_i + let z1: u128 = inv_mod(p1, y1); + let z2: u128 = inv_mod(p2, y2); // m1 = q1^-1 mod q2 // aux = (a2 - a1) * m1 mod q2 @@ -136,9 +86,6 @@ fn reconstruct( // let aux: Vec = itertools::zip_eq(a0.clone(), a1.clone()) // .map(|(a0_i, a1_i)| ((a1_i - a0_i) * m1) % P1) // .collect(); - let p0: u128 = P0 as u128; - let p1: u128 = P1 as u128; - let p2: u128 = P2 as u128; // let a: Vec = itertools::zip_eq(a0, a1) // // .map(|(a1_i, aux_i)| a1_i + (P1 * m1) * aux_i) // // .map(|(a0_i, aux_i)| a0_i + (P0 * m1) * aux_i) @@ -158,16 +105,31 @@ fn reconstruct( // dbg!(a0[0] as u128 * y0 * z0); dbg!(&y0, &y1, &y2); dbg!(&z0, &z1, &z2); - let Q = P0 as u128 * P1 as u128 * P2 as u128; + // WIP, using BigUint to use Q with 192 bits (product of 3 u64) + let Q = BigUint::from_u64(P0).unwrap() + * BigUint::from_u64(P1).unwrap() + * BigUint::from_u64(P2).unwrap(); + let max_u64 = BigUint::from_u128(1_u128 << 64).unwrap(); + // let Q = P0 as u128 * P1 as u128 * P2 as u128; let a: Vec = itertools::multizip((a0, a1, a2)) .map(|(a0_i, a1_i, a2_i)| { - (a0_i as u128 * y0 * z0)// % Q - + (a1_i as u128 * y1 * z1)// % Q - + (a2_i as u128 * y2 * z2) // % Q + // (a0_i as u128 * y0 * z0)// % Q + // + (a1_i as u128 * y1 * z1)// % Q + // + (a2_i as u128 * y2 * z2) // % Q + mul_3_big(a0_i as u128, y0, z0, &Q) // TODO rm %Q, since it returns a BigUint, and %Q + // is done later at the end of the additions + + mul_3_big(a1_i as u128, y1, z1, &Q) + + mul_3_big(a2_i as u128, y2, z2, &Q) + }) + .map(|v| { + // WIP, using BigUint to use Q with 192 bits (product of 3 u64) + // ((BigUint::from_u128(v).unwrap() % Q.clone()) % max_u64.clone()) + ((v % Q.clone()) % max_u64.clone()).to_u64().unwrap() }) - .map(|v| v % Q) - .map(|v| v as u32) - .map(|v| v as u64) + // .map(|v| v % (1 << 63)) + // .map(|v| v % Q) + // .map(|v| v as u32) + // .map(|v| v as u64) .collect(); a // dbg!(&a); @@ -177,49 +139,6 @@ fn reconstruct( // let q64 = 1_u128 << 64; // let a: Vec = a.iter().map(|a_i| (a_i % q64) as u64).collect(); // a - - /* - // x_i*z_i mod q_i - let r0: Vec = a_0.iter().map(|a_i| ((a_i * z0) % P0) * y0).collect(); - let r1: Vec = a_1.iter().map(|a_i| ((a_i * z1) % P1) * y1).collect(); - // let r0: Vec = a_0.iter().map(|a_i| ((a_i * z0) % P0) * y0).collect(); - // let r1: Vec = a_1.iter().map(|a_i| ((a_i * z1) % P1) * y1).collect(); - // let r2: Vec = a_2.iter().map(|a_i| ((a_i * y2_inv) % P2) * y2).collect(); - // let r3: Vec = a_3.iter().map(|a_i| ((a_i * y3_inv) % P3) * y3).collect(); - // let r4: Vec = a_4.iter().map(|a_i| ((a_i * y4_inv) % P4) * y4).collect(); - - let r: Vec = itertools::multizip((r0.iter(), r1.iter())) - .map(|(a, b)| a + b) - .collect(); - // let r = r0; - // - dbg!(&r); - - let p1p2: u128 = (P0 as u128) * (P1 as u128); - // let p1p2_inv: u128 = inv_mod((P0 % P1) as u128, P1) as u128; - let p1p2_inv: u128 = inv_mod((P0) as u128, P1) as u128; - dbg!(&p1p2); - dbg!(&p1p2_inv); - // let p1p2: u128 = P0 as u128 / 2; // PIHALF - let r = r - .iter() - .map(|c_i_u64| { - let c_i = *c_i_u64 as u128; - if c_i * 2 >= p1p2 { - // if c_i >= p1p2 { - c_i.wrapping_sub(p1p2) as u64 - } else { - c_i as u64 - } - }) - .collect(); - // let r: Vec = itertools::multizip((r0.iter(), r1.iter(), r2.iter(), r3.iter(), r4.iter())) - // .map(|(a, b, c, d, e)| a + b + c + d + e) - // .collect(); - // let mut r = a_0 + y0_inv + a_1 * y1_inv + a_2 * y2_inv + a_3 * y3_inv + a_4 * y4_inv; - - r - */ } fn exp_mod(q: u128, x: u128, k: u128) -> u128 { @@ -239,16 +158,16 @@ fn exp_mod(q: u128, x: u128, k: u128) -> u128 { r } /// returns x^-1 mod Q, assuming x and Q are coprime, generally Q is prime -// fn inv_mod(q: u128, x: u128) -> u128 { -// // by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q -// // exp_mod(q, x, q - 2) -// exp_mod(q, x, q - 2) -// } +fn inv_mod_new(q: u128, x: u128) -> u128 { + // by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q + // exp_mod(q, x, q - 2) + exp_mod(q, x, q - 2) +} fn inv_mod(m: u128, a: u128) -> u128 { - // if m == 1 { - // return Some(0); - // } + if m == 1 { + panic!("m==1"); + } let mut m = m.clone(); let mut a = a.clone(); @@ -274,39 +193,22 @@ fn inv_mod(m: u128, a: u128) -> u128 { x1 as u128 } -// fn inv_mod(m: u128, a: u128) -> u128 { -// let mut m = m.clone(); -// let mut a = a.clone(); -// let m0 = m.clone(); -// let mut x0 = 0; -// let mut x1 = 1; -// -// if m == 1 { -// return 0; -// } -// -// while a > 1 { -// let q = a / m; -// -// let mut t = m.clone(); -// -// m = a % m; -// a = t.clone(); -// -// t = x0.clone(); -// -// x0 = x1 - q * x0; -// -// x1 = t.clone(); -// } -// -// // if (x1 < 0) { -// // x1 = x1 + m0 -// // } -// -// // return x1 % m0; -// return x1 % m0; -// } + +use num_bigint::BigUint; +use num_traits::{FromPrimitive, ToPrimitive}; +fn mul_3_big(a: u128, b: u128, c: u128, Q: &BigUint) -> BigUint { + let r = (BigUint::from_u128(a).unwrap() + * BigUint::from_u128(b).unwrap() + * BigUint::from_u128(c).unwrap()); + // % Q; + dbg!(&r); + let r = r % Q; + dbg!(&r); + r + // let max_u64 = BigUint::from_u128(1_u128 << 64).unwrap(); + // dbg!(&r % &max_u64); + // (r % max_u64).to_u128().unwrap() +} #[cfg(test)] mod tests { @@ -317,13 +219,36 @@ mod tests { #[test] fn test_inv_mod() -> Result<()> { + // test vectors used in this test generated in SageMath + + let p: u64 = 0x0FFFFFFF0000001; + + let x = 3; + let x_inv = inv_mod(p as u128, x); + assert_eq!(x_inv, 48038395846328321); + + let x = (1_u128 << 64) - 1; + let x_inv = inv_mod(p as u128, x); + assert_eq!(x_inv, 25433122709808565); + + let x = 1_u128 << 120; + let x_inv = inv_mod(p as u128, x); + assert_eq!(x_inv, 281474976710656); + + // other prime + let p: u64 = ((1u128 << 64) - (1u128 << 28) + 1u128) as u64; + let x = 3; - let x_inv = inv_mod(P0 as u128, x); - dbg!(&P0); - dbg!(&x_inv); + let x_inv = inv_mod(p as u128, x); + assert_eq!(x_inv, 12297829382294077441); - // let r = x_inv * x; - // dbg!(&r); + let x = (1_u128 << 64) - 1; + let x_inv = inv_mod(p as u128, x); + assert_eq!(x_inv, 10530692608818076599); + + let x = 1_u128 << 120; + let x_inv = inv_mod(p as u128, x); + assert_eq!(x_inv, 18374686616574755070); Ok(()) } @@ -332,20 +257,30 @@ mod tests { fn test_reconstruct() -> Result<()> { let n: usize = 16; + use num_bigint::BigUint; + use num_traits::{FromPrimitive, ToPrimitive}; + let Q = BigUint::from_u64(P0).unwrap() + * BigUint::from_u64(P1).unwrap() + * BigUint::from_u64(P2).unwrap(); + let q = BigUint::from_u128(1_u128 << BITS).unwrap(); + let N = BigUint::from_usize(n).unwrap(); + let big2 = BigUint::from_u64(2).unwrap(); + assert!(Q > (N * (&q * &q)) / big2); // sanity check + use rand::Rng; let mut rng = rand::thread_rng(); let a: Vec = (0..n) - // .map(|_| rng.gen_range(0..(1 << 64))) - .map(|_| rng.gen_range(0..(1 << 32))) - // .map(|_| rng.gen_range(0..16)) - // .map(|_| rng.sample(rand::distributions::Standard)) + .map(|_| rng.gen_range(0..((1u128 << 64) - 1) as u64)) .collect(); - - dbg!(a.len()); + // let a = vec![14713100818624219214]; + dbg!(&a); let a_0: Vec = a.iter().map(|a_i| a_i % P0).collect(); let a_1: Vec = a.iter().map(|a_i| a_i % P1).collect(); let a_2: Vec = a.iter().map(|a_i| a_i % P2).collect(); + // let a_0: Vec = a.iter().map(|a_i| (a_i % P0 + P0) % P0).collect(); + // let a_1: Vec = a.iter().map(|a_i| (a_i % P1 + P1) % P1).collect(); + // let a_2: Vec = a.iter().map(|a_i| (a_i % P2 + P2) % P2).collect(); dbg!(&a_0, &a_1, &a_2); let a_reconstructed = reconstruct(a_0, a_1, a_2); @@ -356,12 +291,12 @@ mod tests { } #[test] - fn test_dbg() -> Result<()> { - println!("{}", 1u128 << 64); - let n: usize = 2; + fn test_ntt() -> Result<()> { + // println!("{}", 1u128 << 64); + let n: usize = 1; - println!("{}", P0); - println!("{}", P1); + // println!("{}", P0); + // println!("{}", P1); // let q = 1u128 << 64; // assert!(P0 as u128 * P1 as u128 > (n as u128 * (q * q)) / 2); @@ -371,26 +306,15 @@ mod tests { use rand::Rng; let mut rng = rand::thread_rng(); let a: Vec = (0..n) - // .map(|_| rng.gen_range(0..(1 << 64))) - // .map(|_| rng.gen_range(0..(1 << 32))) - .map(|_| rng.gen_range(0..16)) - // .map(|_| rng.sample(rand::distributions::Standard)) + .map(|_| rng.gen_range(0..((1u128 << 64) - 1) as u64)) .collect(); - dbg!(a.len()); + dbg!(&a); - let a_0: Vec = a.iter().map(|a_i| a_i % P0).collect(); - let a_1: Vec = a.iter().map(|a_i| a_i % P1).collect(); - let a_2: Vec = a.iter().map(|a_i| a_i % P2).collect(); - dbg!(&a_0, &a_1, &a_2); - // let a_0 = vec![3]; - // let a_1 = vec![3]; - // let a_2 = vec![3]; - let a_intt = reconstruct(a_0, a_1, a_2); - // let a_ntt = NTT::ntt(n, &a); - // dbg!(&a_ntt); - // - // let a_intt = NTT::intt(n, &a_ntt); + let a_ntt = NTT::ntt(n, &a); + dbg!(&a_ntt); + + let a_intt = NTT::intt(n, &a_ntt); dbg!(&a_intt); assert_eq!(a_intt, a);