@ -12,9 +12,9 @@ use crate::poly_iop::{
structs ::{ IOPProverMessage , IOPProverState } ,
structs ::{ IOPProverMessage , IOPProverState } ,
} ;
} ;
use arithmetic ::{ fix_variables , VirtualPolynomial } ;
use arithmetic ::{ fix_variables , VirtualPolynomial } ;
use ark_ff ::PrimeField ;
use ark_ff ::{ batch_inversion , PrimeField } ;
use ark_poly ::DenseMultilinearExtension ;
use ark_poly ::DenseMultilinearExtension ;
use ark_std ::{ end_timer , start_timer , vec ::Vec } ;
use ark_std ::{ cfg_into_iter , end_timer , start_timer , vec ::Vec } ;
use rayon ::prelude ::{ IntoParallelIterator , IntoParallelRefIterator } ;
use rayon ::prelude ::{ IntoParallelIterator , IntoParallelRefIterator } ;
use std ::sync ::Arc ;
use std ::sync ::Arc ;
@ -40,6 +40,13 @@ impl SumCheckProver for IOPProverState {
challenges : Vec ::with_capacity ( polynomial . aux_info . num_variables ) ,
challenges : Vec ::with_capacity ( polynomial . aux_info . num_variables ) ,
round : 0 ,
round : 0 ,
poly : polynomial . clone ( ) ,
poly : polynomial . clone ( ) ,
extrapolation_aux : ( 1 . . polynomial . aux_info . max_degree )
. map ( | degree | {
let points = ( 0 . . 1 + degree as u64 ) . map ( F ::from ) . collect ::< Vec < _ > > ( ) ;
let weights = barycentric_weights ( & points ) ;
( points , weights )
} )
. collect ( ) ,
} )
} )
}
}
@ -110,83 +117,56 @@ impl SumCheckProver for IOPProverState {
let products_list = self . poly . products . clone ( ) ;
let products_list = self . poly . products . clone ( ) ;
let mut products_sum = vec ! [ F ::zero ( ) ; self . poly . aux_info . max_degree + 1 ] ;
let mut products_sum = vec ! [ F ::zero ( ) ; self . poly . aux_info . max_degree + 1 ] ;
// let compute_sum = start_timer!(|| "compute sum");
// Step 2: generate sum for the partial evaluated polynomial:
// Step 2: generate sum for the partial evaluated polynomial:
// f(r_1, ... r_m,, x_{m+1}... x_n)
// f(r_1, ... r_m,, x_{m+1}... x_n)
#[ cfg(feature = " parallel " ) ]
{
let flag = ( self . poly . aux_info . max_degree = = 2 )
& & ( products_list . len ( ) = = 1 )
& & ( products_list [ 0 ] . 0 = = F ::one ( ) ) ;
if flag {
for ( t , e ) in products_sum . iter_mut ( ) . enumerate ( ) {
let evals = ( 0 . . 1 < < ( self . poly . aux_info . num_variables - self . round ) )
. into_par_iter ( )
. map ( | b | {
// evaluate P_round(t)
let table0 = & flattened_ml_extensions [ products_list [ 0 ] . 1 [ 0 ] ] ;
let table1 = & flattened_ml_extensions [ products_list [ 0 ] . 1 [ 1 ] ] ;
if t = = 0 {
table0 [ b < < 1 ] * table1 [ b < < 1 ]
} else if t = = 1 {
table0 [ ( b < < 1 ) + 1 ] * table1 [ ( b < < 1 ) + 1 ]
} else {
( table0 [ ( b < < 1 ) + 1 ] + table0 [ ( b < < 1 ) + 1 ] - table0 [ b < < 1 ] )
* ( table1 [ ( b < < 1 ) + 1 ] + table1 [ ( b < < 1 ) + 1 ] - table1 [ b < < 1 ] )
}
} )
. collect ::< Vec < F > > ( ) ;
* e + = evals . par_iter ( ) . sum ::< F > ( ) ;
}
} else {
for ( t , e ) in products_sum . iter_mut ( ) . enumerate ( ) {
let t = F ::from ( t as u128 ) ;
let products = ( 0 . . 1 < < ( self . poly . aux_info . num_variables - self . round ) )
. into_par_iter ( )
. map ( | b | {
// evaluate P_round(t)
let mut tmp = F ::zero ( ) ;
products_list . iter ( ) . for_each ( | ( coefficient , products ) | {
let num_mles = products . len ( ) ;
let mut product = * coefficient ;
for & f in products . iter ( ) . take ( num_mles ) {
let table = & flattened_ml_extensions [ f ] ; // f's range is checked in init
// TODO: Could be done faster by cashing the results from the
// previous t and adding the diff
// Also possible to use Karatsuba multiplication
product * =
table [ b < < 1 ] + ( table [ ( b < < 1 ) + 1 ] - table [ b < < 1 ] ) * t ;
}
tmp + = product ;
products_list . iter ( ) . for_each ( | ( coefficient , products ) | {
let mut sum = cfg_into_iter ! ( 0 . . 1 < < ( self . poly . aux_info . num_variables - self . round ) )
. fold (
| | {
(
vec ! [ ( F ::zero ( ) , F ::zero ( ) ) ; products . len ( ) ] ,
vec ! [ F ::zero ( ) ; products . len ( ) + 1 ] ,
)
} ,
| ( mut buf , mut acc ) , b | {
buf . iter_mut ( )
. zip ( products . iter ( ) )
. for_each ( | ( ( eval , step ) , f ) | {
let table = & flattened_ml_extensions [ * f ] ;
* eval = table [ b < < 1 ] ;
* step = table [ ( b < < 1 ) + 1 ] - table [ b < < 1 ] ;
} ) ;
} ) ;
tmp
} )
. collect ::< Vec < F > > ( ) ;
* e + = products . par_iter ( ) . sum ::< F > ( ) ;
}
}
}
#[ cfg(not(feature = " parallel " )) ]
products_sum . iter_mut ( ) . enumerate ( ) . for_each ( | ( t , e ) | {
let t = F ::from ( t as u64 ) ;
let one_minus_t = F ::one ( ) - t ;
for b in 0 . . 1 < < ( self . poly . aux_info . num_variables - self . round ) {
// evaluate P_round(t)
for ( coefficient , products ) in products_list . iter ( ) {
let num_mles = products . len ( ) ;
let mut product = * coefficient ;
for & f in products . iter ( ) . take ( num_mles ) {
let table = & flattened_ml_extensions [ f ] ; // f's range is checked in init
product * = table [ b < < 1 ] + ( table [ ( b < < 1 ) + 1 ] - table [ b < < 1 ] ) * t ;
}
* e + = product ;
}
}
acc [ 0 ] + = buf . iter ( ) . map ( | ( eval , _ ) | eval ) . product ::< F > ( ) ;
acc [ 1 . . ] . iter_mut ( ) . for_each ( | acc | {
buf . iter_mut ( ) . for_each ( | ( eval , step ) | * eval + = step as & _ ) ;
* acc + = buf . iter ( ) . map ( | ( eval , _ ) | eval ) . product ::< F > ( ) ;
} ) ;
( buf , acc )
} ,
)
. map ( | ( _ , partial ) | partial )
. reduce (
| | vec ! [ F ::zero ( ) ; products . len ( ) + 1 ] ,
| mut sum , partial | {
sum . iter_mut ( )
. zip ( partial . iter ( ) )
. for_each ( | ( sum , partial ) | * sum + = partial ) ;
sum
} ,
) ;
sum . iter_mut ( ) . for_each ( | sum | * sum * = coefficient ) ;
let extraploation = cfg_into_iter ! ( 0 . . self . poly . aux_info . max_degree - products . len ( ) )
. map ( | i | {
let ( points , weights ) = & self . extrapolation_aux [ products . len ( ) - 1 ] ;
let at = F ::from ( ( products . len ( ) + 1 + i ) as u64 ) ;
extrapolate ( points , weights , & sum , & at )
} )
. collect ::< Vec < _ > > ( ) ;
products_sum
. iter_mut ( )
. zip ( sum . iter ( ) . chain ( extraploation . iter ( ) ) )
. for_each ( | ( products_sum , sum ) | * products_sum + = sum ) ;
} ) ;
} ) ;
// update prover's state to the partial evaluated polynomial
// update prover's state to the partial evaluated polynomial
@ -195,10 +175,43 @@ impl SumCheckProver for IOPProverState {
. map ( | x | Arc ::new ( x . clone ( ) ) )
. map ( | x | Arc ::new ( x . clone ( ) ) )
. collect ( ) ;
. collect ( ) ;
// end_timer!(compute_sum);
// end_timer!(start);
Ok ( IOPProverMessage {
Ok ( IOPProverMessage {
evaluations : products_sum ,
evaluations : products_sum ,
} )
} )
}
}
}
}
fn barycentric_weights < F : PrimeField > ( points : & [ F ] ) -> Vec < F > {
let mut weights = points
. iter ( )
. enumerate ( )
. map ( | ( j , point_j ) | {
points
. iter ( )
. enumerate ( )
. filter_map ( | ( i , point_i ) | ( i ! = j ) . then ( | | * point_j - point_i ) )
. reduce ( | acc , value | acc * value )
. unwrap_or_else ( F ::one )
} )
. collect ::< Vec < _ > > ( ) ;
batch_inversion ( & mut weights ) ;
weights
}
fn extrapolate < F : PrimeField > ( points : & [ F ] , weights : & [ F ] , evals : & [ F ] , at : & F ) -> F {
let ( coeffs , sum_inv ) = {
let mut coeffs = points . iter ( ) . map ( | point | * at - point ) . collect ::< Vec < _ > > ( ) ;
batch_inversion ( & mut coeffs ) ;
coeffs . iter_mut ( ) . zip ( weights ) . for_each ( | ( coeff , weight ) | {
* coeff * = weight ;
} ) ;
let sum_inv = coeffs . iter ( ) . sum ::< F > ( ) . inverse ( ) . unwrap_or_default ( ) ;
( coeffs , sum_inv )
} ;
coeffs
. iter ( )
. zip ( evals )
. map ( | ( coeff , eval ) | * coeff * eval )
. sum ::< F > ( )
* sum_inv
}