From 08d448e204afcff967eb59ec18d54068329e40f3 Mon Sep 17 00:00:00 2001 From: zhenfei Date: Mon, 23 May 2022 11:52:15 -0400 Subject: [PATCH] optimized interpolation (#28) --- poly-iop/src/sum_check/verifier.rs | 150 ++++++++++++++++++++--------- 1 file changed, 106 insertions(+), 44 deletions(-) diff --git a/poly-iop/src/sum_check/verifier.rs b/poly-iop/src/sum_check/verifier.rs index 63cd6cb..2cf6fc3 100644 --- a/poly-iop/src/sum_check/verifier.rs +++ b/poly-iop/src/sum_check/verifier.rs @@ -181,67 +181,129 @@ impl SumCheckVerifier for IOPVerifierState { fn interpolate_uni_poly(p_i: &[F], eval_at: F) -> Result { let start = start_timer!(|| "sum check interpolate uni poly opt"); - let mut res = F::zero(); - - // compute - // - prod = \prod (eval_at - j) - // - evals = [eval_at - j] - let mut evals = vec![]; let len = p_i.len(); + let mut evals = vec![]; let mut prod = eval_at; evals.push(eval_at); + // `prod = \prod_{j} (eval_at - j)` for e in 1..len { let tmp = eval_at - F::from(e as u64); evals.push(tmp); prod *= tmp; } + let mut res = F::zero(); + // we want to compute \prod (j!=i) (i-j) for a given i + // + // we start from the last step, which is + // denom[len-1] = (len-1) * (len-2) *... * 2 * 1 + // the step before that is + // denom[len-2] = (len-2) * (len-3) * ... * 2 * 1 * -1 + // and the step before that is + // denom[len-3] = (len-3) * (len-4) * ... * 2 * 1 * -1 * -2 + // + // i.e., for any i, the one before this will be derived from + // denom[i-1] = denom[i] * (len-i) / i + // + // that is, we only need to store + // - the last denom for i = len-1, and + // - the ratio between current step and fhe last step, which is the product of + // (len-i) / i from all previous steps and we store this product as a fraction + // number to reduce field divisions. - for i in 0..len { - // res += p_i * prod / (divisor * (eval_at - j)) - let divisor = get_divisor(i, len)?; - let divisor_f = { - if divisor < 0 { - -F::from((-divisor) as u128) + // We know + // - 2^61 < factorial(20) < 2^62 + // - 2^122 < factorial(33) < 2^123 + // so we will be able to compute the ratio + // - for len <= 20 with i64 + // - for len <= 33 with i128 + // - for len > 33 with BigInt + if p_i.len() <= 20 { + let last_denominator = F::from(u64_factorial(len - 1)); + let mut ratio_numerator = 1i64; + let mut ratio_enumerator = 1u64; + + for i in (0..len).rev() { + let ratio_numerator_f = if ratio_numerator < 0 { + -F::from((-ratio_numerator) as u64) } else { - F::from(divisor as u128) - } - }; - res += p_i[i] * prod / (divisor_f * evals[i]); - } + F::from(ratio_numerator as u64) + }; - end_timer!(start); - Ok(res) -} + res += p_i[i] * prod * F::from(ratio_enumerator) + / (last_denominator * ratio_numerator_f * evals[i]); -/// Compute \prod_{j!=i)^len (i-j). This function takes O(n^2) number of -/// primitive operations which is negligible compared to field operations. -// We know -// - factorial(20) ~ 2^61 -// - factorial(33) ~ 2^123 -// so we will be able to store the result for len<=20 with i64; -// for len<=33 with i128; and we do not currently support len>33. -#[inline] -fn get_divisor(i: usize, len: usize) -> Result { - if len <= 20 { - let mut res = 1i64; - for j in 0..len { - if j != i { - res *= i as i64 - j as i64; + // compute denom for the next step is current_denom * (len-i)/i + if i != 0 { + ratio_numerator *= -(len as i64 - i as i64); + ratio_enumerator *= i as u64; } } - Ok(res as i128) - } else if len <= 33 { - let mut res = 1i128; - for j in 0..len { - if j != i { - res *= i as i128 - j as i128; + } else if p_i.len() <= 33 { + let last_denominator = F::from(u128_factorial(len - 1)); + let mut ratio_numerator = 1i128; + let mut ratio_enumerator = 1u128; + + for i in (0..len).rev() { + let ratio_numerator_f = if ratio_numerator < 0 { + -F::from((-ratio_numerator) as u128) + } else { + F::from(ratio_numerator as u128) + }; + + res += p_i[i] * prod * F::from(ratio_enumerator) + / (last_denominator * ratio_numerator_f * evals[i]); + + // compute denom for the next step is current_denom * (len-i)/i + if i != 0 { + ratio_numerator *= -(len as i128 - i as i128); + ratio_enumerator *= i as u128; } } - Ok(res) } else { - Err(PolyIOPErrors::InvalidParameters( - "Do not support number variable > 33".to_string(), - )) + let mut denom_up = field_factorial::(len - 1); + let mut denom_down = F::one(); + + for i in (0..len).rev() { + res += p_i[i] * prod * denom_down / (denom_up * evals[i]); + + // compute denom for the next step is current_denom * (len-i)/i + if i != 0 { + denom_up *= -F::from((len - i) as u64); + denom_down *= F::from(i as u64); + } + } + } + end_timer!(start); + Ok(res) +} + +/// compute the factorial(a) = 1 * 2 * ... * a +#[inline] +fn field_factorial(a: usize) -> F { + let mut res = 1u64; + for i in 1..=a { + res *= i as u64; + } + F::from(res) +} + +/// compute the factorial(a) = 1 * 2 * ... * a +#[inline] +fn u128_factorial(a: usize) -> u128 { + let mut res = 1u128; + for i in 1..=a { + res *= i as u128; + } + res +} + +/// compute the factorial(a) = 1 * 2 * ... * a +#[inline] +fn u64_factorial(a: usize) -> u64 { + let mut res = 1u64; + for i in 1..=a { + res *= i as u64; } + res }