You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

348 lines
10 KiB

  1. package prover
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "math/big"
  7. "net/http"
  8. "strings"
  9. "time"
  10. "github.com/dghubble/sling"
  11. "github.com/hermeznetwork/hermez-node/common"
  12. "github.com/hermeznetwork/tracerr"
  13. )
  14. // Proof TBD this type will be received from the proof server
  15. type Proof struct {
  16. PiA [3]*big.Int `json:"pi_a"`
  17. PiB [3][2]*big.Int `json:"pi_b"`
  18. PiC [3]*big.Int `json:"pi_c"`
  19. Protocol string `json:"protocol"`
  20. }
  21. type bigInt big.Int
  22. func (b *bigInt) UnmarshalText(text []byte) error {
  23. _, ok := (*big.Int)(b).SetString(string(text), 10)
  24. if !ok {
  25. return tracerr.Wrap(fmt.Errorf("invalid big int: \"%v\"", string(text)))
  26. }
  27. return nil
  28. }
  29. // UnmarshalJSON unmarshals the proof from a JSON encoded proof with the big
  30. // ints as strings
  31. func (p *Proof) UnmarshalJSON(data []byte) error {
  32. proof := struct {
  33. PiA [3]*bigInt `json:"pi_a"`
  34. PiB [3][2]*bigInt `json:"pi_b"`
  35. PiC [3]*bigInt `json:"pi_c"`
  36. Protocol string `json:"protocol"`
  37. }{}
  38. if err := json.Unmarshal(data, &proof); err != nil {
  39. return tracerr.Wrap(err)
  40. }
  41. p.PiA[0] = (*big.Int)(proof.PiA[0])
  42. p.PiA[1] = (*big.Int)(proof.PiA[1])
  43. p.PiA[2] = (*big.Int)(proof.PiA[2])
  44. if p.PiA[2].Int64() != 1 {
  45. return tracerr.Wrap(fmt.Errorf("Expected PiA[2] == 1, but got %v", p.PiA[2]))
  46. }
  47. p.PiB[0][0] = (*big.Int)(proof.PiB[0][0])
  48. p.PiB[0][1] = (*big.Int)(proof.PiB[0][1])
  49. p.PiB[1][0] = (*big.Int)(proof.PiB[1][0])
  50. p.PiB[1][1] = (*big.Int)(proof.PiB[1][1])
  51. p.PiB[2][0] = (*big.Int)(proof.PiB[2][0])
  52. p.PiB[2][1] = (*big.Int)(proof.PiB[2][1])
  53. if p.PiB[2][0].Int64() != 1 || p.PiB[2][1].Int64() != 0 {
  54. return tracerr.Wrap(fmt.Errorf("Expected PiB[2] == [1, 0], but got %v", p.PiB[2]))
  55. }
  56. p.PiC[0] = (*big.Int)(proof.PiC[0])
  57. p.PiC[1] = (*big.Int)(proof.PiC[1])
  58. p.PiC[2] = (*big.Int)(proof.PiC[2])
  59. if p.PiC[2].Int64() != 1 {
  60. return tracerr.Wrap(fmt.Errorf("Expected PiC[2] == 1, but got %v", p.PiC[2]))
  61. }
  62. p.Protocol = proof.Protocol
  63. return nil
  64. }
  65. // PublicInputs are the public inputs of the proof
  66. type PublicInputs []*big.Int
  67. // UnmarshalJSON unmarshals the JSON into the public inputs where the bigInts
  68. // are in decimal as quoted strings
  69. func (p *PublicInputs) UnmarshalJSON(data []byte) error {
  70. pubInputs := []*bigInt{}
  71. if err := json.Unmarshal(data, &pubInputs); err != nil {
  72. return tracerr.Wrap(err)
  73. }
  74. *p = make([]*big.Int, len(pubInputs))
  75. for i, v := range pubInputs {
  76. ([]*big.Int)(*p)[i] = (*big.Int)(v)
  77. }
  78. return nil
  79. }
  80. // Client is the interface to a ServerProof that calculates zk proofs
  81. type Client interface {
  82. // Non-blocking
  83. CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error
  84. // Blocking. Returns the Proof and Public Data (public inputs)
  85. GetProof(ctx context.Context) (*Proof, []*big.Int, error)
  86. // Non-Blocking
  87. Cancel(ctx context.Context) error
  88. // Blocking
  89. WaitReady(ctx context.Context) error
  90. }
  91. // StatusCode is the status string of the ProofServer
  92. type StatusCode string
  93. const (
  94. // StatusCodeAborted means prover is ready to take new proof. Previous
  95. // proof was aborted.
  96. StatusCodeAborted StatusCode = "aborted"
  97. // StatusCodeBusy means prover is busy computing proof.
  98. StatusCodeBusy StatusCode = "busy"
  99. // StatusCodeFailed means prover is ready to take new proof. Previous
  100. // proof failed
  101. StatusCodeFailed StatusCode = "failed"
  102. // StatusCodeSuccess means prover is ready to take new proof. Previous
  103. // proof succeeded
  104. StatusCodeSuccess StatusCode = "success"
  105. // StatusCodeUnverified means prover is ready to take new proof.
  106. // Previous proof was unverified
  107. StatusCodeUnverified StatusCode = "unverified"
  108. // StatusCodeUninitialized means prover is not initialized
  109. StatusCodeUninitialized StatusCode = "uninitialized"
  110. // StatusCodeUndefined means prover is in an undefined state. Most
  111. // likely is booting up. Keep trying
  112. StatusCodeUndefined StatusCode = "undefined"
  113. // StatusCodeInitializing means prover is initializing and not ready yet
  114. StatusCodeInitializing StatusCode = "initializing"
  115. // StatusCodeReady means prover initialized and ready to do first proof
  116. StatusCodeReady StatusCode = "ready"
  117. )
  118. // IsReady returns true when the prover is ready
  119. func (status StatusCode) IsReady() bool {
  120. if status == StatusCodeAborted || status == StatusCodeFailed || status == StatusCodeSuccess ||
  121. status == StatusCodeUnverified || status == StatusCodeReady {
  122. return true
  123. }
  124. return false
  125. }
  126. // IsInitialized returns true when the prover is initialized
  127. func (status StatusCode) IsInitialized() bool {
  128. if status == StatusCodeUninitialized || status == StatusCodeUndefined ||
  129. status == StatusCodeInitializing {
  130. return false
  131. }
  132. return true
  133. }
  134. // Status is the return struct for the status API endpoint
  135. type Status struct {
  136. Status StatusCode `json:"status"`
  137. Proof string `json:"proof"`
  138. PubData string `json:"pubData"`
  139. }
  140. // ErrorServer is the return struct for an API error
  141. type ErrorServer struct {
  142. Status StatusCode `json:"status"`
  143. Message string `json:"msg"`
  144. }
  145. // Error message for ErrorServer
  146. func (e ErrorServer) Error() string {
  147. return fmt.Sprintf("server proof status (%v): %v", e.Status, e.Message)
  148. }
  149. type apiMethod string
  150. const (
  151. // GET is an HTTP GET
  152. GET apiMethod = "GET"
  153. // POST is an HTTP POST with maybe JSON body
  154. POST apiMethod = "POST"
  155. )
  156. // ProofServerClient contains the data related to a ProofServerClient
  157. type ProofServerClient struct {
  158. URL string
  159. client *sling.Sling
  160. pollInterval time.Duration
  161. }
  162. // NewProofServerClient creates a new ServerProof
  163. func NewProofServerClient(URL string, pollInterval time.Duration) *ProofServerClient {
  164. if URL[len(URL)-1] != '/' {
  165. URL += "/"
  166. }
  167. client := sling.New().Base(URL)
  168. return &ProofServerClient{URL: URL, client: client, pollInterval: pollInterval}
  169. }
  170. func (p *ProofServerClient) apiRequest(ctx context.Context, method apiMethod, path string,
  171. body interface{}, ret interface{}) error {
  172. path = strings.TrimPrefix(path, "/")
  173. var errSrv ErrorServer
  174. var req *http.Request
  175. var err error
  176. switch method {
  177. case GET:
  178. req, err = p.client.New().Get(path).Request()
  179. case POST:
  180. req, err = p.client.New().Post(path).BodyJSON(body).Request()
  181. default:
  182. return tracerr.Wrap(fmt.Errorf("invalid http method: %v", method))
  183. }
  184. if err != nil {
  185. return tracerr.Wrap(err)
  186. }
  187. res, err := p.client.Do(req.WithContext(ctx), ret, &errSrv)
  188. if err != nil {
  189. return tracerr.Wrap(err)
  190. }
  191. defer res.Body.Close() //nolint:errcheck
  192. if !(200 <= res.StatusCode && res.StatusCode < 300) {
  193. return tracerr.Wrap(errSrv)
  194. }
  195. return nil
  196. }
  197. func (p *ProofServerClient) apiStatus(ctx context.Context) (*Status, error) {
  198. var status Status
  199. return &status, tracerr.Wrap(p.apiRequest(ctx, GET, "/status", nil, &status))
  200. }
  201. func (p *ProofServerClient) apiCancel(ctx context.Context) error {
  202. return tracerr.Wrap(p.apiRequest(ctx, POST, "/cancel", nil, nil))
  203. }
  204. func (p *ProofServerClient) apiInput(ctx context.Context, zkInputs *common.ZKInputs) error {
  205. return tracerr.Wrap(p.apiRequest(ctx, POST, "/input", zkInputs, nil))
  206. }
  207. // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the
  208. // Proof
  209. func (p *ProofServerClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error {
  210. return tracerr.Wrap(p.apiInput(ctx, zkInputs))
  211. }
  212. // GetProof retreives the Proof and Public Data (public inputs) from the
  213. // ServerProof, blocking until the proof is ready.
  214. func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, []*big.Int, error) {
  215. if err := p.WaitReady(ctx); err != nil {
  216. return nil, nil, tracerr.Wrap(err)
  217. }
  218. status, err := p.apiStatus(ctx)
  219. if err != nil {
  220. return nil, nil, tracerr.Wrap(err)
  221. }
  222. if status.Status == StatusCodeSuccess {
  223. var proof Proof
  224. if err := json.Unmarshal([]byte(status.Proof), &proof); err != nil {
  225. return nil, nil, tracerr.Wrap(err)
  226. }
  227. var pubInputs PublicInputs
  228. if err := json.Unmarshal([]byte(status.PubData), &pubInputs); err != nil {
  229. return nil, nil, tracerr.Wrap(err)
  230. }
  231. return &proof, pubInputs, nil
  232. }
  233. return nil, nil, tracerr.Wrap(fmt.Errorf("status != StatusCodeSuccess, status = %v", status.Status))
  234. }
  235. // Cancel cancels any current proof computation
  236. func (p *ProofServerClient) Cancel(ctx context.Context) error {
  237. return tracerr.Wrap(p.apiCancel(ctx))
  238. }
  239. // WaitReady waits until the serverProof is ready
  240. func (p *ProofServerClient) WaitReady(ctx context.Context) error {
  241. for {
  242. status, err := p.apiStatus(ctx)
  243. if err != nil {
  244. return tracerr.Wrap(err)
  245. }
  246. if !status.Status.IsInitialized() {
  247. return tracerr.Wrap(fmt.Errorf("Proof Server is not initialized"))
  248. }
  249. if status.Status.IsReady() {
  250. return nil
  251. }
  252. select {
  253. case <-ctx.Done():
  254. return tracerr.Wrap(common.ErrDone)
  255. case <-time.After(p.pollInterval):
  256. }
  257. }
  258. }
  259. // MockClient is a mock ServerProof to be used in tests. It doesn't calculate anything
  260. type MockClient struct {
  261. counter int64
  262. Delay time.Duration
  263. }
  264. // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the
  265. // Proof
  266. func (p *MockClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error {
  267. return nil
  268. }
  269. // GetProof retreives the Proof from the ServerProof
  270. func (p *MockClient) GetProof(ctx context.Context) (*Proof, []*big.Int, error) {
  271. // Simulate a delay
  272. select {
  273. case <-time.After(p.Delay): //nolint:gomnd
  274. i := p.counter * 100 //nolint:gomnd
  275. p.counter++
  276. return &Proof{
  277. PiA: [3]*big.Int{
  278. big.NewInt(i), big.NewInt(i + 1), big.NewInt(1), //nolint:gomnd
  279. },
  280. PiB: [3][2]*big.Int{
  281. {big.NewInt(i + 2), big.NewInt(i + 3)}, //nolint:gomnd
  282. {big.NewInt(i + 4), big.NewInt(i + 5)}, //nolint:gomnd
  283. {big.NewInt(1), big.NewInt(0)}, //nolint:gomnd
  284. },
  285. PiC: [3]*big.Int{
  286. big.NewInt(i + 6), big.NewInt(i + 7), big.NewInt(1), //nolint:gomnd
  287. },
  288. Protocol: "groth",
  289. },
  290. []*big.Int{big.NewInt(i + 42)}, //nolint:gomnd
  291. nil
  292. case <-ctx.Done():
  293. return nil, nil, tracerr.Wrap(common.ErrDone)
  294. }
  295. }
  296. // Cancel cancels any current proof computation
  297. func (p *MockClient) Cancel(ctx context.Context) error {
  298. // Simulate a delay
  299. select {
  300. case <-time.After(80 * time.Millisecond): //nolint:gomnd
  301. return nil
  302. case <-ctx.Done():
  303. return tracerr.Wrap(common.ErrDone)
  304. }
  305. }
  306. // WaitReady waits until the prover is ready
  307. func (p *MockClient) WaitReady(ctx context.Context) error {
  308. // Simulate a delay
  309. select {
  310. case <-time.After(200 * time.Millisecond): //nolint:gomnd
  311. return nil
  312. case <-ctx.Done():
  313. return tracerr.Wrap(common.ErrDone)
  314. }
  315. }