You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

318 lines
9.3 KiB

  1. package prover
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "math/big"
  7. "net/http"
  8. "strings"
  9. "time"
  10. "github.com/dghubble/sling"
  11. "github.com/hermeznetwork/hermez-node/common"
  12. "github.com/hermeznetwork/tracerr"
  13. )
  14. // Proof TBD this type will be received from the proof server
  15. type Proof struct {
  16. PiA [2]*big.Int `json:"pi_a"`
  17. PiB [3][2]*big.Int `json:"pi_b"`
  18. PiC [2]*big.Int `json:"pi_c"`
  19. Protocol string `json:"protocol"`
  20. }
  21. type bigInt big.Int
  22. func (b *bigInt) UnmarshalText(text []byte) error {
  23. _, ok := (*big.Int)(b).SetString(string(text), 10)
  24. if !ok {
  25. return fmt.Errorf("invalid big int: \"%v\"", string(text))
  26. }
  27. return nil
  28. }
  29. // UnmarshalJSON unmarshals the proof from a JSON encoded proof with the big
  30. // ints as strings
  31. func (p *Proof) UnmarshalJSON(data []byte) error {
  32. proof := struct {
  33. PiA [2]*bigInt `json:"pi_a"`
  34. PiB [3][2]*bigInt `json:"pi_b"`
  35. PiC [2]*bigInt `json:"pi_c"`
  36. Protocol string `json:"protocol"`
  37. }{}
  38. if err := json.Unmarshal(data, &proof); err != nil {
  39. return err
  40. }
  41. p.PiA[0] = (*big.Int)(proof.PiA[0])
  42. p.PiA[1] = (*big.Int)(proof.PiA[1])
  43. p.PiB[0][0] = (*big.Int)(proof.PiB[0][0])
  44. p.PiB[0][1] = (*big.Int)(proof.PiB[0][1])
  45. p.PiB[1][0] = (*big.Int)(proof.PiB[1][0])
  46. p.PiB[1][1] = (*big.Int)(proof.PiB[1][1])
  47. p.PiB[2][0] = (*big.Int)(proof.PiB[2][0])
  48. p.PiB[2][1] = (*big.Int)(proof.PiB[2][1])
  49. p.PiC[0] = (*big.Int)(proof.PiC[0])
  50. p.PiC[1] = (*big.Int)(proof.PiC[1])
  51. p.Protocol = proof.Protocol
  52. return nil
  53. }
  54. // PublicInputs are the public inputs of the proof
  55. type PublicInputs []*big.Int
  56. // UnmarshalJSON unmarshals the JSON into the public inputs where the bigInts
  57. // are in decimal as quoted strings
  58. func (p *PublicInputs) UnmarshalJSON(data []byte) error {
  59. pubInputs := []*bigInt{}
  60. if err := json.Unmarshal(data, &pubInputs); err != nil {
  61. return err
  62. }
  63. *p = make([]*big.Int, len(pubInputs))
  64. for i, v := range pubInputs {
  65. ([]*big.Int)(*p)[i] = (*big.Int)(v)
  66. }
  67. return nil
  68. }
  69. // Client is the interface to a ServerProof that calculates zk proofs
  70. type Client interface {
  71. // Non-blocking
  72. CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error
  73. // Blocking. Returns the Proof and Public Data (public inputs)
  74. GetProof(ctx context.Context) (*Proof, []*big.Int, error)
  75. // Non-Blocking
  76. Cancel(ctx context.Context) error
  77. // Blocking
  78. WaitReady(ctx context.Context) error
  79. }
  80. // StatusCode is the status string of the ProofServer
  81. type StatusCode string
  82. const (
  83. // StatusCodeAborted means prover is ready to take new proof. Previous
  84. // proof was aborted.
  85. StatusCodeAborted StatusCode = "aborted"
  86. // StatusCodeBusy means prover is busy computing proof.
  87. StatusCodeBusy StatusCode = "busy"
  88. // StatusCodeFailed means prover is ready to take new proof. Previous
  89. // proof failed
  90. StatusCodeFailed StatusCode = "failed"
  91. // StatusCodeSuccess means prover is ready to take new proof. Previous
  92. // proof succeeded
  93. StatusCodeSuccess StatusCode = "success"
  94. // StatusCodeUnverified means prover is ready to take new proof.
  95. // Previous proof was unverified
  96. StatusCodeUnverified StatusCode = "unverified"
  97. // StatusCodeUninitialized means prover is not initialized
  98. StatusCodeUninitialized StatusCode = "uninitialized"
  99. // StatusCodeUndefined means prover is in an undefined state. Most
  100. // likely is booting up. Keep trying
  101. StatusCodeUndefined StatusCode = "undefined"
  102. // StatusCodeInitializing means prover is initializing and not ready yet
  103. StatusCodeInitializing StatusCode = "initializing"
  104. // StatusCodeReady means prover initialized and ready to do first proof
  105. StatusCodeReady StatusCode = "ready"
  106. )
  107. // IsReady returns true when the prover is ready
  108. func (status StatusCode) IsReady() bool {
  109. if status == StatusCodeAborted || status == StatusCodeFailed || status == StatusCodeSuccess ||
  110. status == StatusCodeUnverified || status == StatusCodeReady {
  111. return true
  112. }
  113. return false
  114. }
  115. // IsInitialized returns true when the prover is initialized
  116. func (status StatusCode) IsInitialized() bool {
  117. if status == StatusCodeUninitialized || status == StatusCodeUndefined ||
  118. status == StatusCodeInitializing {
  119. return false
  120. }
  121. return true
  122. }
  123. // Status is the return struct for the status API endpoint
  124. type Status struct {
  125. Status StatusCode `json:"status"`
  126. Proof string `json:"proof"`
  127. PubData string `json:"pubData"`
  128. }
  129. // ErrorServer is the return struct for an API error
  130. type ErrorServer struct {
  131. Status StatusCode `json:"status"`
  132. Message string `json:"msg"`
  133. }
  134. // Error message for ErrorServer
  135. func (e ErrorServer) Error() string {
  136. return fmt.Sprintf("server proof status (%v): %v", e.Status, e.Message)
  137. }
  138. type apiMethod string
  139. const (
  140. // GET is an HTTP GET
  141. GET apiMethod = "GET"
  142. // POST is an HTTP POST with maybe JSON body
  143. POST apiMethod = "POST"
  144. )
  145. // ProofServerClient contains the data related to a ProofServerClient
  146. type ProofServerClient struct {
  147. URL string
  148. client *sling.Sling
  149. pollInterval time.Duration
  150. }
  151. // NewProofServerClient creates a new ServerProof
  152. func NewProofServerClient(URL string, pollInterval time.Duration) *ProofServerClient {
  153. if URL[len(URL)-1] != '/' {
  154. URL += "/"
  155. }
  156. client := sling.New().Base(URL)
  157. return &ProofServerClient{URL: URL, client: client, pollInterval: pollInterval}
  158. }
  159. func (p *ProofServerClient) apiRequest(ctx context.Context, method apiMethod, path string,
  160. body interface{}, ret interface{}) error {
  161. path = strings.TrimPrefix(path, "/")
  162. var errSrv ErrorServer
  163. var req *http.Request
  164. var err error
  165. switch method {
  166. case GET:
  167. req, err = p.client.New().Get(path).Request()
  168. case POST:
  169. req, err = p.client.New().Post(path).BodyJSON(body).Request()
  170. default:
  171. return tracerr.Wrap(fmt.Errorf("invalid http method: %v", method))
  172. }
  173. if err != nil {
  174. return tracerr.Wrap(err)
  175. }
  176. res, err := p.client.Do(req.WithContext(ctx), ret, &errSrv)
  177. if err != nil {
  178. return tracerr.Wrap(err)
  179. }
  180. defer res.Body.Close() //nolint:errcheck
  181. if !(200 <= res.StatusCode && res.StatusCode < 300) {
  182. return tracerr.Wrap(errSrv)
  183. }
  184. return nil
  185. }
  186. func (p *ProofServerClient) apiStatus(ctx context.Context) (*Status, error) {
  187. var status Status
  188. return &status, tracerr.Wrap(p.apiRequest(ctx, GET, "/status", nil, &status))
  189. }
  190. func (p *ProofServerClient) apiCancel(ctx context.Context) error {
  191. return tracerr.Wrap(p.apiRequest(ctx, POST, "/cancel", nil, nil))
  192. }
  193. func (p *ProofServerClient) apiInput(ctx context.Context, zkInputs *common.ZKInputs) error {
  194. return tracerr.Wrap(p.apiRequest(ctx, POST, "/input", zkInputs, nil))
  195. }
  196. // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the
  197. // Proof
  198. func (p *ProofServerClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error {
  199. return tracerr.Wrap(p.apiInput(ctx, zkInputs))
  200. }
  201. // GetProof retreives the Proof and Public Data (public inputs) from the
  202. // ServerProof, blocking until the proof is ready.
  203. func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, []*big.Int, error) {
  204. if err := p.WaitReady(ctx); err != nil {
  205. return nil, nil, tracerr.Wrap(err)
  206. }
  207. status, err := p.apiStatus(ctx)
  208. if err != nil {
  209. return nil, nil, tracerr.Wrap(err)
  210. }
  211. if status.Status == StatusCodeSuccess {
  212. var proof Proof
  213. if err := json.Unmarshal([]byte(status.Proof), &proof); err != nil {
  214. return nil, nil, tracerr.Wrap(err)
  215. }
  216. var pubInputs PublicInputs
  217. if err := json.Unmarshal([]byte(status.PubData), &pubInputs); err != nil {
  218. return nil, nil, tracerr.Wrap(err)
  219. }
  220. return &proof, pubInputs, nil
  221. }
  222. return nil, nil, fmt.Errorf("status != StatusCodeSuccess, status = %v", status.Status)
  223. }
  224. // Cancel cancels any current proof computation
  225. func (p *ProofServerClient) Cancel(ctx context.Context) error {
  226. return tracerr.Wrap(p.apiCancel(ctx))
  227. }
  228. // WaitReady waits until the serverProof is ready
  229. func (p *ProofServerClient) WaitReady(ctx context.Context) error {
  230. for {
  231. status, err := p.apiStatus(ctx)
  232. if err != nil {
  233. return tracerr.Wrap(err)
  234. }
  235. if !status.Status.IsInitialized() {
  236. return fmt.Errorf("Proof Server is not initialized")
  237. }
  238. if status.Status.IsReady() {
  239. return nil
  240. }
  241. select {
  242. case <-ctx.Done():
  243. return tracerr.Wrap(common.ErrDone)
  244. case <-time.After(p.pollInterval):
  245. }
  246. }
  247. }
  248. // MockClient is a mock ServerProof to be used in tests. It doesn't calculate anything
  249. type MockClient struct {
  250. }
  251. // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the
  252. // Proof
  253. func (p *MockClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error {
  254. return nil
  255. }
  256. // GetProof retreives the Proof from the ServerProof
  257. func (p *MockClient) GetProof(ctx context.Context) (*Proof, []*big.Int, error) {
  258. // Simulate a delay
  259. select {
  260. case <-time.After(500 * time.Millisecond): //nolint:gomnd
  261. return &Proof{}, []*big.Int{big.NewInt(1234)}, nil //nolint:gomnd
  262. case <-ctx.Done():
  263. return nil, nil, tracerr.Wrap(common.ErrDone)
  264. }
  265. }
  266. // Cancel cancels any current proof computation
  267. func (p *MockClient) Cancel(ctx context.Context) error {
  268. // Simulate a delay
  269. select {
  270. case <-time.After(80 * time.Millisecond): //nolint:gomnd
  271. return nil
  272. case <-ctx.Done():
  273. return tracerr.Wrap(common.ErrDone)
  274. }
  275. }
  276. // WaitReady waits until the prover is ready
  277. func (p *MockClient) WaitReady(ctx context.Context) error {
  278. // Simulate a delay
  279. select {
  280. case <-time.After(200 * time.Millisecond): //nolint:gomnd
  281. return nil
  282. case <-ctx.Done():
  283. return tracerr.Wrap(common.ErrDone)
  284. }
  285. }