Browse Source

[composite-ntt] reconstruct u64 works (CRT with 3 eq'ns)

composite-ntt
arnaucube 1 month ago
parent
commit
900c42ac81
3 changed files with 148 additions and 381 deletions
  1. +3
    -0
      arith/Cargo.toml
  2. +13
    -173
      arith/src/ntt_u62.rs
  3. +132
    -208
      arith/src/ntt_u64.rs

+ 3
- 0
arith/Cargo.toml

@ -17,3 +17,6 @@ num = "0.4.3"
num-complex = "0.4.6" num-complex = "0.4.6"
ndarray = "0.16.1" ndarray = "0.16.1"
ndarray-linalg = { version = "0.17.0", features = ["intel-mkl"] } ndarray-linalg = { version = "0.17.0", features = ["intel-mkl"] }
num-bigint = "0.4.6"
num-traits = "0.2.19"

+ 13
- 173
arith/src/ntt_u62.rs

@ -3,140 +3,49 @@
use crate::ntt::NTT as NTT_p; use crate::ntt::NTT as NTT_p;
// const P0: u64 = 17293822569241362433;
// const P0: u64 = 4611686018427387905;
// const P1: u64 = 4611686018326724609;
const P0: u64 = 8070449433331580929; // max use 1<<62
const P0: u64 = 8070449433331580929; // max use: 1<<62
const P1: u64 = 8070450532384645121; const P1: u64 = 8070450532384645121;
// const P0: u64 = 0x80000000080001; // max use 1<<55
// const P0: u64 = 0x80000000080001; // max use: 1<<55
// const P1: u64 = 0x80000000130001; // const P1: u64 = 0x80000000130001;
// const P0: u64 = ((1u128 << 64) - (1u128 << 28) + 1u128) as u64;
// const P1: u64 = ((1u128 << 64) - (1u128 << 27) + 1u128) as u64;
// const P0: u64 = ((1u128 << 64) - (1u128 << 28) + 1u128) as u64;
// const P1: u64 = ((1u128 << 64) - (1u128 << 27) + 1u128) as u64;
// const P2: u64 = (1 << 60) - (1 << 26) + 1;
#[derive(Debug)] #[derive(Debug)]
pub struct NTT {} pub struct NTT {}
impl NTT { impl NTT {
pub fn ntt(
n: usize,
a: &Vec<u64>,
) -> (
Vec<u64>,
Vec<u64>,
// Vec<u64>,
// Vec<u64>,
// Vec<u64>,
// Vec<u64>,
// Vec<u64>,
) {
// TODO ensure that: a_i <P0
pub fn ntt(n: usize, a: &Vec<u64>) -> (Vec<u64>, Vec<u64>) {
// apply modulus p_i // apply modulus p_i
let a_0: Vec<u64> = a.iter().map(|a_i| (a_i % P0 + P0) % P0).collect(); let a_0: Vec<u64> = a.iter().map(|a_i| (a_i % P0 + P0) % P0).collect();
let a_1: Vec<u64> = a.iter().map(|a_i| (a_i % P1 + P1) % P1).collect(); let a_1: Vec<u64> = a.iter().map(|a_i| (a_i % P1 + P1) % P1).collect();
// let a_2: Vec<u64> = a.iter().map(|a_i| (a_i % P2 + P2) % P2).collect();
// let a_3: Vec<u64> = a.iter().map(|a_i| (a_i % P3 + P3) % P3).collect();
// let a_4: Vec<u64> = a.iter().map(|a_i| (a_i % P4 + P4) % P4).collect();
// let a_5: Vec<u64> = a.iter().map(|a_i| (a_i % P5 + P5) % P5).collect();
// let a_6: Vec<u64> = a.iter().map(|a_i| (a_i % P6 + P6) % P6).collect();
let r_0 = NTT_p::ntt(P0, n, &a_0); let r_0 = NTT_p::ntt(P0, n, &a_0);
let r_1 = NTT_p::ntt(P1, n, &a_1); let r_1 = NTT_p::ntt(P1, n, &a_1);
// let r_2 = NTT_p::ntt(P2, n, &a_2);
// let r_3 = NTT_p::ntt(P3, n, &a_3);
// let r_4 = NTT_p::ntt(P4, n, &a_4);
// let r_5 = NTT_p::ntt(P5, n, &a_5);
// let r_6 = NTT_p::ntt(P6, n, &a_6);
(r_0, r_1) //, r_2) //, r_3, r_4) //, r_5, r_6)
(r_0, r_1)
} }
pub fn intt(
n: usize,
r: &(
Vec<u64>,
Vec<u64>,
// Vec<u64>,
// Vec<u64>,
// Vec<u64>,
// Vec<u64>,
// Vec<u64>,
),
) -> Vec<u64> {
pub fn intt(n: usize, r: &(Vec<u64>, Vec<u64>)) -> Vec<u64> {
let a_0 = NTT_p::intt(P0, n, &r.0); let a_0 = NTT_p::intt(P0, n, &r.0);
let a_1 = NTT_p::intt(P1, n, &r.1); let a_1 = NTT_p::intt(P1, n, &r.1);
// let a_2 = NTT_p::intt(P2, n, &r.2);
// let a_3 = NTT_p::intt(P3, n, &r.3);
// let a_4 = NTT_p::intt(P4, n, &r.4);
// let a_5 = NTT_p::intt(P5, n, &r.5);
// let a_6 = NTT_p::intt(P6, n, &r.6);
// Garner CRT for two moduli: combine (r1 mod p1, r2 mod p2) -> Z/(p1*p2)
// let inv_p1_mod_p2: u128 = inv_mod_u64(p1 % p2, p2) as u128;
// const INV_P1_MOD_P2: u128 = 4895217125691974194;
reconstruct(a_0, a_1) //, a_2) // , a_3, a_4) //, a_5, a_6)
reconstruct(a_0, a_1)
} }
} }
fn reconstruct(
a0: Vec<u64>,
a1: Vec<u64>,
// a2: Vec<u64>,
// a_3: Vec<u64>,
// a_4: Vec<u64>,
// a_5: Vec<u64>,
// a_6: Vec<u64>,
) -> Vec<u64> {
// let Q = P0 as u128 * P1 as u128;
/// applies CRT to reconstruct the composite original value. Uses Garner's CRT algorithm for two
/// moduli: combine (r1 mod p1, r2 mod p2) -> Z/(p1*p2)
fn reconstruct(a0: Vec<u64>, a1: Vec<u64>) -> Vec<u64> {
// y_i = q/q_i // y_i = q/q_i
// let y0 = ((u64::MAX as u128 + 1) / P0 as u128);
// let y1 = ((u64::MAX as u128 + 1) / P1 as u128);
// let y2 = ((u64::MAX as u128 + 1) / P2 as u128) as u64;
let y0: u128 = P1 as u128; // N_i =Q/P0 = P1*P2
let y1: u128 = P0 as u128;
// let y1: u128 = P0 as u128 * P2 as u128;
// let y0: u128 = P1 as u128 * P2 as u128; // N_i =Q/P0 = P1*P2
// let y1: u128 = P0 as u128 * P2 as u128;
// let y2: u128 = P0 as u128 * P1 as u128;
// let y0 = (Q / P0 as u128) as u64;
// let y1 = (Q / P1 as u128) as u64;
// let y2 = ((u64::MAX as u128 + 1) / P2 as u128) as u64;
// let y3 = ((u64::MAX as u128 + 1) / P3 as u128) as u64;
// let y4 = ((u64::MAX as u128 + 1) / P4 as u128) as u64;
// let y0: u128 = P1 as u128;
let y1: u128 = P0 as u128; // Q/P1 = P0*P1/P1 = P0
// y_i^-1 mod q_i = z_i // y_i^-1 mod q_i = z_i
let z0: u128 = inv_mod(P0 as u128, y0); // M_i = N_i^-1 mod q_i
// let z0: u128 = inv_mod(P0 as u128, y0); // M_i = N_i^-1 mod q_i
let z1: u128 = inv_mod(P1 as u128, y1); let z1: u128 = inv_mod(P1 as u128, y1);
// let z2: u128 = inv_mod(P2 as u128, y2);
// let y2_inv = inv_mod(P2 as u128, y2);
// let y3_inv = inv_mod(P3 as u128, y3);
// let y4_inv = inv_mod(P4 as u128, y4);
// m1 = q1^-1 mod q2 // m1 = q1^-1 mod q2
// aux = (a2 - a1) * m1 mod q2 // aux = (a2 - a1) * m1 mod q2
// a = a1 + (q1 * m1) * aux // a = a1 + (q1 * m1) * aux
/*
let m1 = inv_mod(P1 as u128, P0 as u128) as u64; // P0^-1 mod P1
let aux: Vec<u64> = itertools::zip_eq(a0.clone(), a1.clone())
.map(|(a0_i, a1_i)| ((a1_i - a0_i) * m1) % P1)
.collect();
let a: Vec<u64> = itertools::zip_eq(a0, aux)
// .map(|(a1_i, aux_i)| a1_i + (P1 * m1) * aux_i)
// .map(|(a0_i, aux_i)| a0_i + (P0 * m1) * aux_i)
.map(|(a0_i, aux_i)| a0_i + ((P0 * m1) % P1) * aux_i)
.collect();
a
*/
let p0: u128 = P0 as u128; let p0: u128 = P0 as u128;
let p1: u128 = P1 as u128; let p1: u128 = P1 as u128;
let a: Vec<u64> = itertools::zip_eq(a0, a1) let a: Vec<u64> = itertools::zip_eq(a0, a1)
@ -144,65 +53,6 @@ fn reconstruct(
.map(|v| v as u64) .map(|v| v as u64)
.collect(); .collect();
a a
// dbg!(a0[0] as u128);
// dbg!(a0[0] as u128 * y0);
// dbg!(a0[0] as u128 * z0);
// dbg!(a0[0] as u128 * y0 * z0);
// let a: Vec<u128> = itertools::multizip((a0, a1, a2))
// .map(|(a0_i, a1_i, a2_i)| {
// a0_i as u128 * y0 * z0 + a1_i as u128 * y1 * z1 + a2_i as u128 * y2 * z2
// })
// .collect();
// dbg!(&a);
// let Q = y2 * P2 as u128;
// let a: Vec<u128> = a.iter().map(|a_i| a_i % Q).collect();
// dbg!(&a);
// let q64 = 1_u128 << 64;
// let a: Vec<u64> = a.iter().map(|a_i| (a_i % q64) as u64).collect();
// a
/*
// x_i*z_i mod q_i
let r0: Vec<u64> = a_0.iter().map(|a_i| ((a_i * z0) % P0) * y0).collect();
let r1: Vec<u64> = a_1.iter().map(|a_i| ((a_i * z1) % P1) * y1).collect();
// let r0: Vec<u64> = a_0.iter().map(|a_i| ((a_i * z0) % P0) * y0).collect();
// let r1: Vec<u64> = a_1.iter().map(|a_i| ((a_i * z1) % P1) * y1).collect();
// let r2: Vec<u64> = a_2.iter().map(|a_i| ((a_i * y2_inv) % P2) * y2).collect();
// let r3: Vec<u64> = a_3.iter().map(|a_i| ((a_i * y3_inv) % P3) * y3).collect();
// let r4: Vec<u64> = a_4.iter().map(|a_i| ((a_i * y4_inv) % P4) * y4).collect();
let r: Vec<u64> = itertools::multizip((r0.iter(), r1.iter()))
.map(|(a, b)| a + b)
.collect();
// let r = r0;
//
dbg!(&r);
let p1p2: u128 = (P0 as u128) * (P1 as u128);
// let p1p2_inv: u128 = inv_mod((P0 % P1) as u128, P1) as u128;
let p1p2_inv: u128 = inv_mod((P0) as u128, P1) as u128;
dbg!(&p1p2);
dbg!(&p1p2_inv);
// let p1p2: u128 = P0 as u128 / 2; // PIHALF
let r = r
.iter()
.map(|c_i_u64| {
let c_i = *c_i_u64 as u128;
if c_i * 2 >= p1p2 {
// if c_i >= p1p2 {
c_i.wrapping_sub(p1p2) as u64
} else {
c_i as u64
}
})
.collect();
// let r: Vec<u64> = itertools::multizip((r0.iter(), r1.iter(), r2.iter(), r3.iter(), r4.iter()))
// .map(|(a, b, c, d, e)| a + b + c + d + e)
// .collect();
// let mut r = a_0 + y0_inv + a_1 * y1_inv + a_2 * y2_inv + a_3 * y3_inv + a_4 * y4_inv;
r
*/
} }
fn exp_mod(q: u128, x: u128, k: u128) -> u128 { fn exp_mod(q: u128, x: u128, k: u128) -> u128 {
@ -230,29 +80,19 @@ fn inv_mod(q: u128, x: u128) -> u128 {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use rand_distr::Distribution;
use anyhow::Result; use anyhow::Result;
#[test] #[test]
fn test_dbg() -> Result<()> { fn test_dbg() -> Result<()> {
println!("{}", 1u128 << 64);
let n: usize = 16; let n: usize = 16;
println!("{}", P0);
println!("{}", P1);
// let q = 1u128 << 64; // let q = 1u128 << 64;
// assert!(P0 as u128 * P1 as u128 > (n as u128 * (q * q)) / 2); // assert!(P0 as u128 * P1 as u128 > (n as u128 * (q * q)) / 2);
// let a: Vec<u64> = vec![1u64, 2, 3, 4];
// let a: Vec<u64> = vec![1u64, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
// let a: Vec<u64> = vec![9u64, 8, 7, 6, 0, 9999, 0, 0, 0, 0, 0, 0, 6, 7, 8, 9];
use rand::Rng; use rand::Rng;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let a: Vec<u64> = (0..n)
// .map(|_| rng.gen_range(0..=(1u64 << 57) - 1) - (1u64 << 56))
.map(|_| rng.gen_range(0..(1 << 62)))
.collect();
let a: Vec<u64> = (0..n).map(|_| rng.gen_range(0..(1 << 62))).collect();
dbg!(a.len()); dbg!(a.len());

+ 132
- 208
arith/src/ntt_u64.rs

@ -14,34 +14,22 @@ use crate::ntt::NTT as NTT_p;
// const P0: u64 = 0x80000000080001; // max use 1<<55 -1 // const P0: u64 = 0x80000000080001; // max use 1<<55 -1
// const P1: u64 = 0x80000000130001; // const P1: u64 = 0x80000000130001;
// const P0: u64 = ((1u128 << 64) - (1u128 << 28) + 1u128) as u64;
// const P1: u64 = ((1u128 << 64) - (1u128 << 27) + 1u128) as u64;
// const P2: u64 = ((1u128 << 64) - (1u128 << 26) + 1u128) as u64;
const P0: u64 = ((1u128 << 32) - (1u128 << 18) + 1u128) as u64;
const P1: u64 = ((1u128 << 32) - (1u128 << 17) + 1u128) as u64;
const P2: u64 = ((1u128 << 32) - (1u128 << 16) + 1u128) as u64;
// const P1: u64 = ((1u128 << 64) - (1u128 << 27) + 1u128) as u64;
// const P2: u64 = (1 << 60) - (1 << 26) + 1;
const BITS: u128 = 64;
const P0: u64 = ((1u128 << BITS) - (1u128 << 28) + 1u128) as u64;
const P1: u64 = ((1u128 << BITS) - (1u128 << 27) + 1u128) as u64;
const P2: u64 = ((1u128 << BITS) - (1u128 << 26) + 1u128) as u64;
// const P0: u64 = ((1u128 << BITS) - (1u128 << 18) + 1u128) as u64;
// const P1: u64 = ((1u128 << BITS) - (1u128 << 17) + 1u128) as u64;
// const P2: u64 = ((1u128 << BITS) - (1u128 << 16) + 1u128) as u64;
// const P0: u64 = 0x0FFFFFFF0000001; // 56 bits each P_i
// const P1: u64 = 0x0FFFFFFE8000001;
// const P2: u64 = 0x0FFFFFFE4000001;
#[derive(Debug)] #[derive(Debug)]
pub struct NTT {} pub struct NTT {}
impl NTT { impl NTT {
pub fn ntt(
n: usize,
a: &Vec<u64>,
) -> (
Vec<u64>,
Vec<u64>,
Vec<u64>,
// Vec<u64>,
// Vec<u64>,
// Vec<u64>,
// Vec<u64>,
) {
// TODO ensure that: a_i <P0
pub fn ntt(n: usize, a: &Vec<u64>) -> (Vec<u64>, Vec<u64>, Vec<u64>) {
// apply modulus p_i // apply modulus p_i
let a_0: Vec<u64> = a.iter().map(|a_i| a_i % P0).collect(); let a_0: Vec<u64> = a.iter().map(|a_i| a_i % P0).collect();
let a_1: Vec<u64> = a.iter().map(|a_i| a_i % P1).collect(); let a_1: Vec<u64> = a.iter().map(|a_i| a_i % P1).collect();
@ -49,83 +37,45 @@ impl NTT {
// let a_0: Vec<u64> = a.iter().map(|a_i| (a_i % P0 + P0) % P0).collect(); // let a_0: Vec<u64> = a.iter().map(|a_i| (a_i % P0 + P0) % P0).collect();
// let a_1: Vec<u64> = a.iter().map(|a_i| (a_i % P1 + P1) % P1).collect(); // let a_1: Vec<u64> = a.iter().map(|a_i| (a_i % P1 + P1) % P1).collect();
// let a_2: Vec<u64> = a.iter().map(|a_i| (a_i % P2 + P2) % P2).collect(); // let a_2: Vec<u64> = a.iter().map(|a_i| (a_i % P2 + P2) % P2).collect();
// let a_3: Vec<u64> = a.iter().map(|a_i| (a_i % P3 + P3) % P3).collect();
// let a_4: Vec<u64> = a.iter().map(|a_i| (a_i % P4 + P4) % P4).collect();
// let a_5: Vec<u64> = a.iter().map(|a_i| (a_i % P5 + P5) % P5).collect();
// let a_6: Vec<u64> = a.iter().map(|a_i| (a_i % P6 + P6) % P6).collect();
dbg!(&a_0, &a_1, &a_2); dbg!(&a_0, &a_1, &a_2);
let r_0 = NTT_p::ntt(P0, n, &a_0); let r_0 = NTT_p::ntt(P0, n, &a_0);
let r_1 = NTT_p::ntt(P1, n, &a_1); let r_1 = NTT_p::ntt(P1, n, &a_1);
let r_2 = NTT_p::ntt(P2, n, &a_2); let r_2 = NTT_p::ntt(P2, n, &a_2);
// let r_3 = NTT_p::ntt(P3, n, &a_3);
// let r_4 = NTT_p::ntt(P4, n, &a_4);
// let r_5 = NTT_p::ntt(P5, n, &a_5);
// let r_6 = NTT_p::ntt(P6, n, &a_6);
dbg!(&r_0, &r_1, &r_2);
(r_0, r_1, r_2) //, r_3, r_4) //, r_5, r_6)
(r_0, r_1, r_2)
} }
pub fn intt(
n: usize,
r: &(
Vec<u64>,
Vec<u64>,
Vec<u64>,
// Vec<u64>,
// Vec<u64>,
// Vec<u64>,
// Vec<u64>,
),
) -> Vec<u64> {
pub fn intt(n: usize, r: &(Vec<u64>, Vec<u64>, Vec<u64>)) -> Vec<u64> {
let a_0 = NTT_p::intt(P0, n, &r.0); let a_0 = NTT_p::intt(P0, n, &r.0);
let a_1 = NTT_p::intt(P1, n, &r.1); let a_1 = NTT_p::intt(P1, n, &r.1);
let a_2 = NTT_p::intt(P2, n, &r.2); let a_2 = NTT_p::intt(P2, n, &r.2);
// let a_3 = NTT_p::intt(P3, n, &r.3);
// let a_4 = NTT_p::intt(P4, n, &r.4);
// let a_5 = NTT_p::intt(P5, n, &r.5);
// let a_6 = NTT_p::intt(P6, n, &r.6);
// Garner CRT for two moduli: combine (r1 mod p1, r2 mod p2) -> Z/(p1*p2)
// let inv_p1_mod_p2: u128 = inv_mod_u64(p1 % p2, p2) as u128;
// const INV_P1_MOD_P2: u128 = 4895217125691974194;
dbg!(&a_0, &a_1, &a_2);
// TODO WIP
// let a_0: Vec<u64> = a_0.iter().map(|a_i| a_i % P0).collect();
// let a_1: Vec<u64> = a_1.iter().map(|a_i| a_i % P1).collect();
// let a_2: Vec<u64> = a_2.iter().map(|a_i| a_i % P2).collect();
reconstruct(a_0, a_1, a_2) // , a_3, a_4) //, a_5, a_6)
reconstruct(a_0, a_1, a_2)
} }
} }
fn reconstruct(
a0: Vec<u64>,
a1: Vec<u64>,
a2: Vec<u64>,
// a_3: Vec<u64>,
// a_4: Vec<u64>,
// a_5: Vec<u64>,
// a_6: Vec<u64>,
) -> Vec<u64> {
/// applies CRT to reconstruct the composite original value
fn reconstruct(a0: Vec<u64>, a1: Vec<u64>, a2: Vec<u64>) -> Vec<u64> {
let p0: u128 = P0 as u128;
let p1: u128 = P1 as u128;
let p2: u128 = P2 as u128;
// y_i = q/q_i // y_i = q/q_i
// let y0 = ((u64::MAX as u128 + 1) / P0 as u128);
// let y1 = ((u64::MAX as u128 + 1) / P1 as u128);
// let y2 = ((u64::MAX as u128 + 1) / P2 as u128) as u64;
// let y0: u128 = P1 as u128; // N_i =Q/P0 = P1*P2
// let y1: u128 = P0 as u128;
let y0: u128 = P1 as u128 * P2 as u128; // N_i =Q/P0 = P1*P2
let y1: u128 = P0 as u128 * P2 as u128;
let y2: u128 = P0 as u128 * P1 as u128;
// let y0 = (Q / P0 as u128) as u64;
// let y1 = (Q / P1 as u128) as u64;
// let y2 = ((u64::MAX as u128 + 1) / P2 as u128) as u64;
// let y3 = ((u64::MAX as u128 + 1) / P3 as u128) as u64;
// let y4 = ((u64::MAX as u128 + 1) / P4 as u128) as u64;
let y0: u128 = p1 * p2; // N_i =Q/P0 = P1*P2
let y1: u128 = p0 * p2;
let y2: u128 = p0 * p1;
// y_i^-1 mod q_i = z_i // y_i^-1 mod q_i = z_i
dbg!(P0, y0); dbg!(P0, y0);
let z0: u128 = inv_mod(P0 as u128, y0); // M_i = N_i^-1 mod q_i
let z1: u128 = inv_mod(P1 as u128, y1);
let z2: u128 = inv_mod(P2 as u128, y2);
// let y2_inv = inv_mod(P2 as u128, y2);
// let y3_inv = inv_mod(P3 as u128, y3);
// let y4_inv = inv_mod(P4 as u128, y4);
let z0: u128 = inv_mod(p0, y0); // M_i = N_i^-1 mod q_i
let z1: u128 = inv_mod(p1, y1);
let z2: u128 = inv_mod(p2, y2);
// m1 = q1^-1 mod q2 // m1 = q1^-1 mod q2
// aux = (a2 - a1) * m1 mod q2 // aux = (a2 - a1) * m1 mod q2
@ -136,9 +86,6 @@ fn reconstruct(
// let aux: Vec<u64> = itertools::zip_eq(a0.clone(), a1.clone()) // let aux: Vec<u64> = itertools::zip_eq(a0.clone(), a1.clone())
// .map(|(a0_i, a1_i)| ((a1_i - a0_i) * m1) % P1) // .map(|(a0_i, a1_i)| ((a1_i - a0_i) * m1) % P1)
// .collect(); // .collect();
let p0: u128 = P0 as u128;
let p1: u128 = P1 as u128;
let p2: u128 = P2 as u128;
// let a: Vec<u64> = itertools::zip_eq(a0, a1) // let a: Vec<u64> = itertools::zip_eq(a0, a1)
// // .map(|(a1_i, aux_i)| a1_i + (P1 * m1) * aux_i) // // .map(|(a1_i, aux_i)| a1_i + (P1 * m1) * aux_i)
// // .map(|(a0_i, aux_i)| a0_i + (P0 * m1) * aux_i) // // .map(|(a0_i, aux_i)| a0_i + (P0 * m1) * aux_i)
@ -158,16 +105,31 @@ fn reconstruct(
// dbg!(a0[0] as u128 * y0 * z0); // dbg!(a0[0] as u128 * y0 * z0);
dbg!(&y0, &y1, &y2); dbg!(&y0, &y1, &y2);
dbg!(&z0, &z1, &z2); dbg!(&z0, &z1, &z2);
let Q = P0 as u128 * P1 as u128 * P2 as u128;
// WIP, using BigUint to use Q with 192 bits (product of 3 u64)
let Q = BigUint::from_u64(P0).unwrap()
* BigUint::from_u64(P1).unwrap()
* BigUint::from_u64(P2).unwrap();
let max_u64 = BigUint::from_u128(1_u128 << 64).unwrap();
// let Q = P0 as u128 * P1 as u128 * P2 as u128;
let a: Vec<u64> = itertools::multizip((a0, a1, a2)) let a: Vec<u64> = itertools::multizip((a0, a1, a2))
.map(|(a0_i, a1_i, a2_i)| { .map(|(a0_i, a1_i, a2_i)| {
(a0_i as u128 * y0 * z0)// % Q
+ (a1_i as u128 * y1 * z1)// % Q
+ (a2_i as u128 * y2 * z2) // % Q
// (a0_i as u128 * y0 * z0)// % Q
// + (a1_i as u128 * y1 * z1)// % Q
// + (a2_i as u128 * y2 * z2) // % Q
mul_3_big(a0_i as u128, y0, z0, &Q) // TODO rm %Q, since it returns a BigUint, and %Q
// is done later at the end of the additions
+ mul_3_big(a1_i as u128, y1, z1, &Q)
+ mul_3_big(a2_i as u128, y2, z2, &Q)
})
.map(|v| {
// WIP, using BigUint to use Q with 192 bits (product of 3 u64)
// ((BigUint::from_u128(v).unwrap() % Q.clone()) % max_u64.clone())
((v % Q.clone()) % max_u64.clone()).to_u64().unwrap()
}) })
.map(|v| v % Q)
.map(|v| v as u32)
.map(|v| v as u64)
// .map(|v| v % (1 << 63))
// .map(|v| v % Q)
// .map(|v| v as u32)
// .map(|v| v as u64)
.collect(); .collect();
a a
// dbg!(&a); // dbg!(&a);
@ -177,49 +139,6 @@ fn reconstruct(
// let q64 = 1_u128 << 64; // let q64 = 1_u128 << 64;
// let a: Vec<u64> = a.iter().map(|a_i| (a_i % q64) as u64).collect(); // let a: Vec<u64> = a.iter().map(|a_i| (a_i % q64) as u64).collect();
// a // a
/*
// x_i*z_i mod q_i
let r0: Vec<u64> = a_0.iter().map(|a_i| ((a_i * z0) % P0) * y0).collect();
let r1: Vec<u64> = a_1.iter().map(|a_i| ((a_i * z1) % P1) * y1).collect();
// let r0: Vec<u64> = a_0.iter().map(|a_i| ((a_i * z0) % P0) * y0).collect();
// let r1: Vec<u64> = a_1.iter().map(|a_i| ((a_i * z1) % P1) * y1).collect();
// let r2: Vec<u64> = a_2.iter().map(|a_i| ((a_i * y2_inv) % P2) * y2).collect();
// let r3: Vec<u64> = a_3.iter().map(|a_i| ((a_i * y3_inv) % P3) * y3).collect();
// let r4: Vec<u64> = a_4.iter().map(|a_i| ((a_i * y4_inv) % P4) * y4).collect();
let r: Vec<u64> = itertools::multizip((r0.iter(), r1.iter()))
.map(|(a, b)| a + b)
.collect();
// let r = r0;
//
dbg!(&r);
let p1p2: u128 = (P0 as u128) * (P1 as u128);
// let p1p2_inv: u128 = inv_mod((P0 % P1) as u128, P1) as u128;
let p1p2_inv: u128 = inv_mod((P0) as u128, P1) as u128;
dbg!(&p1p2);
dbg!(&p1p2_inv);
// let p1p2: u128 = P0 as u128 / 2; // PIHALF
let r = r
.iter()
.map(|c_i_u64| {
let c_i = *c_i_u64 as u128;
if c_i * 2 >= p1p2 {
// if c_i >= p1p2 {
c_i.wrapping_sub(p1p2) as u64
} else {
c_i as u64
}
})
.collect();
// let r: Vec<u64> = itertools::multizip((r0.iter(), r1.iter(), r2.iter(), r3.iter(), r4.iter()))
// .map(|(a, b, c, d, e)| a + b + c + d + e)
// .collect();
// let mut r = a_0 + y0_inv + a_1 * y1_inv + a_2 * y2_inv + a_3 * y3_inv + a_4 * y4_inv;
r
*/
} }
fn exp_mod(q: u128, x: u128, k: u128) -> u128 { fn exp_mod(q: u128, x: u128, k: u128) -> u128 {
@ -239,16 +158,16 @@ fn exp_mod(q: u128, x: u128, k: u128) -> u128 {
r r
} }
/// returns x^-1 mod Q, assuming x and Q are coprime, generally Q is prime /// returns x^-1 mod Q, assuming x and Q are coprime, generally Q is prime
// fn inv_mod(q: u128, x: u128) -> u128 {
// // by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q
// // exp_mod(q, x, q - 2)
// exp_mod(q, x, q - 2)
// }
fn inv_mod_new(q: u128, x: u128) -> u128 {
// by Fermat's Little Theorem, x^-1 mod q \equiv x^{q-2} mod q
// exp_mod(q, x, q - 2)
exp_mod(q, x, q - 2)
}
fn inv_mod(m: u128, a: u128) -> u128 { fn inv_mod(m: u128, a: u128) -> u128 {
// if m == 1 {
// return Some(0);
// }
if m == 1 {
panic!("m==1");
}
let mut m = m.clone(); let mut m = m.clone();
let mut a = a.clone(); let mut a = a.clone();
@ -274,39 +193,22 @@ fn inv_mod(m: u128, a: u128) -> u128 {
x1 as u128 x1 as u128
} }
// fn inv_mod(m: u128, a: u128) -> u128 {
// let mut m = m.clone();
// let mut a = a.clone();
// let m0 = m.clone();
// let mut x0 = 0;
// let mut x1 = 1;
//
// if m == 1 {
// return 0;
// }
//
// while a > 1 {
// let q = a / m;
//
// let mut t = m.clone();
//
// m = a % m;
// a = t.clone();
//
// t = x0.clone();
//
// x0 = x1 - q * x0;
//
// x1 = t.clone();
// }
//
// // if (x1 < 0) {
// // x1 = x1 + m0
// // }
//
// // return x1 % m0;
// return x1 % m0;
// }
use num_bigint::BigUint;
use num_traits::{FromPrimitive, ToPrimitive};
fn mul_3_big(a: u128, b: u128, c: u128, Q: &BigUint) -> BigUint {
let r = (BigUint::from_u128(a).unwrap()
* BigUint::from_u128(b).unwrap()
* BigUint::from_u128(c).unwrap());
// % Q;
dbg!(&r);
let r = r % Q;
dbg!(&r);
r
// let max_u64 = BigUint::from_u128(1_u128 << 64).unwrap();
// dbg!(&r % &max_u64);
// (r % max_u64).to_u128().unwrap()
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@ -317,13 +219,36 @@ mod tests {
#[test] #[test]
fn test_inv_mod() -> Result<()> { fn test_inv_mod() -> Result<()> {
// test vectors used in this test generated in SageMath
let p: u64 = 0x0FFFFFFF0000001;
let x = 3;
let x_inv = inv_mod(p as u128, x);
assert_eq!(x_inv, 48038395846328321);
let x = (1_u128 << 64) - 1;
let x_inv = inv_mod(p as u128, x);
assert_eq!(x_inv, 25433122709808565);
let x = 1_u128 << 120;
let x_inv = inv_mod(p as u128, x);
assert_eq!(x_inv, 281474976710656);
// other prime
let p: u64 = ((1u128 << 64) - (1u128 << 28) + 1u128) as u64;
let x = 3; let x = 3;
let x_inv = inv_mod(P0 as u128, x);
dbg!(&P0);
dbg!(&x_inv);
let x_inv = inv_mod(p as u128, x);
assert_eq!(x_inv, 12297829382294077441);
// let r = x_inv * x;
// dbg!(&r);
let x = (1_u128 << 64) - 1;
let x_inv = inv_mod(p as u128, x);
assert_eq!(x_inv, 10530692608818076599);
let x = 1_u128 << 120;
let x_inv = inv_mod(p as u128, x);
assert_eq!(x_inv, 18374686616574755070);
Ok(()) Ok(())
} }
@ -332,20 +257,30 @@ mod tests {
fn test_reconstruct() -> Result<()> { fn test_reconstruct() -> Result<()> {
let n: usize = 16; let n: usize = 16;
use num_bigint::BigUint;
use num_traits::{FromPrimitive, ToPrimitive};
let Q = BigUint::from_u64(P0).unwrap()
* BigUint::from_u64(P1).unwrap()
* BigUint::from_u64(P2).unwrap();
let q = BigUint::from_u128(1_u128 << BITS).unwrap();
let N = BigUint::from_usize(n).unwrap();
let big2 = BigUint::from_u64(2).unwrap();
assert!(Q > (N * (&q * &q)) / big2); // sanity check
use rand::Rng; use rand::Rng;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let a: Vec<u64> = (0..n) let a: Vec<u64> = (0..n)
// .map(|_| rng.gen_range(0..(1 << 64)))
.map(|_| rng.gen_range(0..(1 << 32)))
// .map(|_| rng.gen_range(0..16))
// .map(|_| rng.sample(rand::distributions::Standard))
.map(|_| rng.gen_range(0..((1u128 << 64) - 1) as u64))
.collect(); .collect();
dbg!(a.len());
// let a = vec![14713100818624219214];
dbg!(&a);
let a_0: Vec<u64> = a.iter().map(|a_i| a_i % P0).collect(); let a_0: Vec<u64> = a.iter().map(|a_i| a_i % P0).collect();
let a_1: Vec<u64> = a.iter().map(|a_i| a_i % P1).collect(); let a_1: Vec<u64> = a.iter().map(|a_i| a_i % P1).collect();
let a_2: Vec<u64> = a.iter().map(|a_i| a_i % P2).collect(); let a_2: Vec<u64> = a.iter().map(|a_i| a_i % P2).collect();
// let a_0: Vec<u64> = a.iter().map(|a_i| (a_i % P0 + P0) % P0).collect();
// let a_1: Vec<u64> = a.iter().map(|a_i| (a_i % P1 + P1) % P1).collect();
// let a_2: Vec<u64> = a.iter().map(|a_i| (a_i % P2 + P2) % P2).collect();
dbg!(&a_0, &a_1, &a_2); dbg!(&a_0, &a_1, &a_2);
let a_reconstructed = reconstruct(a_0, a_1, a_2); let a_reconstructed = reconstruct(a_0, a_1, a_2);
@ -356,12 +291,12 @@ mod tests {
} }
#[test] #[test]
fn test_dbg() -> Result<()> {
println!("{}", 1u128 << 64);
let n: usize = 2;
fn test_ntt() -> Result<()> {
// println!("{}", 1u128 << 64);
let n: usize = 1;
println!("{}", P0);
println!("{}", P1);
// println!("{}", P0);
// println!("{}", P1);
// let q = 1u128 << 64; // let q = 1u128 << 64;
// assert!(P0 as u128 * P1 as u128 > (n as u128 * (q * q)) / 2); // assert!(P0 as u128 * P1 as u128 > (n as u128 * (q * q)) / 2);
@ -371,26 +306,15 @@ mod tests {
use rand::Rng; use rand::Rng;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let a: Vec<u64> = (0..n) let a: Vec<u64> = (0..n)
// .map(|_| rng.gen_range(0..(1 << 64)))
// .map(|_| rng.gen_range(0..(1 << 32)))
.map(|_| rng.gen_range(0..16))
// .map(|_| rng.sample(rand::distributions::Standard))
.map(|_| rng.gen_range(0..((1u128 << 64) - 1) as u64))
.collect(); .collect();
dbg!(a.len());
dbg!(&a);
let a_0: Vec<u64> = a.iter().map(|a_i| a_i % P0).collect();
let a_1: Vec<u64> = a.iter().map(|a_i| a_i % P1).collect();
let a_2: Vec<u64> = a.iter().map(|a_i| a_i % P2).collect();
dbg!(&a_0, &a_1, &a_2);
// let a_0 = vec![3];
// let a_1 = vec![3];
// let a_2 = vec![3];
let a_intt = reconstruct(a_0, a_1, a_2);
// let a_ntt = NTT::ntt(n, &a);
// dbg!(&a_ntt);
//
// let a_intt = NTT::intt(n, &a_ntt);
let a_ntt = NTT::ntt(n, &a);
dbg!(&a_ntt);
let a_intt = NTT::intt(n, &a_ntt);
dbg!(&a_intt); dbg!(&a_intt);
assert_eq!(a_intt, a); assert_eq!(a_intt, a);

Loading…
Cancel
Save