Browse Source

Add mock proof server

feature/sql-semaphore1
Eduard S 3 years ago
parent
commit
885f584fd2
4 changed files with 275 additions and 14 deletions
  1. +28
    -8
      prover/prover.go
  2. +9
    -6
      prover/prover_test.go
  3. +47
    -0
      test/proofserver/cli/main.go
  4. +191
    -0
      test/proofserver/proofserver.go

+ 28
- 8
prover/prover.go

@ -58,11 +58,28 @@ func (p *Proof) UnmarshalJSON(data []byte) error {
return nil return nil
} }
// PublicInputs are the public inputs of the proof
type PublicInputs []*big.Int
// UnmarshalJSON unmarshals the JSON into the public inputs where the bigInts
// are in decimal as quoted strings
func (p *PublicInputs) UnmarshalJSON(data []byte) error {
pubInputs := []*bigInt{}
if err := json.Unmarshal(data, &pubInputs); err != nil {
return err
}
*p = make([]*big.Int, len(pubInputs))
for i, v := range pubInputs {
([]*big.Int)(*p)[i] = (*big.Int)(v)
}
return nil
}
// Client is the interface to a ServerProof that calculates zk proofs // Client is the interface to a ServerProof that calculates zk proofs
type Client interface { type Client interface {
// Non-blocking // Non-blocking
CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error
// Blocking
// Blocking. Returns the Proof and Public Data (public inputs)
GetProof(ctx context.Context) (*Proof, []*big.Int, error) GetProof(ctx context.Context) (*Proof, []*big.Int, error)
// Non-Blocking // Non-Blocking
Cancel(ctx context.Context) error Cancel(ctx context.Context) error
@ -207,11 +224,11 @@ func (p *ProofServerClient) CalculateProof(ctx context.Context, zkInputs *common
return tracerr.Wrap(p.apiInput(ctx, zkInputs)) return tracerr.Wrap(p.apiInput(ctx, zkInputs))
} }
// GetProof retreives the Proof from the ServerProof, blocking until the proof
// is ready.
// GetProof retreives the Proof and Public Data (public inputs) from the
// ServerProof, blocking until the proof is ready.
func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, []*big.Int, error) { func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, []*big.Int, error) {
if err := p.WaitReady(ctx); err != nil { if err := p.WaitReady(ctx); err != nil {
return nil, nil, err
return nil, nil, tracerr.Wrap(err)
} }
status, err := p.apiStatus(ctx) status, err := p.apiStatus(ctx)
if err != nil { if err != nil {
@ -219,11 +236,14 @@ func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, []*big.Int, e
} }
if status.Status == StatusCodeSuccess { if status.Status == StatusCodeSuccess {
var proof Proof var proof Proof
err := json.Unmarshal([]byte(status.Proof), &proof)
if err != nil {
if err := json.Unmarshal([]byte(status.Proof), &proof); err != nil {
return nil, nil, tracerr.Wrap(err)
}
var pubInputs PublicInputs
if err := json.Unmarshal([]byte(status.PubData), &pubInputs); err != nil {
return nil, nil, tracerr.Wrap(err) return nil, nil, tracerr.Wrap(err)
} }
return &proof, nil, nil
return &proof, pubInputs, nil
} }
return nil, nil, fmt.Errorf("status != StatusCodeSuccess, status = %v", status.Status) return nil, nil, fmt.Errorf("status != StatusCodeSuccess, status = %v", status.Status)
} }
@ -269,7 +289,7 @@ func (p *MockClient) GetProof(ctx context.Context) (*Proof, []*big.Int, error) {
// Simulate a delay // Simulate a delay
select { select {
case <-time.After(500 * time.Millisecond): //nolint:gomnd case <-time.After(500 * time.Millisecond): //nolint:gomnd
return &Proof{}, nil, nil
return &Proof{}, []*big.Int{big.NewInt(1234)}, nil //nolint:gomnd
case <-ctx.Done(): case <-ctx.Done():
return nil, nil, tracerr.Wrap(common.ErrDone) return nil, nil, tracerr.Wrap(common.ErrDone)
} }

+ 9
- 6
prover/prover_test.go

@ -53,13 +53,16 @@ func testCalculateProof(t *testing.T) {
} }
func testGetProof(t *testing.T) { func testGetProof(t *testing.T) {
proof, _, err := proofServerClient.GetProof(context.Background())
proof, pubInputs, err := proofServerClient.GetProof(context.Background())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, proof)
require.NotNil(t, proof.PiA)
require.NotNil(t, proof.PiB)
require.NotNil(t, proof.PiC)
require.NotNil(t, proof.Protocol)
assert.NotNil(t, proof.PiA)
assert.NotEqual(t, [2]*big.Int{}, proof.PiA)
assert.NotNil(t, proof.PiB)
assert.NotEqual(t, [3][2]*big.Int{}, proof.PiB)
assert.NotNil(t, proof.PiC)
assert.NotEqual(t, [2]*big.Int{}, proof.PiC)
assert.NotNil(t, proof.Protocol)
assert.NotEqual(t, 0, len(pubInputs))
} }
func testCancel(t *testing.T) { func testCancel(t *testing.T) {

+ 47
- 0
test/proofserver/cli/main.go

@ -0,0 +1,47 @@
package main
import (
"context"
"flag"
"log"
"os"
"os/signal"
"sync"
"time"
"github.com/hermeznetwork/hermez-node/test/proofserver"
)
func main() {
var addr string
flag.StringVar(&addr, "a", "localhost:3000", "listen address")
var provingDuration time.Duration
flag.DurationVar(&provingDuration, "d", 2*time.Second, "proving time duration") //nolint:gomnd
flag.Parse()
mock := proofserver.NewMock(addr, provingDuration)
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
if err := mock.Run(ctx); err != nil {
log.Fatal(err)
}
wg.Done()
}()
stopCh := make(chan interface{})
// catch ^C to send the stop signal
ossig := make(chan os.Signal, 1)
signal.Notify(ossig, os.Interrupt)
go func() {
for sig := range ossig {
if sig == os.Interrupt {
stopCh <- nil
}
}
}()
<-stopCh
cancel()
wg.Wait()
}

+ 191
- 0
test/proofserver/proofserver.go

@ -0,0 +1,191 @@
package proofserver
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"sync"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/hermeznetwork/hermez-node/log"
"github.com/hermeznetwork/hermez-node/prover"
"github.com/hermeznetwork/tracerr"
)
type msg struct {
value string
ackCh chan bool
}
func newMsg(value string) msg {
return msg{
value: value,
ackCh: make(chan bool),
}
}
// Mock proof server
type Mock struct {
addr string
status prover.StatusCode
sync.RWMutex
proof string
pubData string
counter int
msgCh chan msg
wg sync.WaitGroup
provingDuration time.Duration
}
// NewMock creates a new mock server
func NewMock(addr string, provingDuration time.Duration) *Mock {
return &Mock{
addr: addr,
status: prover.StatusCodeReady,
proof: "",
pubData: "",
counter: 0,
msgCh: make(chan msg),
provingDuration: provingDuration,
}
}
func (s *Mock) err(c *gin.Context, err error) {
c.JSON(http.StatusInternalServerError, prover.ErrorServer{
Status: "error",
Message: err.Error(),
})
}
func (s *Mock) handleCancel(c *gin.Context) {
msg := newMsg("cancel")
s.msgCh <- msg
<-msg.ackCh
c.JSON(http.StatusOK, "OK")
}
func (s *Mock) handleStatus(c *gin.Context) {
s.RLock()
c.JSON(http.StatusOK, prover.Status{
Status: s.status,
Proof: s.proof,
PubData: s.pubData,
})
s.RUnlock()
}
func (s *Mock) handleInput(c *gin.Context) {
s.RLock()
if !s.status.IsReady() {
s.err(c, fmt.Errorf("not ready"))
s.RUnlock()
return
}
s.RUnlock()
_, err := ioutil.ReadAll(c.Request.Body)
if err != nil {
s.err(c, err)
return
}
msg := newMsg("prove")
s.msgCh <- msg
<-msg.ackCh
c.JSON(http.StatusOK, "OK")
}
const longWaitDuration = 999 * time.Hour
// const provingDuration = 2 * time.Second
func (s *Mock) runProver(ctx context.Context) {
waitDuration := longWaitDuration
for {
select {
case <-ctx.Done():
return
case msg := <-s.msgCh:
switch msg.value {
case "cancel":
waitDuration = longWaitDuration
s.Lock()
if !s.status.IsReady() {
s.status = prover.StatusCodeAborted
}
s.Unlock()
case "prove":
waitDuration = s.provingDuration
s.Lock()
s.status = prover.StatusCodeBusy
s.Unlock()
}
msg.ackCh <- true
case <-time.After(waitDuration):
waitDuration = longWaitDuration
s.Lock()
if s.status != prover.StatusCodeBusy {
s.Unlock()
continue
}
i := s.counter * 100 //nolint:gomnd
s.counter++
// Mock data
s.proof = fmt.Sprintf(`{
"pi_a": ["%v", "%v"],
"pi_b": [["%v", "%v"],["%v", "%v"],["%v", "%v"]],
"pi_c": ["%v", "%v"],
"protocol": "groth16"
}`, i, i+1, i+2, i+3, i+4, i+5, i+6, i+7, i+8, i+9) //nolint:gomnd
s.pubData = fmt.Sprintf(`[
"%v"
]`, i+42) //nolint:gomnd
s.status = prover.StatusCodeSuccess
s.Unlock()
}
}
}
// Run the mock server. Use ctx to stop it via cancel
func (s *Mock) Run(ctx context.Context) error {
api := gin.Default()
api.Use(cors.Default())
apiGroup := api.Group("/api")
apiGroup.GET("/status", s.handleStatus)
apiGroup.POST("/input", s.handleInput)
apiGroup.POST("/cancel", s.handleCancel)
debugAPIServer := &http.Server{
Addr: s.addr,
Handler: api,
// Use some hardcoded numberes that are suitable for testing
ReadTimeout: 30 * time.Second, //nolint:gomnd
WriteTimeout: 30 * time.Second, //nolint:gomnd
MaxHeaderBytes: 1 << 20, //nolint:gomnd
}
go func() {
log.Infof("prover.MockServer is ready at %v", s.addr)
if err := debugAPIServer.ListenAndServe(); err != nil && tracerr.Unwrap(err) != http.ErrServerClosed {
log.Fatalf("Listen: %s\n", err)
}
}()
s.wg.Add(1)
go func() {
s.runProver(ctx)
s.wg.Done()
}()
<-ctx.Done()
log.Info("Stopping prover.MockServer...")
s.wg.Wait()
ctxTimeout, cancel := context.WithTimeout(context.Background(), 10*time.Second) //nolint:gomnd
defer cancel()
if err := debugAPIServer.Shutdown(ctxTimeout); err != nil {
return tracerr.Wrap(err)
}
log.Info("prover.MockServer done")
return nil
}

Loading…
Cancel
Save