@ -61,6 +61,34 @@ impl SumcheckProof {
Ok ( ( e , r ) )
Ok ( ( e , r ) )
}
}
#[ inline ]
fn compute_eval_points < F > (
poly_A : & MultilinearPolynomial < G ::Scalar > ,
poly_B : & MultilinearPolynomial < G ::Scalar > ,
comb_func : & F ,
) -> ( G ::Scalar , G ::Scalar )
where
F : Fn ( & G ::Scalar , & G ::Scalar ) -> G ::Scalar + Sync ,
{
let len = poly_A . len ( ) / 2 ;
( 0 . . len )
. into_par_iter ( )
. map ( | i | {
// eval 0: bound_func is A(low)
let eval_point_0 = comb_func ( & poly_A [ i ] , & poly_B [ i ] ) ;
// eval 2: bound_func is -A(low) + 2*A(high)
let poly_A_bound_point = poly_A [ len + i ] + poly_A [ len + i ] - poly_A [ i ] ;
let poly_B_bound_point = poly_B [ len + i ] + poly_B [ len + i ] - poly_B [ i ] ;
let eval_point_2 = comb_func ( & poly_A_bound_point , & poly_B_bound_point ) ;
( eval_point_0 , eval_point_2 )
} )
. reduce (
| | ( G ::Scalar ::ZERO , G ::Scalar ::ZERO ) ,
| a , b | ( a . 0 + b . 0 , a . 1 + b . 1 ) ,
)
}
pub fn prove_quad < F > (
pub fn prove_quad < F > (
claim : & G ::Scalar ,
claim : & G ::Scalar ,
num_rounds : usize ,
num_rounds : usize ,
@ -77,25 +105,7 @@ impl SumcheckProof {
let mut claim_per_round = * claim ;
let mut claim_per_round = * claim ;
for _ in 0 . . num_rounds {
for _ in 0 . . num_rounds {
let poly = {
let poly = {
let len = poly_A . len ( ) / 2 ;
// Make an iterator returning the contributions to the evaluations
let ( eval_point_0 , eval_point_2 ) = ( 0 . . len )
. into_par_iter ( )
. map ( | i | {
// eval 0: bound_func is A(low)
let eval_point_0 = comb_func ( & poly_A [ i ] , & poly_B [ i ] ) ;
// eval 2: bound_func is -A(low) + 2*A(high)
let poly_A_bound_point = poly_A [ len + i ] + poly_A [ len + i ] - poly_A [ i ] ;
let poly_B_bound_point = poly_B [ len + i ] + poly_B [ len + i ] - poly_B [ i ] ;
let eval_point_2 = comb_func ( & poly_A_bound_point , & poly_B_bound_point ) ;
( eval_point_0 , eval_point_2 )
} )
. reduce (
| | ( G ::Scalar ::ZERO , G ::Scalar ::ZERO ) ,
| a , b | ( a . 0 + b . 0 , a . 1 + b . 1 ) ,
) ;
let ( eval_point_0 , eval_point_2 ) = Self ::compute_eval_points ( poly_A , poly_B , & comb_func ) ;
let evals = vec ! [ eval_point_0 , claim_per_round - eval_point_0 , eval_point_2 ] ;
let evals = vec ! [ eval_point_0 , claim_per_round - eval_point_0 , eval_point_2 ] ;
UniPoly ::from_evals ( & evals )
UniPoly ::from_evals ( & evals )
@ -136,7 +146,7 @@ impl SumcheckProof {
transcript : & mut G ::TE ,
transcript : & mut G ::TE ,
) -> Result < ( Self , Vec < G ::Scalar > , ( Vec < G ::Scalar > , Vec < G ::Scalar > ) ) , NovaError >
) -> Result < ( Self , Vec < G ::Scalar > , ( Vec < G ::Scalar > , Vec < G ::Scalar > ) ) , NovaError >
where
where
F : Fn ( & G ::Scalar , & G ::Scalar ) -> G ::Scalar ,
F : Fn ( & G ::Scalar , & G ::Scalar ) -> G ::Scalar + Sync ,
{
{
let mut e = * claim ;
let mut e = * claim ;
let mut r : Vec < G ::Scalar > = Vec ::new ( ) ;
let mut r : Vec < G ::Scalar > = Vec ::new ( ) ;
@ -146,20 +156,7 @@ impl SumcheckProof {
let mut evals : Vec < ( G ::Scalar , G ::Scalar ) > = Vec ::new ( ) ;
let mut evals : Vec < ( G ::Scalar , G ::Scalar ) > = Vec ::new ( ) ;
for ( poly_A , poly_B ) in poly_A_vec . iter ( ) . zip ( poly_B_vec . iter ( ) ) {
for ( poly_A , poly_B ) in poly_A_vec . iter ( ) . zip ( poly_B_vec . iter ( ) ) {
let mut eval_point_0 = G ::Scalar ::ZERO ;
let mut eval_point_2 = G ::Scalar ::ZERO ;
let len = poly_A . len ( ) / 2 ;
for i in 0 . . len {
// eval 0: bound_func is A(low)
eval_point_0 + = comb_func ( & poly_A [ i ] , & poly_B [ i ] ) ;
// eval 2: bound_func is -A(low) + 2*A(high)
let poly_A_bound_point = poly_A [ len + i ] + poly_A [ len + i ] - poly_A [ i ] ;
let poly_B_bound_point = poly_B [ len + i ] + poly_B [ len + i ] - poly_B [ i ] ;
eval_point_2 + = comb_func ( & poly_A_bound_point , & poly_B_bound_point ) ;
}
let ( eval_point_0 , eval_point_2 ) = Self ::compute_eval_points ( poly_A , poly_B , & comb_func ) ;
evals . push ( ( eval_point_0 , eval_point_2 ) ) ;
evals . push ( ( eval_point_0 , eval_point_2 ) ) ;
}
}