From 4ad67a3d181c0c3d2b1b4bac25c2e572125c4ab6 Mon Sep 17 00:00:00 2001 From: laisolizq Date: Mon, 7 Dec 2020 15:25:26 +0100 Subject: [PATCH 1/3] Update prover & add test --- prover/README.md | 28 ++++++++++++++ prover/prover.go | 90 ++++++++++++++++++++++++++++++------------- prover/prover_test.go | 83 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 174 insertions(+), 27 deletions(-) create mode 100644 prover/README.md create mode 100644 prover/prover_test.go diff --git a/prover/README.md b/prover/README.md new file mode 100644 index 0000000..9cd630d --- /dev/null +++ b/prover/README.md @@ -0,0 +1,28 @@ +## Test Prover + +### Server Proof API +It is necessary to have a docker with server locally. + +The instructions in the following link can be followed: +https://github.com/hermeznetwork/test-info/tree/main/cli-prover + +> It is necessary to consult the pre-requirements to follow the steps of the next summary + +A summary of the steps to follow to run docker would be: + +- Clone the repository: https://github.com/hermeznetwork/test-info +- `cd cli-prover` +- `./cli-prover.sh -s localhost -v ~/prover_data -r 22` +- To enter docker: `docker exec -ti docker_cusnarks bash` +- Inside the docker: `cd cusnarks; make docker_all FORCE_CPU=1` +- Inside the docker: `cd config; python3 cusnarks_config.py 22 BN256` +- To exit docker: `exit` +- Now, the server API can be used. Helper can be consulted with: `./cli-prover.sh -h` +- Is necessary to initialize the server with: `./cli-prover.sh --post-start ` +- When `./cli-prover.sh --get-status ` is `ready` can be run the test. + +> The session can be consulted with `tmux ls`. The session will be the number of the last session on the list. + +### Test + +`INTEGRATION=1 go test` \ No newline at end of file diff --git a/prover/prover.go b/prover/prover.go index acec14f..124cb4b 100644 --- a/prover/prover.go +++ b/prover/prover.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "mime/multipart" @@ -13,18 +14,21 @@ import ( "github.com/dghubble/sling" "github.com/hermeznetwork/hermez-node/common" - "github.com/hermeznetwork/hermez-node/log" "github.com/hermeznetwork/tracerr" ) // Proof TBD this type will be received from the proof server type Proof struct { + PiA []string `json:"pi_a"` + PiB [][]string `json:"pi_b"` + PiC []string `json:"pi_c"` + Protocol string `json:"protocol"` } // Client is the interface to a ServerProof that calculates zk proofs type Client interface { // Non-blocking - CalculateProof(zkInputs *common.ZKInputs) error + CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error // Blocking GetProof(ctx context.Context) (*Proof, error) // Non-Blocking @@ -105,23 +109,22 @@ const ( GET apiMethod = "GET" // POST is an HTTP POST with maybe JSON body POST apiMethod = "POST" - // POSTFILE is an HTTP POST with a form file - POSTFILE apiMethod = "POSTFILE" ) // ProofServerClient contains the data related to a ProofServerClient type ProofServerClient struct { - URL string - client *sling.Sling + URL string + client *sling.Sling + timeCons time.Duration } // NewProofServerClient creates a new ServerProof -func NewProofServerClient(URL string) *ProofServerClient { +func NewProofServerClient(URL string, timeCons time.Duration) *ProofServerClient { if URL[len(URL)-1] != '/' { URL += "/" } client := sling.New().Base(URL) - return &ProofServerClient{URL: URL, client: client} + return &ProofServerClient{URL: URL, client: client, timeCons: timeCons} } //nolint:unused @@ -170,15 +173,6 @@ func (p *ProofServerClient) apiRequest(ctx context.Context, method apiMethod, pa req, err = p.client.New().Get(path).Request() case POST: req, err = p.client.New().Post(path).BodyJSON(body).Request() - case POSTFILE: - provider, err := newFormFileProvider(body) - if err != nil { - return tracerr.Wrap(err) - } - req, err = p.client.New().Post(path).BodyProvider(provider).Request() - if err != nil { - return tracerr.Wrap(err) - } default: return tracerr.Wrap(fmt.Errorf("invalid http method: %v", method)) } @@ -215,7 +209,7 @@ func (p *ProofServerClient) apiCancel(ctx context.Context) error { //nolint:unused func (p *ProofServerClient) apiInput(ctx context.Context, zkInputs *common.ZKInputs) error { - if err := p.apiRequest(ctx, POSTFILE, "/input", zkInputs, nil); err != nil { + if err := p.apiRequest(ctx, POST, "/input", zkInputs, nil); err != nil { return tracerr.Wrap(err) } return nil @@ -223,28 +217,70 @@ func (p *ProofServerClient) apiInput(ctx context.Context, zkInputs *common.ZKInp // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the // Proof -func (p *ProofServerClient) CalculateProof(zkInputs *common.ZKInputs) error { - log.Error("TODO") - return tracerr.Wrap(common.ErrTODO) +func (p *ProofServerClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error { + err := p.apiInput(ctx, zkInputs) + if err != nil { + return tracerr.Wrap(err) + } + return nil } // GetProof retreives the Proof from the ServerProof, blocking until the proof // is ready. func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, error) { - log.Error("TODO") - return nil, tracerr.Wrap(common.ErrTODO) + status, err := p.apiStatus(ctx) + if err != nil { + return nil, tracerr.Wrap(err) + } + if status.Status == StatusCodeSuccess { + var proof Proof + err := json.Unmarshal([]byte(status.Proof), &proof) + if err != nil { + return nil, tracerr.Wrap(err) + } + return &proof, nil + } else { + return nil, errors.New("State is not Success") + } } // Cancel cancels any current proof computation func (p *ProofServerClient) Cancel(ctx context.Context) error { - log.Error("TODO") - return tracerr.Wrap(common.ErrTODO) + err := p.apiCancel(ctx) + if err != nil { + return tracerr.Wrap(err) + } + return nil } // WaitReady waits until the serverProof is ready func (p *ProofServerClient) WaitReady(ctx context.Context) error { - log.Error("TODO") - return tracerr.Wrap(common.ErrTODO) + status, err := p.apiStatus(ctx) + if err != nil { + return tracerr.Wrap(err) + } + if !status.Status.IsInitialized() { + err := errors.New("Proof Server is not initialized") + return err + } else { + if status.Status.IsReady() { + return nil + } + for { + select { + case <-ctx.Done(): + return tracerr.Wrap(common.ErrDone) + case <-time.After(p.timeCons): + status, err := p.apiStatus(ctx) + if err != nil { + return tracerr.Wrap(err) + } + if status.Status.IsReady() { + return nil + } + } + } + } } // MockClient is a mock ServerProof to be used in tests. It doesn't calculate anything diff --git a/prover/prover_test.go b/prover/prover_test.go new file mode 100644 index 0000000..6bfedc3 --- /dev/null +++ b/prover/prover_test.go @@ -0,0 +1,83 @@ +package prover + +import ( + "context" + "math/big" + "os" + "testing" + "time" + + "github.com/hermeznetwork/hermez-node/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const apiURL = "http://localhost:3000/api" +const timeCons = 1 * time.Second + +var proofServerClient *ProofServerClient + +func TestMain(m *testing.M) { + exitVal := 0 + if os.Getenv("INTEGRATION") != "" { + proofServerClient = NewProofServerClient(apiURL, timeCons) + err := proofServerClient.WaitReady(context.Background()) + if err != nil { + panic(err) + } + exitVal = m.Run() + } + os.Exit(exitVal) +} + +func TestApiServer(t *testing.T) { + t.Run("testAPIStatus", testAPIStatus) + t.Run("testCalculateProof", testCalculateProof) + time.Sleep(time.Second / 4) + err := proofServerClient.WaitReady(context.Background()) + require.NoError(t, err) + t.Run("testGetProof", testGetProof) + t.Run("testCancel", testCancel) +} + +func testAPIStatus(t *testing.T) { + status, err := proofServerClient.apiStatus(context.Background()) + require.NoError(t, err) + assert.Equal(t, true, status.Status.IsReady()) +} + +func testCalculateProof(t *testing.T) { + var zkInputs *common.ZKInputs + zkInputs = common.NewZKInputs(100, 16, 512, 24, 32, big.NewInt(1)) + err := proofServerClient.CalculateProof(context.Background(), zkInputs) + require.NoError(t, err) +} + +func testGetProof(t *testing.T) { + proof, err := proofServerClient.GetProof(context.Background()) + require.NoError(t, err) + require.NotNil(t, proof) + require.NotNil(t, proof.PiA) + require.NotNil(t, proof.PiB) + require.NotNil(t, proof.PiC) + require.NotNil(t, proof.Protocol) +} + +func testCancel(t *testing.T) { + var zkInputs *common.ZKInputs + zkInputs = common.NewZKInputs(100, 16, 512, 24, 32, big.NewInt(1)) + err := proofServerClient.CalculateProof(context.Background(), zkInputs) + require.NoError(t, err) + // TODO: remove sleep when the server has been reviewed + time.Sleep(time.Second / 4) + err = proofServerClient.Cancel(context.Background()) + require.NoError(t, err) + status, err := proofServerClient.apiStatus(context.Background()) + require.NoError(t, err) + for status.Status == StatusCodeBusy { + time.Sleep(proofServerClient.timeCons) + status, err = proofServerClient.apiStatus(context.Background()) + require.NoError(t, err) + } + assert.Equal(t, StatusCodeAborted, status.Status) +} From d8050dd0a62ecf72f7737225bcb026a1f9cc0b6d Mon Sep 17 00:00:00 2001 From: Eduard S Date: Wed, 9 Dec 2020 16:22:31 +0100 Subject: [PATCH 2/3] Update node and coordinator, fix linters --- config/config.go | 11 ++++++----- coordinator/coordinator.go | 2 +- node/node.go | 7 ++++--- prover/prover.go | 38 ++++++++++++++++++-------------------- prover/prover_test.go | 6 ++---- 5 files changed, 31 insertions(+), 33 deletions(-) diff --git a/config/config.go b/config/config.go index 71489a8..e4de25a 100644 --- a/config/config.go +++ b/config/config.go @@ -43,8 +43,9 @@ type Coordinator struct { ConfirmBlocks int64 `validate:"required"` // L1BatchTimeoutPerc is the portion of the range before the L1Batch // timeout that will trigger a schedule to forge an L1Batch - L1BatchTimeoutPerc float64 `validate:"required"` - L2DB struct { + ProofServerPollInterval Duration `validate:"required"` + L1BatchTimeoutPerc float64 `validate:"required"` + L2DB struct { SafetyPeriod common.BatchNum `validate:"required"` MaxTxs uint32 `validate:"required"` TTL Duration `validate:"required"` @@ -69,10 +70,10 @@ type Coordinator struct { DeployGasLimit uint64 `validate:"required"` GasPriceDiv uint64 `validate:"required"` ReceiptTimeout Duration `validate:"required"` - IntervalReceiptLoop Duration `validate:"required"` - // IntervalCheckLoop is the waiting interval between receipt + ReceiptLoopInterval Duration `validate:"required"` + // CheckLoopInterval is the waiting interval between receipt // checks of ethereum transactions in the TxManager - IntervalCheckLoop Duration `validate:"required"` + CheckLoopInterval Duration `validate:"required"` // Attempts is the number of attempts to do an eth client RPC // call before giving up Attempts int `validate:"required"` diff --git a/coordinator/coordinator.go b/coordinator/coordinator.go index 50e6ee5..fa2428d 100644 --- a/coordinator/coordinator.go +++ b/coordinator/coordinator.go @@ -807,7 +807,7 @@ func (p *Pipeline) forgeSendServerProof(ctx context.Context, batchNum common.Bat // 7. Call the selected idle server proof with BatchBuilder output, // save server proof info for batchNum - err = batchInfo.ServerProof.CalculateProof(zkInputs) + err = batchInfo.ServerProof.CalculateProof(ctx, zkInputs) if err != nil { return nil, tracerr.Wrap(err) } diff --git a/node/node.go b/node/node.go index 435bc9b..c34994b 100644 --- a/node/node.go +++ b/node/node.go @@ -96,7 +96,7 @@ func NewNode(mode Mode, cfg *config.Node, coordCfg *config.Coordinator) (*Node, DeployGasLimit: coordCfg.EthClient.DeployGasLimit, GasPriceDiv: coordCfg.EthClient.GasPriceDiv, ReceiptTimeout: coordCfg.EthClient.ReceiptTimeout.Duration, - IntervalReceiptLoop: coordCfg.EthClient.IntervalReceiptLoop.Duration, + IntervalReceiptLoop: coordCfg.EthClient.ReceiptLoopInterval.Duration, } } client, err := eth.NewClient(ethClient, nil, nil, ð.ClientConfig{ @@ -165,7 +165,8 @@ func NewNode(mode Mode, cfg *config.Node, coordCfg *config.Coordinator) (*Node, } serverProofs := make([]prover.Client, len(coordCfg.ServerProofs)) for i, serverProofCfg := range coordCfg.ServerProofs { - serverProofs[i] = prover.NewProofServerClient(serverProofCfg.URL) + serverProofs[i] = prover.NewProofServerClient(serverProofCfg.URL, + coordCfg.ProofServerPollInterval.Duration) } coord, err = coordinator.NewCoordinator( @@ -175,7 +176,7 @@ func NewNode(mode Mode, cfg *config.Node, coordCfg *config.Coordinator) (*Node, L1BatchTimeoutPerc: coordCfg.L1BatchTimeoutPerc, EthClientAttempts: coordCfg.EthClient.Attempts, EthClientAttemptsDelay: coordCfg.EthClient.AttemptsDelay.Duration, - TxManagerCheckInterval: coordCfg.EthClient.IntervalCheckLoop.Duration, + TxManagerCheckInterval: coordCfg.EthClient.CheckLoopInterval.Duration, DebugBatchPath: coordCfg.Debug.BatchPath, Purger: coordinator.PurgerCfg{ PurgeBatchDelay: coordCfg.L2DB.PurgeBatchDelay, diff --git a/prover/prover.go b/prover/prover.go index 124cb4b..4053013 100644 --- a/prover/prover.go +++ b/prover/prover.go @@ -133,7 +133,7 @@ type formFileProvider struct { body []byte } -//nolint:unused +//nolint:unused,deadcode func newFormFileProvider(payload interface{}) (*formFileProvider, error) { body := new(bytes.Buffer) writer := multipart.NewWriter(body) @@ -239,9 +239,8 @@ func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, error) { return nil, tracerr.Wrap(err) } return &proof, nil - } else { - return nil, errors.New("State is not Success") } + return nil, errors.New("State is not Success") } // Cancel cancels any current proof computation @@ -262,22 +261,21 @@ func (p *ProofServerClient) WaitReady(ctx context.Context) error { if !status.Status.IsInitialized() { err := errors.New("Proof Server is not initialized") return err - } else { - if status.Status.IsReady() { - return nil - } - for { - select { - case <-ctx.Done(): - return tracerr.Wrap(common.ErrDone) - case <-time.After(p.timeCons): - status, err := p.apiStatus(ctx) - if err != nil { - return tracerr.Wrap(err) - } - if status.Status.IsReady() { - return nil - } + } + if status.Status.IsReady() { + return nil + } + for { + select { + case <-ctx.Done(): + return tracerr.Wrap(common.ErrDone) + case <-time.After(p.timeCons): + status, err := p.apiStatus(ctx) + if err != nil { + return tracerr.Wrap(err) + } + if status.Status.IsReady() { + return nil } } } @@ -289,7 +287,7 @@ type MockClient struct { // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the // Proof -func (p *MockClient) CalculateProof(zkInputs *common.ZKInputs) error { +func (p *MockClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error { return nil } diff --git a/prover/prover_test.go b/prover/prover_test.go index 6bfedc3..b83b766 100644 --- a/prover/prover_test.go +++ b/prover/prover_test.go @@ -47,8 +47,7 @@ func testAPIStatus(t *testing.T) { } func testCalculateProof(t *testing.T) { - var zkInputs *common.ZKInputs - zkInputs = common.NewZKInputs(100, 16, 512, 24, 32, big.NewInt(1)) + zkInputs := common.NewZKInputs(100, 16, 512, 24, 32, big.NewInt(1)) err := proofServerClient.CalculateProof(context.Background(), zkInputs) require.NoError(t, err) } @@ -64,8 +63,7 @@ func testGetProof(t *testing.T) { } func testCancel(t *testing.T) { - var zkInputs *common.ZKInputs - zkInputs = common.NewZKInputs(100, 16, 512, 24, 32, big.NewInt(1)) + zkInputs := common.NewZKInputs(100, 16, 512, 24, 32, big.NewInt(1)) err := proofServerClient.CalculateProof(context.Background(), zkInputs) require.NoError(t, err) // TODO: remove sleep when the server has been reviewed From f5818711dcaf23ced03982d9d92ab32537df4716 Mon Sep 17 00:00:00 2001 From: Eduard S Date: Thu, 10 Dec 2020 12:34:29 +0100 Subject: [PATCH 3/3] Simplify prover client, use big.Int in Proof --- config/config.go | 7 +- prover/prover.go | 159 ++++++++++++++++++------------------------ prover/prover_test.go | 6 +- 3 files changed, 73 insertions(+), 99 deletions(-) diff --git a/config/config.go b/config/config.go index e4de25a..43efa82 100644 --- a/config/config.go +++ b/config/config.go @@ -36,15 +36,16 @@ type ServerProof struct { // Coordinator is the coordinator specific configuration. type Coordinator struct { // ForgerAddress is the address under which this coordinator is forging - ForgerAddress ethCommon.Address `validate:"required"` - ForgeLoopInterval Duration `validate:"required"` + ForgerAddress ethCommon.Address `validate:"required"` // ConfirmBlocks is the number of confirmation blocks to wait for sent // ethereum transactions before forgetting about them ConfirmBlocks int64 `validate:"required"` // L1BatchTimeoutPerc is the portion of the range before the L1Batch // timeout that will trigger a schedule to forge an L1Batch + L1BatchTimeoutPerc float64 `validate:"required"` + // ProofServerPollInterval is the waiting interval between polling the + // ProofServer while waiting for a particular status ProofServerPollInterval Duration `validate:"required"` - L1BatchTimeoutPerc float64 `validate:"required"` L2DB struct { SafetyPeriod common.BatchNum `validate:"required"` MaxTxs uint32 `validate:"required"` diff --git a/prover/prover.go b/prover/prover.go index 4053013..91abd87 100644 --- a/prover/prover.go +++ b/prover/prover.go @@ -1,13 +1,10 @@ package prover import ( - "bytes" "context" "encoding/json" - "errors" "fmt" - "io" - "mime/multipart" + "math/big" "net/http" "strings" "time" @@ -19,10 +16,46 @@ import ( // Proof TBD this type will be received from the proof server type Proof struct { - PiA []string `json:"pi_a"` - PiB [][]string `json:"pi_b"` - PiC []string `json:"pi_c"` - Protocol string `json:"protocol"` + PiA [2]*big.Int `json:"pi_a"` + PiB [3][2]*big.Int `json:"pi_b"` + PiC [2]*big.Int `json:"pi_c"` + Protocol string `json:"protocol"` +} + +type bigInt big.Int + +func (b *bigInt) UnmarshalText(text []byte) error { + _, ok := (*big.Int)(b).SetString(string(text), 10) + if !ok { + return fmt.Errorf("invalid big int: \"%v\"", string(text)) + } + return nil +} + +// UnmarshalJSON unmarshals the proof from a JSON encoded proof with the big +// ints as strings +func (p *Proof) UnmarshalJSON(data []byte) error { + proof := struct { + PiA [2]*bigInt `json:"pi_a"` + PiB [3][2]*bigInt `json:"pi_b"` + PiC [2]*bigInt `json:"pi_c"` + Protocol string `json:"protocol"` + }{} + if err := json.Unmarshal(data, &proof); err != nil { + return err + } + p.PiA[0] = (*big.Int)(proof.PiA[0]) + p.PiA[1] = (*big.Int)(proof.PiA[1]) + p.PiB[0][0] = (*big.Int)(proof.PiB[0][0]) + p.PiB[0][1] = (*big.Int)(proof.PiB[0][1]) + p.PiB[1][0] = (*big.Int)(proof.PiB[1][0]) + p.PiB[1][1] = (*big.Int)(proof.PiB[1][1]) + p.PiB[2][0] = (*big.Int)(proof.PiB[2][0]) + p.PiB[2][1] = (*big.Int)(proof.PiB[2][1]) + p.PiC[0] = (*big.Int)(proof.PiC[0]) + p.PiC[1] = (*big.Int)(proof.PiC[1]) + p.Protocol = proof.Protocol + return nil } // Client is the interface to a ServerProof that calculates zk proofs @@ -113,55 +146,20 @@ const ( // ProofServerClient contains the data related to a ProofServerClient type ProofServerClient struct { - URL string - client *sling.Sling - timeCons time.Duration + URL string + client *sling.Sling + pollInterval time.Duration } // NewProofServerClient creates a new ServerProof -func NewProofServerClient(URL string, timeCons time.Duration) *ProofServerClient { +func NewProofServerClient(URL string, pollInterval time.Duration) *ProofServerClient { if URL[len(URL)-1] != '/' { URL += "/" } client := sling.New().Base(URL) - return &ProofServerClient{URL: URL, client: client, timeCons: timeCons} + return &ProofServerClient{URL: URL, client: client, pollInterval: pollInterval} } -//nolint:unused -type formFileProvider struct { - writer *multipart.Writer - body []byte -} - -//nolint:unused,deadcode -func newFormFileProvider(payload interface{}) (*formFileProvider, error) { - body := new(bytes.Buffer) - writer := multipart.NewWriter(body) - part, err := writer.CreateFormFile("file", "file.json") - if err != nil { - return nil, tracerr.Wrap(err) - } - if err := json.NewEncoder(part).Encode(payload); err != nil { - return nil, tracerr.Wrap(err) - } - if err := writer.Close(); err != nil { - return nil, tracerr.Wrap(err) - } - return &formFileProvider{ - writer: writer, - body: body.Bytes(), - }, nil -} - -func (p formFileProvider) ContentType() string { - return p.writer.FormDataContentType() -} - -func (p formFileProvider) Body() (io.Reader, error) { - return bytes.NewReader(p.body), nil -} - -//nolint:unused func (p *ProofServerClient) apiRequest(ctx context.Context, method apiMethod, path string, body interface{}, ret interface{}) error { path = strings.TrimPrefix(path, "/") @@ -190,44 +188,31 @@ func (p *ProofServerClient) apiRequest(ctx context.Context, method apiMethod, pa return nil } -//nolint:unused func (p *ProofServerClient) apiStatus(ctx context.Context) (*Status, error) { var status Status - if err := p.apiRequest(ctx, GET, "/status", nil, &status); err != nil { - return nil, tracerr.Wrap(err) - } - return &status, nil + return &status, tracerr.Wrap(p.apiRequest(ctx, GET, "/status", nil, &status)) } -//nolint:unused func (p *ProofServerClient) apiCancel(ctx context.Context) error { - if err := p.apiRequest(ctx, POST, "/cancel", nil, nil); err != nil { - return tracerr.Wrap(err) - } - return nil + return tracerr.Wrap(p.apiRequest(ctx, POST, "/cancel", nil, nil)) } -//nolint:unused func (p *ProofServerClient) apiInput(ctx context.Context, zkInputs *common.ZKInputs) error { - if err := p.apiRequest(ctx, POST, "/input", zkInputs, nil); err != nil { - return tracerr.Wrap(err) - } - return nil + return tracerr.Wrap(p.apiRequest(ctx, POST, "/input", zkInputs, nil)) } // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the // Proof func (p *ProofServerClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error { - err := p.apiInput(ctx, zkInputs) - if err != nil { - return tracerr.Wrap(err) - } - return nil + return tracerr.Wrap(p.apiInput(ctx, zkInputs)) } // GetProof retreives the Proof from the ServerProof, blocking until the proof // is ready. func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, error) { + if err := p.WaitReady(ctx); err != nil { + return nil, err + } status, err := p.apiStatus(ctx) if err != nil { return nil, tracerr.Wrap(err) @@ -240,43 +225,31 @@ func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, error) { } return &proof, nil } - return nil, errors.New("State is not Success") + return nil, fmt.Errorf("status != StatusCodeSuccess, status = %v", status.Status) } // Cancel cancels any current proof computation func (p *ProofServerClient) Cancel(ctx context.Context) error { - err := p.apiCancel(ctx) - if err != nil { - return tracerr.Wrap(err) - } - return nil + return tracerr.Wrap(p.apiCancel(ctx)) } // WaitReady waits until the serverProof is ready func (p *ProofServerClient) WaitReady(ctx context.Context) error { - status, err := p.apiStatus(ctx) - if err != nil { - return tracerr.Wrap(err) - } - if !status.Status.IsInitialized() { - err := errors.New("Proof Server is not initialized") - return err - } - if status.Status.IsReady() { - return nil - } for { + status, err := p.apiStatus(ctx) + if err != nil { + return tracerr.Wrap(err) + } + if !status.Status.IsInitialized() { + return fmt.Errorf("Proof Server is not initialized") + } + if status.Status.IsReady() { + return nil + } select { case <-ctx.Done(): return tracerr.Wrap(common.ErrDone) - case <-time.After(p.timeCons): - status, err := p.apiStatus(ctx) - if err != nil { - return tracerr.Wrap(err) - } - if status.Status.IsReady() { - return nil - } + case <-time.After(p.pollInterval): } } } diff --git a/prover/prover_test.go b/prover/prover_test.go index b83b766..77378c6 100644 --- a/prover/prover_test.go +++ b/prover/prover_test.go @@ -13,14 +13,14 @@ import ( ) const apiURL = "http://localhost:3000/api" -const timeCons = 1 * time.Second +const pollInterval = 1 * time.Second var proofServerClient *ProofServerClient func TestMain(m *testing.M) { exitVal := 0 if os.Getenv("INTEGRATION") != "" { - proofServerClient = NewProofServerClient(apiURL, timeCons) + proofServerClient = NewProofServerClient(apiURL, pollInterval) err := proofServerClient.WaitReady(context.Background()) if err != nil { panic(err) @@ -73,7 +73,7 @@ func testCancel(t *testing.T) { status, err := proofServerClient.apiStatus(context.Background()) require.NoError(t, err) for status.Status == StatusCodeBusy { - time.Sleep(proofServerClient.timeCons) + time.Sleep(proofServerClient.pollInterval) status, err = proofServerClient.apiStatus(context.Background()) require.NoError(t, err) }