|
package prover
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math/big"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/dghubble/sling"
|
|
"github.com/hermeznetwork/hermez-node/common"
|
|
"github.com/hermeznetwork/tracerr"
|
|
)
|
|
|
|
// Proof TBD this type will be received from the proof server
|
|
type Proof struct {
|
|
PiA [3]*big.Int `json:"pi_a"`
|
|
PiB [3][2]*big.Int `json:"pi_b"`
|
|
PiC [3]*big.Int `json:"pi_c"`
|
|
Protocol string `json:"protocol"`
|
|
}
|
|
|
|
type bigInt big.Int
|
|
|
|
func (b *bigInt) UnmarshalText(text []byte) error {
|
|
_, ok := (*big.Int)(b).SetString(string(text), 10)
|
|
if !ok {
|
|
return tracerr.Wrap(fmt.Errorf("invalid big int: \"%v\"", string(text)))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UnmarshalJSON unmarshals the proof from a JSON encoded proof with the big
|
|
// ints as strings
|
|
func (p *Proof) UnmarshalJSON(data []byte) error {
|
|
proof := struct {
|
|
PiA [3]*bigInt `json:"pi_a"`
|
|
PiB [3][2]*bigInt `json:"pi_b"`
|
|
PiC [3]*bigInt `json:"pi_c"`
|
|
Protocol string `json:"protocol"`
|
|
}{}
|
|
if err := json.Unmarshal(data, &proof); err != nil {
|
|
return tracerr.Wrap(err)
|
|
}
|
|
p.PiA[0] = (*big.Int)(proof.PiA[0])
|
|
p.PiA[1] = (*big.Int)(proof.PiA[1])
|
|
p.PiA[2] = (*big.Int)(proof.PiA[2])
|
|
if p.PiA[2].Int64() != 1 {
|
|
return tracerr.Wrap(fmt.Errorf("Expected PiA[2] == 1, but got %v", p.PiA[2]))
|
|
}
|
|
p.PiB[0][0] = (*big.Int)(proof.PiB[0][0])
|
|
p.PiB[0][1] = (*big.Int)(proof.PiB[0][1])
|
|
p.PiB[1][0] = (*big.Int)(proof.PiB[1][0])
|
|
p.PiB[1][1] = (*big.Int)(proof.PiB[1][1])
|
|
p.PiB[2][0] = (*big.Int)(proof.PiB[2][0])
|
|
p.PiB[2][1] = (*big.Int)(proof.PiB[2][1])
|
|
if p.PiB[2][0].Int64() != 1 || p.PiB[2][1].Int64() != 0 {
|
|
return tracerr.Wrap(fmt.Errorf("Expected PiB[2] == [1, 0], but got %v", p.PiB[2]))
|
|
}
|
|
p.PiC[0] = (*big.Int)(proof.PiC[0])
|
|
p.PiC[1] = (*big.Int)(proof.PiC[1])
|
|
p.PiC[2] = (*big.Int)(proof.PiC[2])
|
|
if p.PiC[2].Int64() != 1 {
|
|
return tracerr.Wrap(fmt.Errorf("Expected PiC[2] == 1, but got %v", p.PiC[2]))
|
|
}
|
|
p.Protocol = proof.Protocol
|
|
return nil
|
|
}
|
|
|
|
// PublicInputs are the public inputs of the proof
|
|
type PublicInputs []*big.Int
|
|
|
|
// UnmarshalJSON unmarshals the JSON into the public inputs where the bigInts
|
|
// are in decimal as quoted strings
|
|
func (p *PublicInputs) UnmarshalJSON(data []byte) error {
|
|
pubInputs := []*bigInt{}
|
|
if err := json.Unmarshal(data, &pubInputs); err != nil {
|
|
return tracerr.Wrap(err)
|
|
}
|
|
*p = make([]*big.Int, len(pubInputs))
|
|
for i, v := range pubInputs {
|
|
([]*big.Int)(*p)[i] = (*big.Int)(v)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Client is the interface to a ServerProof that calculates zk proofs
|
|
type Client interface {
|
|
// Non-blocking
|
|
CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error
|
|
// Blocking. Returns the Proof and Public Data (public inputs)
|
|
GetProof(ctx context.Context) (*Proof, []*big.Int, error)
|
|
// Non-Blocking
|
|
Cancel(ctx context.Context) error
|
|
// Blocking
|
|
WaitReady(ctx context.Context) error
|
|
}
|
|
|
|
// StatusCode is the status string of the ProofServer
|
|
type StatusCode string
|
|
|
|
const (
|
|
// StatusCodeAborted means prover is ready to take new proof. Previous
|
|
// proof was aborted.
|
|
StatusCodeAborted StatusCode = "aborted"
|
|
// StatusCodeBusy means prover is busy computing proof.
|
|
StatusCodeBusy StatusCode = "busy"
|
|
// StatusCodeFailed means prover is ready to take new proof. Previous
|
|
// proof failed
|
|
StatusCodeFailed StatusCode = "failed"
|
|
// StatusCodeSuccess means prover is ready to take new proof. Previous
|
|
// proof succeeded
|
|
StatusCodeSuccess StatusCode = "success"
|
|
// StatusCodeUnverified means prover is ready to take new proof.
|
|
// Previous proof was unverified
|
|
StatusCodeUnverified StatusCode = "unverified"
|
|
// StatusCodeUninitialized means prover is not initialized
|
|
StatusCodeUninitialized StatusCode = "uninitialized"
|
|
// StatusCodeUndefined means prover is in an undefined state. Most
|
|
// likely is booting up. Keep trying
|
|
StatusCodeUndefined StatusCode = "undefined"
|
|
// StatusCodeInitializing means prover is initializing and not ready yet
|
|
StatusCodeInitializing StatusCode = "initializing"
|
|
// StatusCodeReady means prover initialized and ready to do first proof
|
|
StatusCodeReady StatusCode = "ready"
|
|
)
|
|
|
|
// IsReady returns true when the prover is ready
|
|
func (status StatusCode) IsReady() bool {
|
|
if status == StatusCodeAborted || status == StatusCodeFailed || status == StatusCodeSuccess ||
|
|
status == StatusCodeUnverified || status == StatusCodeReady {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// IsInitialized returns true when the prover is initialized
|
|
func (status StatusCode) IsInitialized() bool {
|
|
if status == StatusCodeUninitialized || status == StatusCodeUndefined ||
|
|
status == StatusCodeInitializing {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
// Status is the return struct for the status API endpoint
|
|
type Status struct {
|
|
Status StatusCode `json:"status"`
|
|
Proof string `json:"proof"`
|
|
PubData string `json:"pubData"`
|
|
}
|
|
|
|
// ErrorServer is the return struct for an API error
|
|
type ErrorServer struct {
|
|
Status StatusCode `json:"status"`
|
|
Message string `json:"msg"`
|
|
}
|
|
|
|
// Error message for ErrorServer
|
|
func (e ErrorServer) Error() string {
|
|
return fmt.Sprintf("server proof status (%v): %v", e.Status, e.Message)
|
|
}
|
|
|
|
type apiMethod string
|
|
|
|
const (
|
|
// GET is an HTTP GET
|
|
GET apiMethod = "GET"
|
|
// POST is an HTTP POST with maybe JSON body
|
|
POST apiMethod = "POST"
|
|
)
|
|
|
|
// ProofServerClient contains the data related to a ProofServerClient
|
|
type ProofServerClient struct {
|
|
URL string
|
|
client *sling.Sling
|
|
pollInterval time.Duration
|
|
}
|
|
|
|
// NewProofServerClient creates a new ServerProof
|
|
func NewProofServerClient(URL string, pollInterval time.Duration) *ProofServerClient {
|
|
if URL[len(URL)-1] != '/' {
|
|
URL += "/"
|
|
}
|
|
client := sling.New().Base(URL)
|
|
return &ProofServerClient{URL: URL, client: client, pollInterval: pollInterval}
|
|
}
|
|
|
|
func (p *ProofServerClient) apiRequest(ctx context.Context, method apiMethod, path string,
|
|
body interface{}, ret interface{}) error {
|
|
path = strings.TrimPrefix(path, "/")
|
|
var errSrv ErrorServer
|
|
var req *http.Request
|
|
var err error
|
|
switch method {
|
|
case GET:
|
|
req, err = p.client.New().Get(path).Request()
|
|
case POST:
|
|
req, err = p.client.New().Post(path).BodyJSON(body).Request()
|
|
default:
|
|
return tracerr.Wrap(fmt.Errorf("invalid http method: %v", method))
|
|
}
|
|
if err != nil {
|
|
return tracerr.Wrap(err)
|
|
}
|
|
res, err := p.client.Do(req.WithContext(ctx), ret, &errSrv)
|
|
if err != nil {
|
|
return tracerr.Wrap(err)
|
|
}
|
|
defer res.Body.Close() //nolint:errcheck
|
|
if !(200 <= res.StatusCode && res.StatusCode < 300) {
|
|
return tracerr.Wrap(errSrv)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *ProofServerClient) apiStatus(ctx context.Context) (*Status, error) {
|
|
var status Status
|
|
return &status, tracerr.Wrap(p.apiRequest(ctx, GET, "/status", nil, &status))
|
|
}
|
|
|
|
func (p *ProofServerClient) apiCancel(ctx context.Context) error {
|
|
return tracerr.Wrap(p.apiRequest(ctx, POST, "/cancel", nil, nil))
|
|
}
|
|
|
|
func (p *ProofServerClient) apiInput(ctx context.Context, zkInputs *common.ZKInputs) error {
|
|
return tracerr.Wrap(p.apiRequest(ctx, POST, "/input", zkInputs, nil))
|
|
}
|
|
|
|
// CalculateProof sends the *common.ZKInputs to the ServerProof to compute the
|
|
// Proof
|
|
func (p *ProofServerClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error {
|
|
return tracerr.Wrap(p.apiInput(ctx, zkInputs))
|
|
}
|
|
|
|
// GetProof retreives the Proof and Public Data (public inputs) from the
|
|
// ServerProof, blocking until the proof is ready.
|
|
func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, []*big.Int, error) {
|
|
if err := p.WaitReady(ctx); err != nil {
|
|
return nil, nil, tracerr.Wrap(err)
|
|
}
|
|
status, err := p.apiStatus(ctx)
|
|
if err != nil {
|
|
return nil, nil, tracerr.Wrap(err)
|
|
}
|
|
if status.Status == StatusCodeSuccess {
|
|
var proof Proof
|
|
if err := json.Unmarshal([]byte(status.Proof), &proof); err != nil {
|
|
return nil, nil, tracerr.Wrap(err)
|
|
}
|
|
var pubInputs PublicInputs
|
|
if err := json.Unmarshal([]byte(status.PubData), &pubInputs); err != nil {
|
|
return nil, nil, tracerr.Wrap(err)
|
|
}
|
|
return &proof, pubInputs, nil
|
|
}
|
|
return nil, nil, tracerr.Wrap(fmt.Errorf("status != %v, status = %v", StatusCodeSuccess, status.Status))
|
|
}
|
|
|
|
// Cancel cancels any current proof computation
|
|
func (p *ProofServerClient) Cancel(ctx context.Context) error {
|
|
return tracerr.Wrap(p.apiCancel(ctx))
|
|
}
|
|
|
|
// WaitReady waits until the serverProof is ready
|
|
func (p *ProofServerClient) WaitReady(ctx context.Context) error {
|
|
for {
|
|
status, err := p.apiStatus(ctx)
|
|
if err != nil {
|
|
return tracerr.Wrap(err)
|
|
}
|
|
if !status.Status.IsInitialized() {
|
|
return tracerr.Wrap(fmt.Errorf("Proof Server is not initialized"))
|
|
}
|
|
if status.Status.IsReady() {
|
|
return nil
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return tracerr.Wrap(common.ErrDone)
|
|
case <-time.After(p.pollInterval):
|
|
}
|
|
}
|
|
}
|
|
|
|
// MockClient is a mock ServerProof to be used in tests. It doesn't calculate anything
|
|
type MockClient struct {
|
|
counter int64
|
|
Delay time.Duration
|
|
}
|
|
|
|
// CalculateProof sends the *common.ZKInputs to the ServerProof to compute the
|
|
// Proof
|
|
func (p *MockClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error {
|
|
return nil
|
|
}
|
|
|
|
// GetProof retreives the Proof from the ServerProof
|
|
func (p *MockClient) GetProof(ctx context.Context) (*Proof, []*big.Int, error) {
|
|
// Simulate a delay
|
|
select {
|
|
case <-time.After(p.Delay): //nolint:gomnd
|
|
i := p.counter * 100 //nolint:gomnd
|
|
p.counter++
|
|
return &Proof{
|
|
PiA: [3]*big.Int{
|
|
big.NewInt(i), big.NewInt(i + 1), big.NewInt(1), //nolint:gomnd
|
|
},
|
|
PiB: [3][2]*big.Int{
|
|
{big.NewInt(i + 2), big.NewInt(i + 3)}, //nolint:gomnd
|
|
{big.NewInt(i + 4), big.NewInt(i + 5)}, //nolint:gomnd
|
|
{big.NewInt(1), big.NewInt(0)}, //nolint:gomnd
|
|
},
|
|
PiC: [3]*big.Int{
|
|
big.NewInt(i + 6), big.NewInt(i + 7), big.NewInt(1), //nolint:gomnd
|
|
},
|
|
Protocol: "groth",
|
|
},
|
|
[]*big.Int{big.NewInt(i + 42)}, //nolint:gomnd
|
|
nil
|
|
case <-ctx.Done():
|
|
return nil, nil, tracerr.Wrap(common.ErrDone)
|
|
}
|
|
}
|
|
|
|
// Cancel cancels any current proof computation
|
|
func (p *MockClient) Cancel(ctx context.Context) error {
|
|
// Simulate a delay
|
|
select {
|
|
case <-time.After(80 * time.Millisecond): //nolint:gomnd
|
|
return nil
|
|
case <-ctx.Done():
|
|
return tracerr.Wrap(common.ErrDone)
|
|
}
|
|
}
|
|
|
|
// WaitReady waits until the prover is ready
|
|
func (p *MockClient) WaitReady(ctx context.Context) error {
|
|
// Simulate a delay
|
|
select {
|
|
case <-time.After(200 * time.Millisecond): //nolint:gomnd
|
|
return nil
|
|
case <-ctx.Done():
|
|
return tracerr.Wrap(common.ErrDone)
|
|
}
|
|
}
|