You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

264 lines
8.0 KiB

  1. package apitypes
  2. import (
  3. "database/sql/driver"
  4. "encoding/base64"
  5. "encoding/hex"
  6. "errors"
  7. "fmt"
  8. "math/big"
  9. "strconv"
  10. "strings"
  11. ethCommon "github.com/ethereum/go-ethereum/common"
  12. "github.com/hermeznetwork/hermez-node/common"
  13. "github.com/hermeznetwork/tracerr"
  14. "github.com/iden3/go-iden3-crypto/babyjub"
  15. )
  16. // BigIntStr is used to scan/value *big.Int directly into strings from/to sql DBs.
  17. // It assumes that *big.Int are inserted/fetched to/from the DB using the BigIntMeddler meddler
  18. // defined at github.com/hermeznetwork/hermez-node/db. Since *big.Int is
  19. // stored as DECIMAL in SQL, there's no need to implement Scan()/Value()
  20. // because DECIMALS are encoded/decoded as strings by the sql driver, and
  21. // BigIntStr is already a string.
  22. type BigIntStr string
  23. // NewBigIntStr creates a *BigIntStr from a *big.Int.
  24. // If the provided bigInt is nil the returned *BigIntStr will also be nil
  25. func NewBigIntStr(bigInt *big.Int) *BigIntStr {
  26. if bigInt == nil {
  27. return nil
  28. }
  29. bigIntStr := BigIntStr(bigInt.String())
  30. return &bigIntStr
  31. }
  32. // StrBigInt is used to unmarshal BigIntStr directly into an alias of big.Int
  33. type StrBigInt big.Int
  34. // UnmarshalText unmarshals a StrBigInt
  35. func (s *StrBigInt) UnmarshalText(text []byte) error {
  36. bi, ok := (*big.Int)(s).SetString(string(text), 10)
  37. if !ok {
  38. return tracerr.Wrap(fmt.Errorf("could not unmarshal %s into a StrBigInt", text))
  39. }
  40. *s = StrBigInt(*bi)
  41. return nil
  42. }
  43. // CollectedFeesAPI is send common.batch.CollectedFee through the API
  44. type CollectedFeesAPI map[common.TokenID]BigIntStr
  45. // NewCollectedFeesAPI creates a new CollectedFeesAPI from a *big.Int map
  46. func NewCollectedFeesAPI(m map[common.TokenID]*big.Int) CollectedFeesAPI {
  47. c := CollectedFeesAPI(make(map[common.TokenID]BigIntStr))
  48. for k, v := range m {
  49. c[k] = *NewBigIntStr(v)
  50. }
  51. return c
  52. }
  53. // HezEthAddr is used to scan/value Ethereum Address directly into strings that follow the Ethereum address hez format (^hez:0x[a-fA-F0-9]{40}$) from/to sql DBs.
  54. // It assumes that Ethereum Address are inserted/fetched to/from the DB using the default Scan/Value interface
  55. type HezEthAddr string
  56. // NewHezEthAddr creates a HezEthAddr from an Ethereum addr
  57. func NewHezEthAddr(addr ethCommon.Address) HezEthAddr {
  58. return HezEthAddr("hez:" + addr.String())
  59. }
  60. // ToEthAddr returns an Ethereum Address created from HezEthAddr
  61. func (a HezEthAddr) ToEthAddr() (ethCommon.Address, error) {
  62. addrStr := strings.TrimPrefix(string(a), "hez:")
  63. var addr ethCommon.Address
  64. return addr, addr.UnmarshalText([]byte(addrStr))
  65. }
  66. // Scan implements Scanner for database/sql
  67. func (a *HezEthAddr) Scan(src interface{}) error {
  68. ethAddr := &ethCommon.Address{}
  69. if err := ethAddr.Scan(src); err != nil {
  70. return tracerr.Wrap(err)
  71. }
  72. if ethAddr == nil {
  73. return nil
  74. }
  75. *a = NewHezEthAddr(*ethAddr)
  76. return nil
  77. }
  78. // Value implements valuer for database/sql
  79. func (a HezEthAddr) Value() (driver.Value, error) {
  80. ethAddr, err := a.ToEthAddr()
  81. if err != nil {
  82. return nil, tracerr.Wrap(err)
  83. }
  84. return ethAddr.Value()
  85. }
  86. // StrHezEthAddr is used to unmarshal HezEthAddr directly into an alias of ethCommon.Address
  87. type StrHezEthAddr ethCommon.Address
  88. // UnmarshalText unmarshals a StrHezEthAddr
  89. func (s *StrHezEthAddr) UnmarshalText(text []byte) error {
  90. withoutHez := strings.TrimPrefix(string(text), "hez:")
  91. var addr ethCommon.Address
  92. if err := addr.UnmarshalText([]byte(withoutHez)); err != nil {
  93. return tracerr.Wrap(err)
  94. }
  95. *s = StrHezEthAddr(addr)
  96. return nil
  97. }
  98. // HezBJJ is used to scan/value *babyjub.PublicKeyComp directly into strings that follow the BJJ public key hez format (^hez:[A-Za-z0-9_-]{44}$) from/to sql DBs.
  99. // It assumes that *babyjub.PublicKeyComp are inserted/fetched to/from the DB using the default Scan/Value interface
  100. type HezBJJ string
  101. // NewHezBJJ creates a HezBJJ from a *babyjub.PublicKeyComp.
  102. // Calling this method with a nil bjj causes panic
  103. func NewHezBJJ(pkComp babyjub.PublicKeyComp) HezBJJ {
  104. sum := pkComp[0]
  105. for i := 1; i < len(pkComp); i++ {
  106. sum += pkComp[i]
  107. }
  108. bjjSum := append(pkComp[:], sum)
  109. return HezBJJ("hez:" + base64.RawURLEncoding.EncodeToString(bjjSum))
  110. }
  111. func hezStrToBJJ(s string) (babyjub.PublicKeyComp, error) {
  112. const decodedLen = 33
  113. const encodedLen = 44
  114. formatErr := errors.New("invalid BJJ format. Must follow this regex: ^hez:[A-Za-z0-9_-]{44}$")
  115. encoded := strings.TrimPrefix(s, "hez:")
  116. if len(encoded) != encodedLen {
  117. return common.EmptyBJJComp, formatErr
  118. }
  119. decoded, err := base64.RawURLEncoding.DecodeString(encoded)
  120. if err != nil {
  121. return common.EmptyBJJComp, formatErr
  122. }
  123. if len(decoded) != decodedLen {
  124. return common.EmptyBJJComp, formatErr
  125. }
  126. bjjBytes := [decodedLen - 1]byte{}
  127. copy(bjjBytes[:decodedLen-1], decoded[:decodedLen-1])
  128. sum := bjjBytes[0]
  129. for i := 1; i < len(bjjBytes); i++ {
  130. sum += bjjBytes[i]
  131. }
  132. if decoded[decodedLen-1] != sum {
  133. return common.EmptyBJJComp, tracerr.Wrap(errors.New("checksum verification failed"))
  134. }
  135. bjjComp := babyjub.PublicKeyComp(bjjBytes)
  136. return bjjComp, nil
  137. }
  138. // ToBJJ returns a babyjub.PublicKeyComp created from HezBJJ
  139. func (b HezBJJ) ToBJJ() (babyjub.PublicKeyComp, error) {
  140. return hezStrToBJJ(string(b))
  141. }
  142. // Scan implements Scanner for database/sql
  143. func (b *HezBJJ) Scan(src interface{}) error {
  144. bjj := &babyjub.PublicKeyComp{}
  145. if err := bjj.Scan(src); err != nil {
  146. return tracerr.Wrap(err)
  147. }
  148. if bjj == nil {
  149. return nil
  150. }
  151. *b = NewHezBJJ(*bjj)
  152. return nil
  153. }
  154. // Value implements valuer for database/sql
  155. func (b HezBJJ) Value() (driver.Value, error) {
  156. bjj, err := b.ToBJJ()
  157. if err != nil {
  158. return nil, tracerr.Wrap(err)
  159. }
  160. return bjj.Value()
  161. }
  162. // StrHezBJJ is used to unmarshal HezBJJ directly into an alias of babyjub.PublicKeyComp
  163. type StrHezBJJ babyjub.PublicKeyComp
  164. // UnmarshalText unmarshalls a StrHezBJJ
  165. func (s *StrHezBJJ) UnmarshalText(text []byte) error {
  166. bjj, err := hezStrToBJJ(string(text))
  167. if err != nil {
  168. return tracerr.Wrap(err)
  169. }
  170. *s = StrHezBJJ(bjj)
  171. return nil
  172. }
  173. // HezIdx is used to value common.Idx directly into strings that follow the Idx key hez format (hez:tokenSymbol:idx) to sql DBs.
  174. // Note that this can only be used to insert to DB since there is no way to automatically read from the DB since it needs the tokenSymbol
  175. type HezIdx string
  176. // StrHezIdx is used to unmarshal HezIdx directly into an alias of common.Idx
  177. type StrHezIdx common.Idx
  178. // UnmarshalText unmarshals a StrHezIdx
  179. func (s *StrHezIdx) UnmarshalText(text []byte) error {
  180. withoutHez := strings.TrimPrefix(string(text), "hez:")
  181. splitted := strings.Split(withoutHez, ":")
  182. const expectedLen = 2
  183. if len(splitted) != expectedLen {
  184. return tracerr.Wrap(fmt.Errorf("can not unmarshal %s into StrHezIdx", text))
  185. }
  186. idxInt, err := strconv.Atoi(splitted[1])
  187. if err != nil {
  188. return tracerr.Wrap(err)
  189. }
  190. *s = StrHezIdx(common.Idx(idxInt))
  191. return nil
  192. }
  193. // EthSignature is used to scan/value []byte representing an Ethereum signature directly into strings from/to sql DBs.
  194. type EthSignature string
  195. // NewEthSignature creates a *EthSignature from []byte
  196. // If the provided signature is nil the returned *EthSignature will also be nil
  197. func NewEthSignature(signature []byte) *EthSignature {
  198. if signature == nil {
  199. return nil
  200. }
  201. ethSignature := EthSignature("0x" + hex.EncodeToString(signature))
  202. return &ethSignature
  203. }
  204. // Scan implements Scanner for database/sql
  205. func (e *EthSignature) Scan(src interface{}) error {
  206. if srcStr, ok := src.(string); ok {
  207. // src is a string
  208. *e = *(NewEthSignature([]byte(srcStr)))
  209. return nil
  210. } else if srcBytes, ok := src.([]byte); ok {
  211. // src is []byte
  212. *e = *(NewEthSignature(srcBytes))
  213. return nil
  214. } else {
  215. // unexpected src
  216. return tracerr.Wrap(fmt.Errorf("can't scan %T into apitypes.EthSignature", src))
  217. }
  218. }
  219. // Value implements valuer for database/sql
  220. func (e EthSignature) Value() (driver.Value, error) {
  221. without0x := strings.TrimPrefix(string(e), "0x")
  222. return hex.DecodeString(without0x)
  223. }
  224. // UnmarshalText unmarshals a StrEthSignature
  225. func (e *EthSignature) UnmarshalText(text []byte) error {
  226. without0x := strings.TrimPrefix(string(text), "0x")
  227. signature, err := hex.DecodeString(without0x)
  228. if err != nil {
  229. return tracerr.Wrap(err)
  230. }
  231. *e = EthSignature([]byte(signature))
  232. return nil
  233. }