Browse Source

Merge pull request #333 from hermeznetwork/feature/client-prover

Update prover & add test
feature/sql-semaphore1
Eduard S 3 years ago
committed by GitHub
parent
commit
30c494b547
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 209 additions and 90 deletions
  1. +8
    -6
      config/config.go
  2. +1
    -1
      coordinator/coordinator.go
  3. +4
    -3
      node/node.go
  4. +28
    -0
      prover/README.md
  5. +87
    -80
      prover/prover.go
  6. +81
    -0
      prover/prover_test.go

+ 8
- 6
config/config.go

@ -36,15 +36,17 @@ type ServerProof struct {
// Coordinator is the coordinator specific configuration. // Coordinator is the coordinator specific configuration.
type Coordinator struct { type Coordinator struct {
// ForgerAddress is the address under which this coordinator is forging // ForgerAddress is the address under which this coordinator is forging
ForgerAddress ethCommon.Address `validate:"required"`
ForgeLoopInterval Duration `validate:"required"`
ForgerAddress ethCommon.Address `validate:"required"`
// ConfirmBlocks is the number of confirmation blocks to wait for sent // ConfirmBlocks is the number of confirmation blocks to wait for sent
// ethereum transactions before forgetting about them // ethereum transactions before forgetting about them
ConfirmBlocks int64 `validate:"required"` ConfirmBlocks int64 `validate:"required"`
// L1BatchTimeoutPerc is the portion of the range before the L1Batch // L1BatchTimeoutPerc is the portion of the range before the L1Batch
// timeout that will trigger a schedule to forge an L1Batch // timeout that will trigger a schedule to forge an L1Batch
L1BatchTimeoutPerc float64 `validate:"required"` L1BatchTimeoutPerc float64 `validate:"required"`
L2DB struct {
// ProofServerPollInterval is the waiting interval between polling the
// ProofServer while waiting for a particular status
ProofServerPollInterval Duration `validate:"required"`
L2DB struct {
SafetyPeriod common.BatchNum `validate:"required"` SafetyPeriod common.BatchNum `validate:"required"`
MaxTxs uint32 `validate:"required"` MaxTxs uint32 `validate:"required"`
TTL Duration `validate:"required"` TTL Duration `validate:"required"`
@ -69,10 +71,10 @@ type Coordinator struct {
DeployGasLimit uint64 `validate:"required"` DeployGasLimit uint64 `validate:"required"`
GasPriceDiv uint64 `validate:"required"` GasPriceDiv uint64 `validate:"required"`
ReceiptTimeout Duration `validate:"required"` ReceiptTimeout Duration `validate:"required"`
IntervalReceiptLoop Duration `validate:"required"`
// IntervalCheckLoop is the waiting interval between receipt
ReceiptLoopInterval Duration `validate:"required"`
// CheckLoopInterval is the waiting interval between receipt
// checks of ethereum transactions in the TxManager // checks of ethereum transactions in the TxManager
IntervalCheckLoop Duration `validate:"required"`
CheckLoopInterval Duration `validate:"required"`
// Attempts is the number of attempts to do an eth client RPC // Attempts is the number of attempts to do an eth client RPC
// call before giving up // call before giving up
Attempts int `validate:"required"` Attempts int `validate:"required"`

+ 1
- 1
coordinator/coordinator.go

@ -807,7 +807,7 @@ func (p *Pipeline) forgeSendServerProof(ctx context.Context, batchNum common.Bat
// 7. Call the selected idle server proof with BatchBuilder output, // 7. Call the selected idle server proof with BatchBuilder output,
// save server proof info for batchNum // save server proof info for batchNum
err = batchInfo.ServerProof.CalculateProof(zkInputs)
err = batchInfo.ServerProof.CalculateProof(ctx, zkInputs)
if err != nil { if err != nil {
return nil, tracerr.Wrap(err) return nil, tracerr.Wrap(err)
} }

+ 4
- 3
node/node.go

@ -96,7 +96,7 @@ func NewNode(mode Mode, cfg *config.Node, coordCfg *config.Coordinator) (*Node,
DeployGasLimit: coordCfg.EthClient.DeployGasLimit, DeployGasLimit: coordCfg.EthClient.DeployGasLimit,
GasPriceDiv: coordCfg.EthClient.GasPriceDiv, GasPriceDiv: coordCfg.EthClient.GasPriceDiv,
ReceiptTimeout: coordCfg.EthClient.ReceiptTimeout.Duration, ReceiptTimeout: coordCfg.EthClient.ReceiptTimeout.Duration,
IntervalReceiptLoop: coordCfg.EthClient.IntervalReceiptLoop.Duration,
IntervalReceiptLoop: coordCfg.EthClient.ReceiptLoopInterval.Duration,
} }
} }
client, err := eth.NewClient(ethClient, nil, nil, &eth.ClientConfig{ client, err := eth.NewClient(ethClient, nil, nil, &eth.ClientConfig{
@ -165,7 +165,8 @@ func NewNode(mode Mode, cfg *config.Node, coordCfg *config.Coordinator) (*Node,
} }
serverProofs := make([]prover.Client, len(coordCfg.ServerProofs)) serverProofs := make([]prover.Client, len(coordCfg.ServerProofs))
for i, serverProofCfg := range coordCfg.ServerProofs { for i, serverProofCfg := range coordCfg.ServerProofs {
serverProofs[i] = prover.NewProofServerClient(serverProofCfg.URL)
serverProofs[i] = prover.NewProofServerClient(serverProofCfg.URL,
coordCfg.ProofServerPollInterval.Duration)
} }
coord, err = coordinator.NewCoordinator( coord, err = coordinator.NewCoordinator(
@ -175,7 +176,7 @@ func NewNode(mode Mode, cfg *config.Node, coordCfg *config.Coordinator) (*Node,
L1BatchTimeoutPerc: coordCfg.L1BatchTimeoutPerc, L1BatchTimeoutPerc: coordCfg.L1BatchTimeoutPerc,
EthClientAttempts: coordCfg.EthClient.Attempts, EthClientAttempts: coordCfg.EthClient.Attempts,
EthClientAttemptsDelay: coordCfg.EthClient.AttemptsDelay.Duration, EthClientAttemptsDelay: coordCfg.EthClient.AttemptsDelay.Duration,
TxManagerCheckInterval: coordCfg.EthClient.IntervalCheckLoop.Duration,
TxManagerCheckInterval: coordCfg.EthClient.CheckLoopInterval.Duration,
DebugBatchPath: coordCfg.Debug.BatchPath, DebugBatchPath: coordCfg.Debug.BatchPath,
Purger: coordinator.PurgerCfg{ Purger: coordinator.PurgerCfg{
PurgeBatchDelay: coordCfg.L2DB.PurgeBatchDelay, PurgeBatchDelay: coordCfg.L2DB.PurgeBatchDelay,

+ 28
- 0
prover/README.md

@ -0,0 +1,28 @@
## Test Prover
### Server Proof API
It is necessary to have a docker with server locally.
The instructions in the following link can be followed:
https://github.com/hermeznetwork/test-info/tree/main/cli-prover
> It is necessary to consult the pre-requirements to follow the steps of the next summary
A summary of the steps to follow to run docker would be:
- Clone the repository: https://github.com/hermeznetwork/test-info
- `cd cli-prover`
- `./cli-prover.sh -s localhost -v ~/prover_data -r 22`
- To enter docker: `docker exec -ti docker_cusnarks bash`
- Inside the docker: `cd cusnarks; make docker_all FORCE_CPU=1`
- Inside the docker: `cd config; python3 cusnarks_config.py 22 BN256`
- To exit docker: `exit`
- Now, the server API can be used. Helper can be consulted with: `./cli-prover.sh -h`
- Is necessary to initialize the server with: `./cli-prover.sh --post-start <session>`
- When `./cli-prover.sh --get-status <session>` is `ready` can be run the test.
> The session can be consulted with `tmux ls`. The session will be the number of the last session on the list.
### Test
`INTEGRATION=1 go test`

+ 87
- 80
prover/prover.go

@ -1,30 +1,67 @@
package prover package prover
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"mime/multipart"
"math/big"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/dghubble/sling" "github.com/dghubble/sling"
"github.com/hermeznetwork/hermez-node/common" "github.com/hermeznetwork/hermez-node/common"
"github.com/hermeznetwork/hermez-node/log"
"github.com/hermeznetwork/tracerr" "github.com/hermeznetwork/tracerr"
) )
// Proof TBD this type will be received from the proof server // Proof TBD this type will be received from the proof server
type Proof struct { type Proof struct {
PiA [2]*big.Int `json:"pi_a"`
PiB [3][2]*big.Int `json:"pi_b"`
PiC [2]*big.Int `json:"pi_c"`
Protocol string `json:"protocol"`
}
type bigInt big.Int
func (b *bigInt) UnmarshalText(text []byte) error {
_, ok := (*big.Int)(b).SetString(string(text), 10)
if !ok {
return fmt.Errorf("invalid big int: \"%v\"", string(text))
}
return nil
}
// UnmarshalJSON unmarshals the proof from a JSON encoded proof with the big
// ints as strings
func (p *Proof) UnmarshalJSON(data []byte) error {
proof := struct {
PiA [2]*bigInt `json:"pi_a"`
PiB [3][2]*bigInt `json:"pi_b"`
PiC [2]*bigInt `json:"pi_c"`
Protocol string `json:"protocol"`
}{}
if err := json.Unmarshal(data, &proof); err != nil {
return err
}
p.PiA[0] = (*big.Int)(proof.PiA[0])
p.PiA[1] = (*big.Int)(proof.PiA[1])
p.PiB[0][0] = (*big.Int)(proof.PiB[0][0])
p.PiB[0][1] = (*big.Int)(proof.PiB[0][1])
p.PiB[1][0] = (*big.Int)(proof.PiB[1][0])
p.PiB[1][1] = (*big.Int)(proof.PiB[1][1])
p.PiB[2][0] = (*big.Int)(proof.PiB[2][0])
p.PiB[2][1] = (*big.Int)(proof.PiB[2][1])
p.PiC[0] = (*big.Int)(proof.PiC[0])
p.PiC[1] = (*big.Int)(proof.PiC[1])
p.Protocol = proof.Protocol
return nil
} }
// Client is the interface to a ServerProof that calculates zk proofs // Client is the interface to a ServerProof that calculates zk proofs
type Client interface { type Client interface {
// Non-blocking // Non-blocking
CalculateProof(zkInputs *common.ZKInputs) error
CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error
// Blocking // Blocking
GetProof(ctx context.Context) (*Proof, error) GetProof(ctx context.Context) (*Proof, error)
// Non-Blocking // Non-Blocking
@ -105,60 +142,24 @@ const (
GET apiMethod = "GET" GET apiMethod = "GET"
// POST is an HTTP POST with maybe JSON body // POST is an HTTP POST with maybe JSON body
POST apiMethod = "POST" POST apiMethod = "POST"
// POSTFILE is an HTTP POST with a form file
POSTFILE apiMethod = "POSTFILE"
) )
// ProofServerClient contains the data related to a ProofServerClient // ProofServerClient contains the data related to a ProofServerClient
type ProofServerClient struct { type ProofServerClient struct {
URL string
client *sling.Sling
URL string
client *sling.Sling
pollInterval time.Duration
} }
// NewProofServerClient creates a new ServerProof // NewProofServerClient creates a new ServerProof
func NewProofServerClient(URL string) *ProofServerClient {
func NewProofServerClient(URL string, pollInterval time.Duration) *ProofServerClient {
if URL[len(URL)-1] != '/' { if URL[len(URL)-1] != '/' {
URL += "/" URL += "/"
} }
client := sling.New().Base(URL) client := sling.New().Base(URL)
return &ProofServerClient{URL: URL, client: client}
}
//nolint:unused
type formFileProvider struct {
writer *multipart.Writer
body []byte
}
//nolint:unused
func newFormFileProvider(payload interface{}) (*formFileProvider, error) {
body := new(bytes.Buffer)
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("file", "file.json")
if err != nil {
return nil, tracerr.Wrap(err)
}
if err := json.NewEncoder(part).Encode(payload); err != nil {
return nil, tracerr.Wrap(err)
}
if err := writer.Close(); err != nil {
return nil, tracerr.Wrap(err)
}
return &formFileProvider{
writer: writer,
body: body.Bytes(),
}, nil
return &ProofServerClient{URL: URL, client: client, pollInterval: pollInterval}
} }
func (p formFileProvider) ContentType() string {
return p.writer.FormDataContentType()
}
func (p formFileProvider) Body() (io.Reader, error) {
return bytes.NewReader(p.body), nil
}
//nolint:unused
func (p *ProofServerClient) apiRequest(ctx context.Context, method apiMethod, path string, func (p *ProofServerClient) apiRequest(ctx context.Context, method apiMethod, path string,
body interface{}, ret interface{}) error { body interface{}, ret interface{}) error {
path = strings.TrimPrefix(path, "/") path = strings.TrimPrefix(path, "/")
@ -170,15 +171,6 @@ func (p *ProofServerClient) apiRequest(ctx context.Context, method apiMethod, pa
req, err = p.client.New().Get(path).Request() req, err = p.client.New().Get(path).Request()
case POST: case POST:
req, err = p.client.New().Post(path).BodyJSON(body).Request() req, err = p.client.New().Post(path).BodyJSON(body).Request()
case POSTFILE:
provider, err := newFormFileProvider(body)
if err != nil {
return tracerr.Wrap(err)
}
req, err = p.client.New().Post(path).BodyProvider(provider).Request()
if err != nil {
return tracerr.Wrap(err)
}
default: default:
return tracerr.Wrap(fmt.Errorf("invalid http method: %v", method)) return tracerr.Wrap(fmt.Errorf("invalid http method: %v", method))
} }
@ -196,55 +188,70 @@ func (p *ProofServerClient) apiRequest(ctx context.Context, method apiMethod, pa
return nil return nil
} }
//nolint:unused
func (p *ProofServerClient) apiStatus(ctx context.Context) (*Status, error) { func (p *ProofServerClient) apiStatus(ctx context.Context) (*Status, error) {
var status Status var status Status
if err := p.apiRequest(ctx, GET, "/status", nil, &status); err != nil {
return nil, tracerr.Wrap(err)
}
return &status, nil
return &status, tracerr.Wrap(p.apiRequest(ctx, GET, "/status", nil, &status))
} }
//nolint:unused
func (p *ProofServerClient) apiCancel(ctx context.Context) error { func (p *ProofServerClient) apiCancel(ctx context.Context) error {
if err := p.apiRequest(ctx, POST, "/cancel", nil, nil); err != nil {
return tracerr.Wrap(err)
}
return nil
return tracerr.Wrap(p.apiRequest(ctx, POST, "/cancel", nil, nil))
} }
//nolint:unused
func (p *ProofServerClient) apiInput(ctx context.Context, zkInputs *common.ZKInputs) error { func (p *ProofServerClient) apiInput(ctx context.Context, zkInputs *common.ZKInputs) error {
if err := p.apiRequest(ctx, POSTFILE, "/input", zkInputs, nil); err != nil {
return tracerr.Wrap(err)
}
return nil
return tracerr.Wrap(p.apiRequest(ctx, POST, "/input", zkInputs, nil))
} }
// CalculateProof sends the *common.ZKInputs to the ServerProof to compute the // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the
// Proof // Proof
func (p *ProofServerClient) CalculateProof(zkInputs *common.ZKInputs) error {
log.Error("TODO")
return tracerr.Wrap(common.ErrTODO)
func (p *ProofServerClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error {
return tracerr.Wrap(p.apiInput(ctx, zkInputs))
} }
// GetProof retreives the Proof from the ServerProof, blocking until the proof // GetProof retreives the Proof from the ServerProof, blocking until the proof
// is ready. // is ready.
func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, error) { func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, error) {
log.Error("TODO")
return nil, tracerr.Wrap(common.ErrTODO)
if err := p.WaitReady(ctx); err != nil {
return nil, err
}
status, err := p.apiStatus(ctx)
if err != nil {
return nil, tracerr.Wrap(err)
}
if status.Status == StatusCodeSuccess {
var proof Proof
err := json.Unmarshal([]byte(status.Proof), &proof)
if err != nil {
return nil, tracerr.Wrap(err)
}
return &proof, nil
}
return nil, fmt.Errorf("status != StatusCodeSuccess, status = %v", status.Status)
} }
// Cancel cancels any current proof computation // Cancel cancels any current proof computation
func (p *ProofServerClient) Cancel(ctx context.Context) error { func (p *ProofServerClient) Cancel(ctx context.Context) error {
log.Error("TODO")
return tracerr.Wrap(common.ErrTODO)
return tracerr.Wrap(p.apiCancel(ctx))
} }
// WaitReady waits until the serverProof is ready // WaitReady waits until the serverProof is ready
func (p *ProofServerClient) WaitReady(ctx context.Context) error { func (p *ProofServerClient) WaitReady(ctx context.Context) error {
log.Error("TODO")
return tracerr.Wrap(common.ErrTODO)
for {
status, err := p.apiStatus(ctx)
if err != nil {
return tracerr.Wrap(err)
}
if !status.Status.IsInitialized() {
return fmt.Errorf("Proof Server is not initialized")
}
if status.Status.IsReady() {
return nil
}
select {
case <-ctx.Done():
return tracerr.Wrap(common.ErrDone)
case <-time.After(p.pollInterval):
}
}
} }
// MockClient is a mock ServerProof to be used in tests. It doesn't calculate anything // MockClient is a mock ServerProof to be used in tests. It doesn't calculate anything
@ -253,7 +260,7 @@ type MockClient struct {
// CalculateProof sends the *common.ZKInputs to the ServerProof to compute the // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the
// Proof // Proof
func (p *MockClient) CalculateProof(zkInputs *common.ZKInputs) error {
func (p *MockClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error {
return nil return nil
} }

+ 81
- 0
prover/prover_test.go

@ -0,0 +1,81 @@
package prover
import (
"context"
"math/big"
"os"
"testing"
"time"
"github.com/hermeznetwork/hermez-node/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const apiURL = "http://localhost:3000/api"
const pollInterval = 1 * time.Second
var proofServerClient *ProofServerClient
func TestMain(m *testing.M) {
exitVal := 0
if os.Getenv("INTEGRATION") != "" {
proofServerClient = NewProofServerClient(apiURL, pollInterval)
err := proofServerClient.WaitReady(context.Background())
if err != nil {
panic(err)
}
exitVal = m.Run()
}
os.Exit(exitVal)
}
func TestApiServer(t *testing.T) {
t.Run("testAPIStatus", testAPIStatus)
t.Run("testCalculateProof", testCalculateProof)
time.Sleep(time.Second / 4)
err := proofServerClient.WaitReady(context.Background())
require.NoError(t, err)
t.Run("testGetProof", testGetProof)
t.Run("testCancel", testCancel)
}
func testAPIStatus(t *testing.T) {
status, err := proofServerClient.apiStatus(context.Background())
require.NoError(t, err)
assert.Equal(t, true, status.Status.IsReady())
}
func testCalculateProof(t *testing.T) {
zkInputs := common.NewZKInputs(100, 16, 512, 24, 32, big.NewInt(1))
err := proofServerClient.CalculateProof(context.Background(), zkInputs)
require.NoError(t, err)
}
func testGetProof(t *testing.T) {
proof, err := proofServerClient.GetProof(context.Background())
require.NoError(t, err)
require.NotNil(t, proof)
require.NotNil(t, proof.PiA)
require.NotNil(t, proof.PiB)
require.NotNil(t, proof.PiC)
require.NotNil(t, proof.Protocol)
}
func testCancel(t *testing.T) {
zkInputs := common.NewZKInputs(100, 16, 512, 24, 32, big.NewInt(1))
err := proofServerClient.CalculateProof(context.Background(), zkInputs)
require.NoError(t, err)
// TODO: remove sleep when the server has been reviewed
time.Sleep(time.Second / 4)
err = proofServerClient.Cancel(context.Background())
require.NoError(t, err)
status, err := proofServerClient.apiStatus(context.Background())
require.NoError(t, err)
for status.Status == StatusCodeBusy {
time.Sleep(proofServerClient.pollInterval)
status, err = proofServerClient.apiStatus(context.Background())
require.NoError(t, err)
}
assert.Equal(t, StatusCodeAborted, status.Status)
}

Loading…
Cancel
Save