You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

298 lines
8.4 KiB

  1. package prover
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "math/big"
  7. "net/http"
  8. "strings"
  9. "time"
  10. "github.com/dghubble/sling"
  11. "github.com/hermeznetwork/hermez-node/common"
  12. "github.com/hermeznetwork/tracerr"
  13. )
  14. // Proof TBD this type will be received from the proof server
  15. type Proof struct {
  16. PiA [2]*big.Int `json:"pi_a"`
  17. PiB [3][2]*big.Int `json:"pi_b"`
  18. PiC [2]*big.Int `json:"pi_c"`
  19. Protocol string `json:"protocol"`
  20. }
  21. type bigInt big.Int
  22. func (b *bigInt) UnmarshalText(text []byte) error {
  23. _, ok := (*big.Int)(b).SetString(string(text), 10)
  24. if !ok {
  25. return fmt.Errorf("invalid big int: \"%v\"", string(text))
  26. }
  27. return nil
  28. }
  29. // UnmarshalJSON unmarshals the proof from a JSON encoded proof with the big
  30. // ints as strings
  31. func (p *Proof) UnmarshalJSON(data []byte) error {
  32. proof := struct {
  33. PiA [2]*bigInt `json:"pi_a"`
  34. PiB [3][2]*bigInt `json:"pi_b"`
  35. PiC [2]*bigInt `json:"pi_c"`
  36. Protocol string `json:"protocol"`
  37. }{}
  38. if err := json.Unmarshal(data, &proof); err != nil {
  39. return err
  40. }
  41. p.PiA[0] = (*big.Int)(proof.PiA[0])
  42. p.PiA[1] = (*big.Int)(proof.PiA[1])
  43. p.PiB[0][0] = (*big.Int)(proof.PiB[0][0])
  44. p.PiB[0][1] = (*big.Int)(proof.PiB[0][1])
  45. p.PiB[1][0] = (*big.Int)(proof.PiB[1][0])
  46. p.PiB[1][1] = (*big.Int)(proof.PiB[1][1])
  47. p.PiB[2][0] = (*big.Int)(proof.PiB[2][0])
  48. p.PiB[2][1] = (*big.Int)(proof.PiB[2][1])
  49. p.PiC[0] = (*big.Int)(proof.PiC[0])
  50. p.PiC[1] = (*big.Int)(proof.PiC[1])
  51. p.Protocol = proof.Protocol
  52. return nil
  53. }
  54. // Client is the interface to a ServerProof that calculates zk proofs
  55. type Client interface {
  56. // Non-blocking
  57. CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error
  58. // Blocking
  59. GetProof(ctx context.Context) (*Proof, error)
  60. // Non-Blocking
  61. Cancel(ctx context.Context) error
  62. // Blocking
  63. WaitReady(ctx context.Context) error
  64. }
  65. // StatusCode is the status string of the ProofServer
  66. type StatusCode string
  67. const (
  68. // StatusCodeAborted means prover is ready to take new proof. Previous
  69. // proof was aborted.
  70. StatusCodeAborted StatusCode = "aborted"
  71. // StatusCodeBusy means prover is busy computing proof.
  72. StatusCodeBusy StatusCode = "busy"
  73. // StatusCodeFailed means prover is ready to take new proof. Previous
  74. // proof failed
  75. StatusCodeFailed StatusCode = "failed"
  76. // StatusCodeSuccess means prover is ready to take new proof. Previous
  77. // proof succeeded
  78. StatusCodeSuccess StatusCode = "success"
  79. // StatusCodeUnverified means prover is ready to take new proof.
  80. // Previous proof was unverified
  81. StatusCodeUnverified StatusCode = "unverified"
  82. // StatusCodeUninitialized means prover is not initialized
  83. StatusCodeUninitialized StatusCode = "uninitialized"
  84. // StatusCodeUndefined means prover is in an undefined state. Most
  85. // likely is booting up. Keep trying
  86. StatusCodeUndefined StatusCode = "undefined"
  87. // StatusCodeInitializing means prover is initializing and not ready yet
  88. StatusCodeInitializing StatusCode = "initializing"
  89. // StatusCodeReady means prover initialized and ready to do first proof
  90. StatusCodeReady StatusCode = "ready"
  91. )
  92. // IsReady returns true when the prover is ready
  93. func (status StatusCode) IsReady() bool {
  94. if status == StatusCodeAborted || status == StatusCodeFailed || status == StatusCodeSuccess ||
  95. status == StatusCodeUnverified || status == StatusCodeReady {
  96. return true
  97. }
  98. return false
  99. }
  100. // IsInitialized returns true when the prover is initialized
  101. func (status StatusCode) IsInitialized() bool {
  102. if status == StatusCodeUninitialized || status == StatusCodeUndefined ||
  103. status == StatusCodeInitializing {
  104. return false
  105. }
  106. return true
  107. }
  108. // Status is the return struct for the status API endpoint
  109. type Status struct {
  110. Status StatusCode `json:"status"`
  111. Proof string `json:"proof"`
  112. PubData string `json:"pubData"`
  113. }
  114. // ErrorServer is the return struct for an API error
  115. type ErrorServer struct {
  116. Status StatusCode `json:"status"`
  117. Message string `json:"msg"`
  118. }
  119. // Error message for ErrorServer
  120. func (e ErrorServer) Error() string {
  121. return fmt.Sprintf("server proof status (%v): %v", e.Status, e.Message)
  122. }
  123. type apiMethod string
  124. const (
  125. // GET is an HTTP GET
  126. GET apiMethod = "GET"
  127. // POST is an HTTP POST with maybe JSON body
  128. POST apiMethod = "POST"
  129. )
  130. // ProofServerClient contains the data related to a ProofServerClient
  131. type ProofServerClient struct {
  132. URL string
  133. client *sling.Sling
  134. pollInterval time.Duration
  135. }
  136. // NewProofServerClient creates a new ServerProof
  137. func NewProofServerClient(URL string, pollInterval time.Duration) *ProofServerClient {
  138. if URL[len(URL)-1] != '/' {
  139. URL += "/"
  140. }
  141. client := sling.New().Base(URL)
  142. return &ProofServerClient{URL: URL, client: client, pollInterval: pollInterval}
  143. }
  144. func (p *ProofServerClient) apiRequest(ctx context.Context, method apiMethod, path string,
  145. body interface{}, ret interface{}) error {
  146. path = strings.TrimPrefix(path, "/")
  147. var errSrv ErrorServer
  148. var req *http.Request
  149. var err error
  150. switch method {
  151. case GET:
  152. req, err = p.client.New().Get(path).Request()
  153. case POST:
  154. req, err = p.client.New().Post(path).BodyJSON(body).Request()
  155. default:
  156. return tracerr.Wrap(fmt.Errorf("invalid http method: %v", method))
  157. }
  158. if err != nil {
  159. return tracerr.Wrap(err)
  160. }
  161. res, err := p.client.Do(req.WithContext(ctx), ret, &errSrv)
  162. if err != nil {
  163. return tracerr.Wrap(err)
  164. }
  165. defer res.Body.Close() //nolint:errcheck
  166. if !(200 <= res.StatusCode && res.StatusCode < 300) {
  167. return tracerr.Wrap(errSrv)
  168. }
  169. return nil
  170. }
  171. func (p *ProofServerClient) apiStatus(ctx context.Context) (*Status, error) {
  172. var status Status
  173. return &status, tracerr.Wrap(p.apiRequest(ctx, GET, "/status", nil, &status))
  174. }
  175. func (p *ProofServerClient) apiCancel(ctx context.Context) error {
  176. return tracerr.Wrap(p.apiRequest(ctx, POST, "/cancel", nil, nil))
  177. }
  178. func (p *ProofServerClient) apiInput(ctx context.Context, zkInputs *common.ZKInputs) error {
  179. return tracerr.Wrap(p.apiRequest(ctx, POST, "/input", zkInputs, nil))
  180. }
  181. // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the
  182. // Proof
  183. func (p *ProofServerClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error {
  184. return tracerr.Wrap(p.apiInput(ctx, zkInputs))
  185. }
  186. // GetProof retreives the Proof from the ServerProof, blocking until the proof
  187. // is ready.
  188. func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, error) {
  189. if err := p.WaitReady(ctx); err != nil {
  190. return nil, err
  191. }
  192. status, err := p.apiStatus(ctx)
  193. if err != nil {
  194. return nil, tracerr.Wrap(err)
  195. }
  196. if status.Status == StatusCodeSuccess {
  197. var proof Proof
  198. err := json.Unmarshal([]byte(status.Proof), &proof)
  199. if err != nil {
  200. return nil, tracerr.Wrap(err)
  201. }
  202. return &proof, nil
  203. }
  204. return nil, fmt.Errorf("status != StatusCodeSuccess, status = %v", status.Status)
  205. }
  206. // Cancel cancels any current proof computation
  207. func (p *ProofServerClient) Cancel(ctx context.Context) error {
  208. return tracerr.Wrap(p.apiCancel(ctx))
  209. }
  210. // WaitReady waits until the serverProof is ready
  211. func (p *ProofServerClient) WaitReady(ctx context.Context) error {
  212. for {
  213. status, err := p.apiStatus(ctx)
  214. if err != nil {
  215. return tracerr.Wrap(err)
  216. }
  217. if !status.Status.IsInitialized() {
  218. return fmt.Errorf("Proof Server is not initialized")
  219. }
  220. if status.Status.IsReady() {
  221. return nil
  222. }
  223. select {
  224. case <-ctx.Done():
  225. return tracerr.Wrap(common.ErrDone)
  226. case <-time.After(p.pollInterval):
  227. }
  228. }
  229. }
  230. // MockClient is a mock ServerProof to be used in tests. It doesn't calculate anything
  231. type MockClient struct {
  232. }
  233. // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the
  234. // Proof
  235. func (p *MockClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error {
  236. return nil
  237. }
  238. // GetProof retreives the Proof from the ServerProof
  239. func (p *MockClient) GetProof(ctx context.Context) (*Proof, error) {
  240. // Simulate a delay
  241. select {
  242. case <-time.After(500 * time.Millisecond): //nolint:gomnd
  243. return &Proof{}, nil
  244. case <-ctx.Done():
  245. return nil, tracerr.Wrap(common.ErrDone)
  246. }
  247. }
  248. // Cancel cancels any current proof computation
  249. func (p *MockClient) Cancel(ctx context.Context) error {
  250. // Simulate a delay
  251. select {
  252. case <-time.After(80 * time.Millisecond): //nolint:gomnd
  253. return nil
  254. case <-ctx.Done():
  255. return tracerr.Wrap(common.ErrDone)
  256. }
  257. }
  258. // WaitReady waits until the prover is ready
  259. func (p *MockClient) WaitReady(ctx context.Context) error {
  260. // Simulate a delay
  261. select {
  262. case <-time.After(200 * time.Millisecond): //nolint:gomnd
  263. return nil
  264. case <-ctx.Done():
  265. return tracerr.Wrap(common.ErrDone)
  266. }
  267. }