From 885f584fd222d667c854a5c31987cf9d094f2737 Mon Sep 17 00:00:00 2001 From: Eduard S Date: Wed, 16 Dec 2020 13:16:24 +0100 Subject: [PATCH] Add mock proof server --- prover/prover.go | 36 ++++-- prover/prover_test.go | 15 ++- test/proofserver/cli/main.go | 47 ++++++++ test/proofserver/proofserver.go | 191 ++++++++++++++++++++++++++++++++ 4 files changed, 275 insertions(+), 14 deletions(-) create mode 100644 test/proofserver/cli/main.go create mode 100644 test/proofserver/proofserver.go diff --git a/prover/prover.go b/prover/prover.go index e360f86..09d6324 100644 --- a/prover/prover.go +++ b/prover/prover.go @@ -58,11 +58,28 @@ func (p *Proof) UnmarshalJSON(data []byte) error { return nil } +// PublicInputs are the public inputs of the proof +type PublicInputs []*big.Int + +// UnmarshalJSON unmarshals the JSON into the public inputs where the bigInts +// are in decimal as quoted strings +func (p *PublicInputs) UnmarshalJSON(data []byte) error { + pubInputs := []*bigInt{} + if err := json.Unmarshal(data, &pubInputs); err != nil { + return err + } + *p = make([]*big.Int, len(pubInputs)) + for i, v := range pubInputs { + ([]*big.Int)(*p)[i] = (*big.Int)(v) + } + return nil +} + // Client is the interface to a ServerProof that calculates zk proofs type Client interface { // Non-blocking CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error - // Blocking + // Blocking. Returns the Proof and Public Data (public inputs) GetProof(ctx context.Context) (*Proof, []*big.Int, error) // Non-Blocking Cancel(ctx context.Context) error @@ -207,11 +224,11 @@ func (p *ProofServerClient) CalculateProof(ctx context.Context, zkInputs *common return tracerr.Wrap(p.apiInput(ctx, zkInputs)) } -// GetProof retreives the Proof from the ServerProof, blocking until the proof -// is ready. +// GetProof retreives the Proof and Public Data (public inputs) from the +// ServerProof, blocking until the proof is ready. func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, []*big.Int, error) { if err := p.WaitReady(ctx); err != nil { - return nil, nil, err + return nil, nil, tracerr.Wrap(err) } status, err := p.apiStatus(ctx) if err != nil { @@ -219,11 +236,14 @@ func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, []*big.Int, e } if status.Status == StatusCodeSuccess { var proof Proof - err := json.Unmarshal([]byte(status.Proof), &proof) - if err != nil { + if err := json.Unmarshal([]byte(status.Proof), &proof); err != nil { + return nil, nil, tracerr.Wrap(err) + } + var pubInputs PublicInputs + if err := json.Unmarshal([]byte(status.PubData), &pubInputs); err != nil { return nil, nil, tracerr.Wrap(err) } - return &proof, nil, nil + return &proof, pubInputs, nil } return nil, nil, fmt.Errorf("status != StatusCodeSuccess, status = %v", status.Status) } @@ -269,7 +289,7 @@ func (p *MockClient) GetProof(ctx context.Context) (*Proof, []*big.Int, error) { // Simulate a delay select { case <-time.After(500 * time.Millisecond): //nolint:gomnd - return &Proof{}, nil, nil + return &Proof{}, []*big.Int{big.NewInt(1234)}, nil //nolint:gomnd case <-ctx.Done(): return nil, nil, tracerr.Wrap(common.ErrDone) } diff --git a/prover/prover_test.go b/prover/prover_test.go index a3eb7e0..99e7ba9 100644 --- a/prover/prover_test.go +++ b/prover/prover_test.go @@ -53,13 +53,16 @@ func testCalculateProof(t *testing.T) { } func testGetProof(t *testing.T) { - proof, _, err := proofServerClient.GetProof(context.Background()) + proof, pubInputs, err := proofServerClient.GetProof(context.Background()) require.NoError(t, err) - require.NotNil(t, proof) - require.NotNil(t, proof.PiA) - require.NotNil(t, proof.PiB) - require.NotNil(t, proof.PiC) - require.NotNil(t, proof.Protocol) + assert.NotNil(t, proof.PiA) + assert.NotEqual(t, [2]*big.Int{}, proof.PiA) + assert.NotNil(t, proof.PiB) + assert.NotEqual(t, [3][2]*big.Int{}, proof.PiB) + assert.NotNil(t, proof.PiC) + assert.NotEqual(t, [2]*big.Int{}, proof.PiC) + assert.NotNil(t, proof.Protocol) + assert.NotEqual(t, 0, len(pubInputs)) } func testCancel(t *testing.T) { diff --git a/test/proofserver/cli/main.go b/test/proofserver/cli/main.go new file mode 100644 index 0000000..e1013b5 --- /dev/null +++ b/test/proofserver/cli/main.go @@ -0,0 +1,47 @@ +package main + +import ( + "context" + "flag" + "log" + "os" + "os/signal" + "sync" + "time" + + "github.com/hermeznetwork/hermez-node/test/proofserver" +) + +func main() { + var addr string + flag.StringVar(&addr, "a", "localhost:3000", "listen address") + var provingDuration time.Duration + flag.DurationVar(&provingDuration, "d", 2*time.Second, "proving time duration") //nolint:gomnd + flag.Parse() + + mock := proofserver.NewMock(addr, provingDuration) + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + go func() { + if err := mock.Run(ctx); err != nil { + log.Fatal(err) + } + wg.Done() + }() + + stopCh := make(chan interface{}) + // catch ^C to send the stop signal + ossig := make(chan os.Signal, 1) + signal.Notify(ossig, os.Interrupt) + go func() { + for sig := range ossig { + if sig == os.Interrupt { + stopCh <- nil + } + } + }() + <-stopCh + cancel() + wg.Wait() +} diff --git a/test/proofserver/proofserver.go b/test/proofserver/proofserver.go new file mode 100644 index 0000000..2586049 --- /dev/null +++ b/test/proofserver/proofserver.go @@ -0,0 +1,191 @@ +package proofserver + +import ( + "context" + "fmt" + "io/ioutil" + "net/http" + "sync" + "time" + + "github.com/gin-contrib/cors" + "github.com/gin-gonic/gin" + "github.com/hermeznetwork/hermez-node/log" + "github.com/hermeznetwork/hermez-node/prover" + "github.com/hermeznetwork/tracerr" +) + +type msg struct { + value string + ackCh chan bool +} + +func newMsg(value string) msg { + return msg{ + value: value, + ackCh: make(chan bool), + } +} + +// Mock proof server +type Mock struct { + addr string + status prover.StatusCode + sync.RWMutex + proof string + pubData string + counter int + msgCh chan msg + wg sync.WaitGroup + provingDuration time.Duration +} + +// NewMock creates a new mock server +func NewMock(addr string, provingDuration time.Duration) *Mock { + return &Mock{ + addr: addr, + status: prover.StatusCodeReady, + proof: "", + pubData: "", + counter: 0, + msgCh: make(chan msg), + provingDuration: provingDuration, + } +} + +func (s *Mock) err(c *gin.Context, err error) { + c.JSON(http.StatusInternalServerError, prover.ErrorServer{ + Status: "error", + Message: err.Error(), + }) +} + +func (s *Mock) handleCancel(c *gin.Context) { + msg := newMsg("cancel") + s.msgCh <- msg + <-msg.ackCh + c.JSON(http.StatusOK, "OK") +} + +func (s *Mock) handleStatus(c *gin.Context) { + s.RLock() + c.JSON(http.StatusOK, prover.Status{ + Status: s.status, + Proof: s.proof, + PubData: s.pubData, + }) + s.RUnlock() +} + +func (s *Mock) handleInput(c *gin.Context) { + s.RLock() + if !s.status.IsReady() { + s.err(c, fmt.Errorf("not ready")) + s.RUnlock() + return + } + s.RUnlock() + _, err := ioutil.ReadAll(c.Request.Body) + if err != nil { + s.err(c, err) + return + } + msg := newMsg("prove") + s.msgCh <- msg + <-msg.ackCh + c.JSON(http.StatusOK, "OK") +} + +const longWaitDuration = 999 * time.Hour + +// const provingDuration = 2 * time.Second + +func (s *Mock) runProver(ctx context.Context) { + waitDuration := longWaitDuration + for { + select { + case <-ctx.Done(): + return + case msg := <-s.msgCh: + switch msg.value { + case "cancel": + waitDuration = longWaitDuration + s.Lock() + if !s.status.IsReady() { + s.status = prover.StatusCodeAborted + } + s.Unlock() + case "prove": + waitDuration = s.provingDuration + s.Lock() + s.status = prover.StatusCodeBusy + s.Unlock() + } + msg.ackCh <- true + case <-time.After(waitDuration): + waitDuration = longWaitDuration + s.Lock() + if s.status != prover.StatusCodeBusy { + s.Unlock() + continue + } + i := s.counter * 100 //nolint:gomnd + s.counter++ + // Mock data + s.proof = fmt.Sprintf(`{ + "pi_a": ["%v", "%v"], + "pi_b": [["%v", "%v"],["%v", "%v"],["%v", "%v"]], + "pi_c": ["%v", "%v"], + "protocol": "groth16" + }`, i, i+1, i+2, i+3, i+4, i+5, i+6, i+7, i+8, i+9) //nolint:gomnd + s.pubData = fmt.Sprintf(`[ + "%v" + ]`, i+42) //nolint:gomnd + s.status = prover.StatusCodeSuccess + s.Unlock() + } + } +} + +// Run the mock server. Use ctx to stop it via cancel +func (s *Mock) Run(ctx context.Context) error { + api := gin.Default() + api.Use(cors.Default()) + + apiGroup := api.Group("/api") + apiGroup.GET("/status", s.handleStatus) + apiGroup.POST("/input", s.handleInput) + apiGroup.POST("/cancel", s.handleCancel) + + debugAPIServer := &http.Server{ + Addr: s.addr, + Handler: api, + // Use some hardcoded numberes that are suitable for testing + ReadTimeout: 30 * time.Second, //nolint:gomnd + WriteTimeout: 30 * time.Second, //nolint:gomnd + MaxHeaderBytes: 1 << 20, //nolint:gomnd + } + go func() { + log.Infof("prover.MockServer is ready at %v", s.addr) + if err := debugAPIServer.ListenAndServe(); err != nil && tracerr.Unwrap(err) != http.ErrServerClosed { + log.Fatalf("Listen: %s\n", err) + } + }() + s.wg.Add(1) + go func() { + s.runProver(ctx) + s.wg.Done() + }() + + <-ctx.Done() + log.Info("Stopping prover.MockServer...") + + s.wg.Wait() + ctxTimeout, cancel := context.WithTimeout(context.Background(), 10*time.Second) //nolint:gomnd + defer cancel() + if err := debugAPIServer.Shutdown(ctxTimeout); err != nil { + return tracerr.Wrap(err) + } + log.Info("prover.MockServer done") + return nil +}