You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

349 lines
10 KiB

  1. package prover
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "math/big"
  7. "net/http"
  8. "strings"
  9. "time"
  10. "github.com/dghubble/sling"
  11. "github.com/hermeznetwork/hermez-node/common"
  12. "github.com/hermeznetwork/tracerr"
  13. )
  14. // Proof TBD this type will be received from the proof server
  15. type Proof struct {
  16. PiA [3]*big.Int `json:"pi_a"`
  17. PiB [3][2]*big.Int `json:"pi_b"`
  18. PiC [3]*big.Int `json:"pi_c"`
  19. Protocol string `json:"protocol"`
  20. }
  21. type bigInt big.Int
  22. func (b *bigInt) UnmarshalText(text []byte) error {
  23. _, ok := (*big.Int)(b).SetString(string(text), 10)
  24. if !ok {
  25. return tracerr.Wrap(fmt.Errorf("invalid big int: \"%v\"", string(text)))
  26. }
  27. return nil
  28. }
  29. // UnmarshalJSON unmarshals the proof from a JSON encoded proof with the big
  30. // ints as strings
  31. func (p *Proof) UnmarshalJSON(data []byte) error {
  32. proof := struct {
  33. PiA [3]*bigInt `json:"pi_a"`
  34. PiB [3][2]*bigInt `json:"pi_b"`
  35. PiC [3]*bigInt `json:"pi_c"`
  36. Protocol string `json:"protocol"`
  37. }{}
  38. if err := json.Unmarshal(data, &proof); err != nil {
  39. return tracerr.Wrap(err)
  40. }
  41. p.PiA[0] = (*big.Int)(proof.PiA[0])
  42. p.PiA[1] = (*big.Int)(proof.PiA[1])
  43. p.PiA[2] = (*big.Int)(proof.PiA[2])
  44. if p.PiA[2].Int64() != 1 {
  45. return tracerr.Wrap(fmt.Errorf("Expected PiA[2] == 1, but got %v", p.PiA[2]))
  46. }
  47. p.PiB[0][0] = (*big.Int)(proof.PiB[0][0])
  48. p.PiB[0][1] = (*big.Int)(proof.PiB[0][1])
  49. p.PiB[1][0] = (*big.Int)(proof.PiB[1][0])
  50. p.PiB[1][1] = (*big.Int)(proof.PiB[1][1])
  51. p.PiB[2][0] = (*big.Int)(proof.PiB[2][0])
  52. p.PiB[2][1] = (*big.Int)(proof.PiB[2][1])
  53. if p.PiB[2][0].Int64() != 1 || p.PiB[2][1].Int64() != 0 {
  54. return tracerr.Wrap(fmt.Errorf("Expected PiB[2] == [1, 0], but got %v", p.PiB[2]))
  55. }
  56. p.PiC[0] = (*big.Int)(proof.PiC[0])
  57. p.PiC[1] = (*big.Int)(proof.PiC[1])
  58. p.PiC[2] = (*big.Int)(proof.PiC[2])
  59. if p.PiC[2].Int64() != 1 {
  60. return tracerr.Wrap(fmt.Errorf("Expected PiC[2] == 1, but got %v", p.PiC[2]))
  61. }
  62. // TODO: Assert ones and zeroes
  63. p.Protocol = proof.Protocol
  64. return nil
  65. }
  66. // PublicInputs are the public inputs of the proof
  67. type PublicInputs []*big.Int
  68. // UnmarshalJSON unmarshals the JSON into the public inputs where the bigInts
  69. // are in decimal as quoted strings
  70. func (p *PublicInputs) UnmarshalJSON(data []byte) error {
  71. pubInputs := []*bigInt{}
  72. if err := json.Unmarshal(data, &pubInputs); err != nil {
  73. return tracerr.Wrap(err)
  74. }
  75. *p = make([]*big.Int, len(pubInputs))
  76. for i, v := range pubInputs {
  77. ([]*big.Int)(*p)[i] = (*big.Int)(v)
  78. }
  79. return nil
  80. }
  81. // Client is the interface to a ServerProof that calculates zk proofs
  82. type Client interface {
  83. // Non-blocking
  84. CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error
  85. // Blocking. Returns the Proof and Public Data (public inputs)
  86. GetProof(ctx context.Context) (*Proof, []*big.Int, error)
  87. // Non-Blocking
  88. Cancel(ctx context.Context) error
  89. // Blocking
  90. WaitReady(ctx context.Context) error
  91. }
  92. // StatusCode is the status string of the ProofServer
  93. type StatusCode string
  94. const (
  95. // StatusCodeAborted means prover is ready to take new proof. Previous
  96. // proof was aborted.
  97. StatusCodeAborted StatusCode = "aborted"
  98. // StatusCodeBusy means prover is busy computing proof.
  99. StatusCodeBusy StatusCode = "busy"
  100. // StatusCodeFailed means prover is ready to take new proof. Previous
  101. // proof failed
  102. StatusCodeFailed StatusCode = "failed"
  103. // StatusCodeSuccess means prover is ready to take new proof. Previous
  104. // proof succeeded
  105. StatusCodeSuccess StatusCode = "success"
  106. // StatusCodeUnverified means prover is ready to take new proof.
  107. // Previous proof was unverified
  108. StatusCodeUnverified StatusCode = "unverified"
  109. // StatusCodeUninitialized means prover is not initialized
  110. StatusCodeUninitialized StatusCode = "uninitialized"
  111. // StatusCodeUndefined means prover is in an undefined state. Most
  112. // likely is booting up. Keep trying
  113. StatusCodeUndefined StatusCode = "undefined"
  114. // StatusCodeInitializing means prover is initializing and not ready yet
  115. StatusCodeInitializing StatusCode = "initializing"
  116. // StatusCodeReady means prover initialized and ready to do first proof
  117. StatusCodeReady StatusCode = "ready"
  118. )
  119. // IsReady returns true when the prover is ready
  120. func (status StatusCode) IsReady() bool {
  121. if status == StatusCodeAborted || status == StatusCodeFailed || status == StatusCodeSuccess ||
  122. status == StatusCodeUnverified || status == StatusCodeReady {
  123. return true
  124. }
  125. return false
  126. }
  127. // IsInitialized returns true when the prover is initialized
  128. func (status StatusCode) IsInitialized() bool {
  129. if status == StatusCodeUninitialized || status == StatusCodeUndefined ||
  130. status == StatusCodeInitializing {
  131. return false
  132. }
  133. return true
  134. }
  135. // Status is the return struct for the status API endpoint
  136. type Status struct {
  137. Status StatusCode `json:"status"`
  138. Proof string `json:"proof"`
  139. PubData string `json:"pubData"`
  140. }
  141. // ErrorServer is the return struct for an API error
  142. type ErrorServer struct {
  143. Status StatusCode `json:"status"`
  144. Message string `json:"msg"`
  145. }
  146. // Error message for ErrorServer
  147. func (e ErrorServer) Error() string {
  148. return fmt.Sprintf("server proof status (%v): %v", e.Status, e.Message)
  149. }
  150. type apiMethod string
  151. const (
  152. // GET is an HTTP GET
  153. GET apiMethod = "GET"
  154. // POST is an HTTP POST with maybe JSON body
  155. POST apiMethod = "POST"
  156. )
  157. // ProofServerClient contains the data related to a ProofServerClient
  158. type ProofServerClient struct {
  159. URL string
  160. client *sling.Sling
  161. pollInterval time.Duration
  162. }
  163. // NewProofServerClient creates a new ServerProof
  164. func NewProofServerClient(URL string, pollInterval time.Duration) *ProofServerClient {
  165. if URL[len(URL)-1] != '/' {
  166. URL += "/"
  167. }
  168. client := sling.New().Base(URL)
  169. return &ProofServerClient{URL: URL, client: client, pollInterval: pollInterval}
  170. }
  171. func (p *ProofServerClient) apiRequest(ctx context.Context, method apiMethod, path string,
  172. body interface{}, ret interface{}) error {
  173. path = strings.TrimPrefix(path, "/")
  174. var errSrv ErrorServer
  175. var req *http.Request
  176. var err error
  177. switch method {
  178. case GET:
  179. req, err = p.client.New().Get(path).Request()
  180. case POST:
  181. req, err = p.client.New().Post(path).BodyJSON(body).Request()
  182. default:
  183. return tracerr.Wrap(fmt.Errorf("invalid http method: %v", method))
  184. }
  185. if err != nil {
  186. return tracerr.Wrap(err)
  187. }
  188. res, err := p.client.Do(req.WithContext(ctx), ret, &errSrv)
  189. if err != nil {
  190. return tracerr.Wrap(err)
  191. }
  192. defer res.Body.Close() //nolint:errcheck
  193. if !(200 <= res.StatusCode && res.StatusCode < 300) {
  194. return tracerr.Wrap(errSrv)
  195. }
  196. return nil
  197. }
  198. func (p *ProofServerClient) apiStatus(ctx context.Context) (*Status, error) {
  199. var status Status
  200. return &status, tracerr.Wrap(p.apiRequest(ctx, GET, "/status", nil, &status))
  201. }
  202. func (p *ProofServerClient) apiCancel(ctx context.Context) error {
  203. return tracerr.Wrap(p.apiRequest(ctx, POST, "/cancel", nil, nil))
  204. }
  205. func (p *ProofServerClient) apiInput(ctx context.Context, zkInputs *common.ZKInputs) error {
  206. return tracerr.Wrap(p.apiRequest(ctx, POST, "/input", zkInputs, nil))
  207. }
  208. // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the
  209. // Proof
  210. func (p *ProofServerClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error {
  211. return tracerr.Wrap(p.apiInput(ctx, zkInputs))
  212. }
  213. // GetProof retreives the Proof and Public Data (public inputs) from the
  214. // ServerProof, blocking until the proof is ready.
  215. func (p *ProofServerClient) GetProof(ctx context.Context) (*Proof, []*big.Int, error) {
  216. if err := p.WaitReady(ctx); err != nil {
  217. return nil, nil, tracerr.Wrap(err)
  218. }
  219. status, err := p.apiStatus(ctx)
  220. if err != nil {
  221. return nil, nil, tracerr.Wrap(err)
  222. }
  223. if status.Status == StatusCodeSuccess {
  224. var proof Proof
  225. if err := json.Unmarshal([]byte(status.Proof), &proof); err != nil {
  226. return nil, nil, tracerr.Wrap(err)
  227. }
  228. var pubInputs PublicInputs
  229. if err := json.Unmarshal([]byte(status.PubData), &pubInputs); err != nil {
  230. return nil, nil, tracerr.Wrap(err)
  231. }
  232. return &proof, pubInputs, nil
  233. }
  234. return nil, nil, tracerr.Wrap(fmt.Errorf("status != StatusCodeSuccess, status = %v", status.Status))
  235. }
  236. // Cancel cancels any current proof computation
  237. func (p *ProofServerClient) Cancel(ctx context.Context) error {
  238. return tracerr.Wrap(p.apiCancel(ctx))
  239. }
  240. // WaitReady waits until the serverProof is ready
  241. func (p *ProofServerClient) WaitReady(ctx context.Context) error {
  242. for {
  243. status, err := p.apiStatus(ctx)
  244. if err != nil {
  245. return tracerr.Wrap(err)
  246. }
  247. if !status.Status.IsInitialized() {
  248. return tracerr.Wrap(fmt.Errorf("Proof Server is not initialized"))
  249. }
  250. if status.Status.IsReady() {
  251. return nil
  252. }
  253. select {
  254. case <-ctx.Done():
  255. return tracerr.Wrap(common.ErrDone)
  256. case <-time.After(p.pollInterval):
  257. }
  258. }
  259. }
  260. // MockClient is a mock ServerProof to be used in tests. It doesn't calculate anything
  261. type MockClient struct {
  262. counter int64
  263. Delay time.Duration
  264. }
  265. // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the
  266. // Proof
  267. func (p *MockClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInputs) error {
  268. return nil
  269. }
  270. // GetProof retreives the Proof from the ServerProof
  271. func (p *MockClient) GetProof(ctx context.Context) (*Proof, []*big.Int, error) {
  272. // Simulate a delay
  273. select {
  274. case <-time.After(p.Delay): //nolint:gomnd
  275. i := p.counter * 100 //nolint:gomnd
  276. p.counter++
  277. return &Proof{
  278. PiA: [3]*big.Int{
  279. big.NewInt(i), big.NewInt(i + 1), big.NewInt(1), //nolint:gomnd
  280. },
  281. PiB: [3][2]*big.Int{
  282. {big.NewInt(i + 2), big.NewInt(i + 3)}, //nolint:gomnd
  283. {big.NewInt(i + 4), big.NewInt(i + 5)}, //nolint:gomnd
  284. {big.NewInt(1), big.NewInt(0)}, //nolint:gomnd
  285. },
  286. PiC: [3]*big.Int{
  287. big.NewInt(i + 6), big.NewInt(i + 7), big.NewInt(1), //nolint:gomnd
  288. },
  289. Protocol: "groth",
  290. },
  291. []*big.Int{big.NewInt(i + 42)}, //nolint:gomnd
  292. nil
  293. case <-ctx.Done():
  294. return nil, nil, tracerr.Wrap(common.ErrDone)
  295. }
  296. }
  297. // Cancel cancels any current proof computation
  298. func (p *MockClient) Cancel(ctx context.Context) error {
  299. // Simulate a delay
  300. select {
  301. case <-time.After(80 * time.Millisecond): //nolint:gomnd
  302. return nil
  303. case <-ctx.Done():
  304. return tracerr.Wrap(common.ErrDone)
  305. }
  306. }
  307. // WaitReady waits until the prover is ready
  308. func (p *MockClient) WaitReady(ctx context.Context) error {
  309. // Simulate a delay
  310. select {
  311. case <-time.After(200 * time.Millisecond): //nolint:gomnd
  312. return nil
  313. case <-ctx.Done():
  314. return tracerr.Wrap(common.ErrDone)
  315. }
  316. }