Browse Source

Simplify prover client, use big.Int in Proof

feature/sql-semaphore1
Eduard S 4 years ago
parent
commit
f5818711dc
3 changed files with 73 additions and 99 deletions
  1. +4
    -3
      config/config.go
  2. +66
    -93
      prover/prover.go
  3. +3
    -3
      prover/prover_test.go

+ 4
- 3
config/config.go

@ -36,15 +36,16 @@ type ServerProof struct {
// Coordinator is the coordinator specific configuration. // Coordinator is the coordinator specific configuration.
type Coordinator struct { type Coordinator struct {
// ForgerAddress is the address under which this coordinator is forging // ForgerAddress is the address under which this coordinator is forging
ForgerAddress ethCommon.Address `validate:"required"`
ForgeLoopInterval Duration `validate:"required"`
ForgerAddress ethCommon.Address `validate:"required"`
// ConfirmBlocks is the number of confirmation blocks to wait for sent // ConfirmBlocks is the number of confirmation blocks to wait for sent
// ethereum transactions before forgetting about them // ethereum transactions before forgetting about them
ConfirmBlocks int64 `validate:"required"` ConfirmBlocks int64 `validate:"required"`
// L1BatchTimeoutPerc is the portion of the range before the L1Batch // L1BatchTimeoutPerc is the portion of the range before the L1Batch
// timeout that will trigger a schedule to forge an L1Batch // timeout that will trigger a schedule to forge an L1Batch
L1BatchTimeoutPerc float64 `validate:"required"`
// ProofServerPollInterval is the waiting interval between polling the
// ProofServer while waiting for a particular status
ProofServerPollInterval Duration `validate:"required"` ProofServerPollInterval Duration `validate:"required"`
L1BatchTimeoutPerc float64 `validate:"required"`
L2DB struct { L2DB struct {
SafetyPeriod common.BatchNum `validate:"required"` SafetyPeriod common.BatchNum `validate:"required"`
MaxTxs uint32 `validate:"required"` MaxTxs uint32 `validate:"required"`

+ 66
- 93
prover/prover.go

@ -1,13 +1,10 @@
package prover package prover
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io"
"mime/multipart"
"math/big"
"net/http" "net/http"
"strings" "strings"
"time" "time"
@ -19,10 +16,46 @@ import (
// Proof TBD this type will be received from the proof server // Proof TBD this type will be received from the proof server
type Proof struct { type Proof struct {
PiA []string `json:"pi_a"`
PiB [][]string `json:"pi_b"`
PiC []string `json:"pi_c"`
Protocol string `json:"protocol"`
PiA [2]*big.Int `json:"pi_a"`
PiB [3][2]*big.Int `json:"pi_b"`
PiC [2]*big.Int `json:"pi_c"`
Protocol string `json:"protocol"`
}
type bigInt big.Int
func (b *bigInt) UnmarshalText(text []byte) error {
_, ok := (*big.Int)(b).SetString(string(text), 10)
if !ok {
return fmt.Errorf("invalid big int: \"%v\"", string(text))
}
return nil
}
// UnmarshalJSON unmarshals the proof from a JSON encoded proof with the big
// ints as strings
func (p *Proof) UnmarshalJSON(data []byte) error {
proof := struct {
PiA [2]*bigInt `json:"pi_a"`
PiB [3][2]*bigInt `json:"pi_b"`
PiC [2]*bigInt `json:"pi_c"`
Protocol string `json:"protocol"`
}{}
if err := json.Unmarshal(data, &proof); err != nil {
return err
}
p.PiA[0] = (*big.Int)(proof.PiA[0])
p.PiA[1] = (*big.Int)(proof.PiA[1])
p.PiB[0][0] = (*big.Int)(proof.PiB[0][0])
p.PiB[0][1] = (*big.Int)(proof.PiB[0][1])
p.PiB[1][0] = (*big.Int)(proof.PiB[1][0])
p.PiB[1][1] = (*big.Int)(proof.PiB[1][1])
p.PiB[2][0] = (*big.Int)(proof.PiB[2][0])
p.PiB[2][1] = (*big.Int)(proof.PiB[2][1])
p.PiC[0] = (*big.Int)(proof.PiC[0])
p.PiC[1] = (*big.Int)(proof.PiC[1])
p.Protocol = proof.Protocol
return nil
} }
// Client is the interface to a ServerProof that calculates zk proofs // Client is the interface to a ServerProof that calculates zk proofs
@ -113,55 +146,20 @@ const (
// ProofServerClient contains the data related to a ProofServerClient // ProofServerClient contains the data related to a ProofServerClient
type ProofServerClient struct { type ProofServerClient struct {
URL string
client *sling.Sling
timeCons time.Duration
URL string
client *sling.Sling
pollInterval time.Duration
} }
// NewProofServerClient creates a new ServerProof // NewProofServerClient creates a new ServerProof
func NewProofServerClient(URL string, timeCons time.Duration) *ProofServerClient {
func NewProofServerClient(URL string, pollInterval time.Duration) *ProofServerClient {
if URL[len(URL)-1] != '/' { if URL[len(URL)-1] != '/' {
URL += "/" URL += "/"
} }
client := sling.New().Base(URL) client := sling.New().Base(URL)
return &ProofServerClient{URL: URL, client: client, timeCons: timeCons}
return &ProofServerClient{URL: URL, client: client, pollInterval: pollInterval}
} }
//nolint:unused
type formFileProvider struct {
writer *multipart.Writer
body []byte
}
//nolint:unused,deadcode
func newFormFileProvider(payload interface{}) (*formFileProvider, error) {
body := new(bytes.Buffer)
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("file", "file.json")
if err != nil {
return nil, tracerr.Wrap(err)
}
if err := json.NewEncoder(part).Encode(payload); err != nil {
return nil, tracerr.Wrap(err)
}
if err := writer.Close(); err != nil {
return nil, tracerr.Wrap(err)
}
return &formFileProvider{
writer: writer,
body: body.Bytes(),
}, nil
}
func (p formFileProvider) ContentType() string {
return p.writer.FormDataContentType()
}
func (p formFileProvider) Body() (io.Reader, error) {
return bytes.NewReader(p.body), nil
}
//nolint:unused
func (p *ProofServerClient) apiRequest(ctx context.Context, method apiMethod, path string, func (p *ProofServerClient) apiRequest(ctx context.Context, method apiMethod, path string,
body interface{}, ret interface{}) error { body interface{}, ret interface{}) error {
path = strings.TrimPrefix(path, "/") path = strings.TrimPrefix(path, "/")
@ -190,44 +188,31 @@ func (p *ProofServerClient) apiRequest(ctx context.Context, method apiMethod, pa
return nil return nil
} }
//nolint:unused
func (p *ProofServerClient) apiStatus(ctx context.Context) (*Status, error) { func (p *ProofServerClient) apiStatus(ctx context.Context) (*Status, error) {
var status Status var status Status
if err := p.apiRequest(ctx, GET, "/status", nil, &status); err != nil {
return nil, tracerr.Wrap(err)
}
return &status, nil
return &status, tracerr.Wrap(p.apiRequest(ctx, GET, "/status", nil, &status))
} }
//nolint:unused
func (p *ProofServerClient) apiCancel(ctx context.Context) error { func (p *ProofServerClient) apiCancel(ctx context.Context) error {
if err := p.apiRequest(ctx, POST, "/cancel", nil, nil); err != nil {
return tracerr.Wrap(err)
}
return nil
return tracerr.Wrap(p.apiRequest(ctx, POST, "/cancel", nil, nil))
} }
//nolint:unused
func (p *ProofServerClient) apiInput(ctx context.Context, zkInputs *common.ZKInputs) error { func (p *ProofServerClient) apiInput(ctx context.Context, zkInputs *common.ZKInputs) error {
if err := p.apiRequest(ctx, POST, "/input", zkInputs, nil); err != nil {
return tracerr.Wrap(err)
}
return nil
return tracerr.Wrap(p.apiRequest(ctx, POST, "/input", zkInputs, nil))
} }
// CalculateProof sends the *common.ZKInputs to the ServerProof to compute the // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the
// Proof // Proof
func (p *ProofServerClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error { func (p *ProofServerClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error {
err := p.apiInput(ctx, zkInputs)
if err != nil {
return tracerr.Wrap(err)
}
return nil
return tracerr.Wrap(p.apiInput(ctx, zkInputs))
} }
// GetProof retreives the Proof from the ServerProof, blocking until the proof // GetProof retreives the Proof from the ServerProof, blocking until the proof
// is ready. // is ready.
func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, error) { func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, error) {
if err := p.WaitReady(ctx); err != nil {
return nil, err
}
status, err := p.apiStatus(ctx) status, err := p.apiStatus(ctx)
if err != nil { if err != nil {
return nil, tracerr.Wrap(err) return nil, tracerr.Wrap(err)
@ -240,43 +225,31 @@ func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, error) {
} }
return &proof, nil return &proof, nil
} }
return nil, errors.New("State is not Success")
return nil, fmt.Errorf("status != StatusCodeSuccess, status = %v", status.Status)
} }
// Cancel cancels any current proof computation // Cancel cancels any current proof computation
func (p *ProofServerClient) Cancel(ctx context.Context) error { func (p *ProofServerClient) Cancel(ctx context.Context) error {
err := p.apiCancel(ctx)
if err != nil {
return tracerr.Wrap(err)
}
return nil
return tracerr.Wrap(p.apiCancel(ctx))
} }
// WaitReady waits until the serverProof is ready // WaitReady waits until the serverProof is ready
func (p *ProofServerClient) WaitReady(ctx context.Context) error { func (p *ProofServerClient) WaitReady(ctx context.Context) error {
status, err := p.apiStatus(ctx)
if err != nil {
return tracerr.Wrap(err)
}
if !status.Status.IsInitialized() {
err := errors.New("Proof Server is not initialized")
return err
}
if status.Status.IsReady() {
return nil
}
for { for {
status, err := p.apiStatus(ctx)
if err != nil {
return tracerr.Wrap(err)
}
if !status.Status.IsInitialized() {
return fmt.Errorf("Proof Server is not initialized")
}
if status.Status.IsReady() {
return nil
}
select { select {
case <-ctx.Done(): case <-ctx.Done():
return tracerr.Wrap(common.ErrDone) return tracerr.Wrap(common.ErrDone)
case <-time.After(p.timeCons):
status, err := p.apiStatus(ctx)
if err != nil {
return tracerr.Wrap(err)
}
if status.Status.IsReady() {
return nil
}
case <-time.After(p.pollInterval):
} }
} }
} }

+ 3
- 3
prover/prover_test.go

@ -13,14 +13,14 @@ import (
) )
const apiURL = "http://localhost:3000/api" const apiURL = "http://localhost:3000/api"
const timeCons = 1 * time.Second
const pollInterval = 1 * time.Second
var proofServerClient *ProofServerClient var proofServerClient *ProofServerClient
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
exitVal := 0 exitVal := 0
if os.Getenv("INTEGRATION") != "" { if os.Getenv("INTEGRATION") != "" {
proofServerClient = NewProofServerClient(apiURL, timeCons)
proofServerClient = NewProofServerClient(apiURL, pollInterval)
err := proofServerClient.WaitReady(context.Background()) err := proofServerClient.WaitReady(context.Background())
if err != nil { if err != nil {
panic(err) panic(err)
@ -73,7 +73,7 @@ func testCancel(t *testing.T) {
status, err := proofServerClient.apiStatus(context.Background()) status, err := proofServerClient.apiStatus(context.Background())
require.NoError(t, err) require.NoError(t, err)
for status.Status == StatusCodeBusy { for status.Status == StatusCodeBusy {
time.Sleep(proofServerClient.timeCons)
time.Sleep(proofServerClient.pollInterval)
status, err = proofServerClient.apiStatus(context.Background()) status, err = proofServerClient.apiStatus(context.Background())
require.NoError(t, err) require.NoError(t, err)
} }

Loading…
Cancel
Save