You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

291 lines
7.8 KiB

  1. package prover
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "mime/multipart"
  9. "net/http"
  10. "strings"
  11. "time"
  12. "github.com/dghubble/sling"
  13. "github.com/hermeznetwork/hermez-node/common"
  14. "github.com/hermeznetwork/hermez-node/log"
  15. "github.com/hermeznetwork/tracerr"
  16. )
  17. // Proof TBD this type will be received from the proof server
  18. type Proof struct {
  19. }
  20. // Client is the interface to a ServerProof that calculates zk proofs
  21. type Client interface {
  22. // Non-blocking
  23. CalculateProof(zkInputs *common.ZKInputs) error
  24. // Blocking
  25. GetProof(ctx context.Context) (*Proof, error)
  26. // Non-Blocking
  27. Cancel(ctx context.Context) error
  28. // Blocking
  29. WaitReady(ctx context.Context) error
  30. }
  31. // StatusCode is the status string of the ProofServer
  32. type StatusCode string
  33. const (
  34. // StatusCodeAborted means prover is ready to take new proof. Previous
  35. // proof was aborted.
  36. StatusCodeAborted StatusCode = "aborted"
  37. // StatusCodeBusy means prover is busy computing proof.
  38. StatusCodeBusy StatusCode = "busy"
  39. // StatusCodeFailed means prover is ready to take new proof. Previous
  40. // proof failed
  41. StatusCodeFailed StatusCode = "failed"
  42. // StatusCodeSuccess means prover is ready to take new proof. Previous
  43. // proof succeeded
  44. StatusCodeSuccess StatusCode = "success"
  45. // StatusCodeUnverified means prover is ready to take new proof.
  46. // Previous proof was unverified
  47. StatusCodeUnverified StatusCode = "unverified"
  48. // StatusCodeUninitialized means prover is not initialized
  49. StatusCodeUninitialized StatusCode = "uninitialized"
  50. // StatusCodeUndefined means prover is in an undefined state. Most
  51. // likely is booting up. Keep trying
  52. StatusCodeUndefined StatusCode = "undefined"
  53. // StatusCodeInitializing means prover is initializing and not ready yet
  54. StatusCodeInitializing StatusCode = "initializing"
  55. // StatusCodeReady means prover initialized and ready to do first proof
  56. StatusCodeReady StatusCode = "ready"
  57. )
  58. // IsReady returns true when the prover is ready
  59. func (status StatusCode) IsReady() bool {
  60. if status == StatusCodeAborted || status == StatusCodeFailed || status == StatusCodeSuccess ||
  61. status == StatusCodeUnverified || status == StatusCodeReady {
  62. return true
  63. }
  64. return false
  65. }
  66. // IsInitialized returns true when the prover is initialized
  67. func (status StatusCode) IsInitialized() bool {
  68. if status == StatusCodeUninitialized || status == StatusCodeUndefined ||
  69. status == StatusCodeInitializing {
  70. return false
  71. }
  72. return true
  73. }
  74. // Status is the return struct for the status API endpoint
  75. type Status struct {
  76. Status StatusCode `json:"status"`
  77. Proof string `json:"proof"`
  78. PubData string `json:"pubData"`
  79. }
  80. // ErrorServer is the return struct for an API error
  81. type ErrorServer struct {
  82. Status StatusCode `json:"status"`
  83. Message string `json:"msg"`
  84. }
  85. // Error message for ErrorServer
  86. func (e ErrorServer) Error() string {
  87. return fmt.Sprintf("server proof status (%v): %v", e.Status, e.Message)
  88. }
  89. type apiMethod string
  90. const (
  91. // GET is an HTTP GET
  92. GET apiMethod = "GET"
  93. // POST is an HTTP POST with maybe JSON body
  94. POST apiMethod = "POST"
  95. // POSTFILE is an HTTP POST with a form file
  96. POSTFILE apiMethod = "POSTFILE"
  97. )
  98. // ProofServerClient contains the data related to a ProofServerClient
  99. type ProofServerClient struct {
  100. URL string
  101. client *sling.Sling
  102. }
  103. // NewProofServerClient creates a new ServerProof
  104. func NewProofServerClient(URL string) *ProofServerClient {
  105. if URL[len(URL)-1] != '/' {
  106. URL += "/"
  107. }
  108. client := sling.New().Base(URL)
  109. return &ProofServerClient{URL: URL, client: client}
  110. }
  111. //nolint:unused
  112. type formFileProvider struct {
  113. writer *multipart.Writer
  114. body []byte
  115. }
  116. //nolint:unused
  117. func newFormFileProvider(payload interface{}) (*formFileProvider, error) {
  118. body := new(bytes.Buffer)
  119. writer := multipart.NewWriter(body)
  120. part, err := writer.CreateFormFile("file", "file.json")
  121. if err != nil {
  122. return nil, tracerr.Wrap(err)
  123. }
  124. if err := json.NewEncoder(part).Encode(payload); err != nil {
  125. return nil, tracerr.Wrap(err)
  126. }
  127. if err := writer.Close(); err != nil {
  128. return nil, tracerr.Wrap(err)
  129. }
  130. return &formFileProvider{
  131. writer: writer,
  132. body: body.Bytes(),
  133. }, nil
  134. }
  135. func (p formFileProvider) ContentType() string {
  136. return p.writer.FormDataContentType()
  137. }
  138. func (p formFileProvider) Body() (io.Reader, error) {
  139. return bytes.NewReader(p.body), nil
  140. }
  141. //nolint:unused
  142. func (p *ProofServerClient) apiRequest(ctx context.Context, method apiMethod, path string,
  143. body interface{}, ret interface{}) error {
  144. path = strings.TrimPrefix(path, "/")
  145. var errSrv ErrorServer
  146. var req *http.Request
  147. var err error
  148. switch method {
  149. case GET:
  150. req, err = p.client.New().Get(path).Request()
  151. case POST:
  152. req, err = p.client.New().Post(path).BodyJSON(body).Request()
  153. case POSTFILE:
  154. provider, err := newFormFileProvider(body)
  155. if err != nil {
  156. return tracerr.Wrap(err)
  157. }
  158. req, err = p.client.New().Post(path).BodyProvider(provider).Request()
  159. if err != nil {
  160. return tracerr.Wrap(err)
  161. }
  162. default:
  163. return tracerr.Wrap(fmt.Errorf("invalid http method: %v", method))
  164. }
  165. if err != nil {
  166. return tracerr.Wrap(err)
  167. }
  168. res, err := p.client.Do(req.WithContext(ctx), ret, &errSrv)
  169. if err != nil {
  170. return tracerr.Wrap(err)
  171. }
  172. defer res.Body.Close() //nolint:errcheck
  173. if !(200 <= res.StatusCode && res.StatusCode < 300) {
  174. return tracerr.Wrap(errSrv)
  175. }
  176. return nil
  177. }
  178. //nolint:unused
  179. func (p *ProofServerClient) apiStatus(ctx context.Context) (*Status, error) {
  180. var status Status
  181. if err := p.apiRequest(ctx, GET, "/status", nil, &status); err != nil {
  182. return nil, tracerr.Wrap(err)
  183. }
  184. return &status, nil
  185. }
  186. //nolint:unused
  187. func (p *ProofServerClient) apiCancel(ctx context.Context) error {
  188. if err := p.apiRequest(ctx, POST, "/cancel", nil, nil); err != nil {
  189. return tracerr.Wrap(err)
  190. }
  191. return nil
  192. }
  193. //nolint:unused
  194. func (p *ProofServerClient) apiInput(ctx context.Context, zkInputs *common.ZKInputs) error {
  195. if err := p.apiRequest(ctx, POSTFILE, "/input", zkInputs, nil); err != nil {
  196. return tracerr.Wrap(err)
  197. }
  198. return nil
  199. }
  200. // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the
  201. // Proof
  202. func (p *ProofServerClient) CalculateProof(zkInputs *common.ZKInputs) error {
  203. log.Error("TODO")
  204. return tracerr.Wrap(common.ErrTODO)
  205. }
  206. // GetProof retreives the Proof from the ServerProof, blocking until the proof
  207. // is ready.
  208. func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, error) {
  209. log.Error("TODO")
  210. return nil, tracerr.Wrap(common.ErrTODO)
  211. }
  212. // Cancel cancels any current proof computation
  213. func (p *ProofServerClient) Cancel(ctx context.Context) error {
  214. log.Error("TODO")
  215. return tracerr.Wrap(common.ErrTODO)
  216. }
  217. // WaitReady waits until the serverProof is ready
  218. func (p *ProofServerClient) WaitReady(ctx context.Context) error {
  219. log.Error("TODO")
  220. return tracerr.Wrap(common.ErrTODO)
  221. }
  222. // MockClient is a mock ServerProof to be used in tests. It doesn't calculate anything
  223. type MockClient struct {
  224. }
  225. // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the
  226. // Proof
  227. func (p *MockClient) CalculateProof(zkInputs *common.ZKInputs) error {
  228. return nil
  229. }
  230. // GetProof retreives the Proof from the ServerProof
  231. func (p *MockClient) GetProof(ctx context.Context) (*Proof, error) {
  232. // Simulate a delay
  233. select {
  234. case <-time.After(500 * time.Millisecond): //nolint:gomnd
  235. return &Proof{}, nil
  236. case <-ctx.Done():
  237. return nil, tracerr.Wrap(common.ErrDone)
  238. }
  239. }
  240. // Cancel cancels any current proof computation
  241. func (p *MockClient) Cancel(ctx context.Context) error {
  242. // Simulate a delay
  243. select {
  244. case <-time.After(80 * time.Millisecond): //nolint:gomnd
  245. return nil
  246. case <-ctx.Done():
  247. return tracerr.Wrap(common.ErrDone)
  248. }
  249. }
  250. // WaitReady waits until the prover is ready
  251. func (p *MockClient) WaitReady(ctx context.Context) error {
  252. // Simulate a delay
  253. select {
  254. case <-time.After(200 * time.Millisecond): //nolint:gomnd
  255. return nil
  256. case <-ctx.Done():
  257. return tracerr.Wrap(common.ErrDone)
  258. }
  259. }