You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

289 lines
8.6 KiB

  1. package apitypes
  2. import (
  3. "database/sql/driver"
  4. "encoding/base64"
  5. "encoding/hex"
  6. "errors"
  7. "fmt"
  8. "math/big"
  9. "strconv"
  10. "strings"
  11. ethCommon "github.com/ethereum/go-ethereum/common"
  12. "github.com/hermeznetwork/hermez-node/common"
  13. "github.com/hermeznetwork/tracerr"
  14. "github.com/iden3/go-iden3-crypto/babyjub"
  15. )
  16. // BigIntStr is used to scan/value *big.Int directly into strings from/to sql DBs.
  17. // It assumes that *big.Int are inserted/fetched to/from the DB using the BigIntMeddler meddler
  18. // defined at github.com/hermeznetwork/hermez-node/db
  19. type BigIntStr string
  20. // NewBigIntStr creates a *BigIntStr from a *big.Int.
  21. // If the provided bigInt is nil the returned *BigIntStr will also be nil
  22. func NewBigIntStr(bigInt *big.Int) *BigIntStr {
  23. if bigInt == nil {
  24. return nil
  25. }
  26. bigIntStr := BigIntStr(bigInt.String())
  27. return &bigIntStr
  28. }
  29. // Scan implements Scanner for database/sql
  30. func (b *BigIntStr) Scan(src interface{}) error {
  31. srcBytes, ok := src.([]byte)
  32. if !ok {
  33. return tracerr.Wrap(fmt.Errorf("can't scan %T into apitypes.BigIntStr", src))
  34. }
  35. // bytes to *big.Int
  36. bigInt := new(big.Int).SetBytes(srcBytes)
  37. // *big.Int to BigIntStr
  38. bigIntStr := NewBigIntStr(bigInt)
  39. if bigIntStr == nil {
  40. return nil
  41. }
  42. *b = *bigIntStr
  43. return nil
  44. }
  45. // Value implements valuer for database/sql
  46. func (b BigIntStr) Value() (driver.Value, error) {
  47. // string to *big.Int
  48. bigInt, ok := new(big.Int).SetString(string(b), 10)
  49. if !ok || bigInt == nil {
  50. return nil, tracerr.Wrap(errors.New("invalid representation of a *big.Int"))
  51. }
  52. // *big.Int to bytes
  53. return bigInt.Bytes(), nil
  54. }
  55. // StrBigInt is used to unmarshal BigIntStr directly into an alias of big.Int
  56. type StrBigInt big.Int
  57. // UnmarshalText unmarshals a StrBigInt
  58. func (s *StrBigInt) UnmarshalText(text []byte) error {
  59. bi, ok := (*big.Int)(s).SetString(string(text), 10)
  60. if !ok {
  61. return tracerr.Wrap(fmt.Errorf("could not unmarshal %s into a StrBigInt", text))
  62. }
  63. *s = StrBigInt(*bi)
  64. return nil
  65. }
  66. // CollectedFeesAPI is send common.batch.CollectedFee through the API
  67. type CollectedFeesAPI map[common.TokenID]BigIntStr
  68. // NewCollectedFeesAPI creates a new CollectedFeesAPI from a *big.Int map
  69. func NewCollectedFeesAPI(m map[common.TokenID]*big.Int) CollectedFeesAPI {
  70. c := CollectedFeesAPI(make(map[common.TokenID]BigIntStr))
  71. for k, v := range m {
  72. c[k] = *NewBigIntStr(v)
  73. }
  74. return c
  75. }
  76. // HezEthAddr is used to scan/value Ethereum Address directly into strings that follow the Ethereum address hez format (^hez:0x[a-fA-F0-9]{40}$) from/to sql DBs.
  77. // It assumes that Ethereum Address are inserted/fetched to/from the DB using the default Scan/Value interface
  78. type HezEthAddr string
  79. // NewHezEthAddr creates a HezEthAddr from an Ethereum addr
  80. func NewHezEthAddr(addr ethCommon.Address) HezEthAddr {
  81. return HezEthAddr("hez:" + addr.String())
  82. }
  83. // ToEthAddr returns an Ethereum Address created from HezEthAddr
  84. func (a HezEthAddr) ToEthAddr() (ethCommon.Address, error) {
  85. addrStr := strings.TrimPrefix(string(a), "hez:")
  86. var addr ethCommon.Address
  87. return addr, addr.UnmarshalText([]byte(addrStr))
  88. }
  89. // Scan implements Scanner for database/sql
  90. func (a *HezEthAddr) Scan(src interface{}) error {
  91. ethAddr := &ethCommon.Address{}
  92. if err := ethAddr.Scan(src); err != nil {
  93. return tracerr.Wrap(err)
  94. }
  95. if ethAddr == nil {
  96. return nil
  97. }
  98. *a = NewHezEthAddr(*ethAddr)
  99. return nil
  100. }
  101. // Value implements valuer for database/sql
  102. func (a HezEthAddr) Value() (driver.Value, error) {
  103. ethAddr, err := a.ToEthAddr()
  104. if err != nil {
  105. return nil, tracerr.Wrap(err)
  106. }
  107. return ethAddr.Value()
  108. }
  109. // StrHezEthAddr is used to unmarshal HezEthAddr directly into an alias of ethCommon.Address
  110. type StrHezEthAddr ethCommon.Address
  111. // UnmarshalText unmarshals a StrHezEthAddr
  112. func (s *StrHezEthAddr) UnmarshalText(text []byte) error {
  113. withoutHez := strings.TrimPrefix(string(text), "hez:")
  114. var addr ethCommon.Address
  115. if err := addr.UnmarshalText([]byte(withoutHez)); err != nil {
  116. return tracerr.Wrap(err)
  117. }
  118. *s = StrHezEthAddr(addr)
  119. return nil
  120. }
  121. // HezBJJ is used to scan/value *babyjub.PublicKeyComp directly into strings that follow the BJJ public key hez format (^hez:[A-Za-z0-9_-]{44}$) from/to sql DBs.
  122. // It assumes that *babyjub.PublicKeyComp are inserted/fetched to/from the DB using the default Scan/Value interface
  123. type HezBJJ string
  124. // NewHezBJJ creates a HezBJJ from a *babyjub.PublicKeyComp.
  125. // Calling this method with a nil bjj causes panic
  126. func NewHezBJJ(pkComp babyjub.PublicKeyComp) HezBJJ {
  127. sum := pkComp[0]
  128. for i := 1; i < len(pkComp); i++ {
  129. sum += pkComp[i]
  130. }
  131. bjjSum := append(pkComp[:], sum)
  132. return HezBJJ("hez:" + base64.RawURLEncoding.EncodeToString(bjjSum))
  133. }
  134. func hezStrToBJJ(s string) (babyjub.PublicKeyComp, error) {
  135. const decodedLen = 33
  136. const encodedLen = 44
  137. formatErr := errors.New("invalid BJJ format. Must follow this regex: ^hez:[A-Za-z0-9_-]{44}$")
  138. encoded := strings.TrimPrefix(s, "hez:")
  139. if len(encoded) != encodedLen {
  140. return common.EmptyBJJComp, formatErr
  141. }
  142. decoded, err := base64.RawURLEncoding.DecodeString(encoded)
  143. if err != nil {
  144. return common.EmptyBJJComp, formatErr
  145. }
  146. if len(decoded) != decodedLen {
  147. return common.EmptyBJJComp, formatErr
  148. }
  149. bjjBytes := [decodedLen - 1]byte{}
  150. copy(bjjBytes[:decodedLen-1], decoded[:decodedLen-1])
  151. sum := bjjBytes[0]
  152. for i := 1; i < len(bjjBytes); i++ {
  153. sum += bjjBytes[i]
  154. }
  155. if decoded[decodedLen-1] != sum {
  156. return common.EmptyBJJComp, tracerr.Wrap(errors.New("checksum verification failed"))
  157. }
  158. bjjComp := babyjub.PublicKeyComp(bjjBytes)
  159. return bjjComp, nil
  160. }
  161. // ToBJJ returns a babyjub.PublicKeyComp created from HezBJJ
  162. func (b HezBJJ) ToBJJ() (babyjub.PublicKeyComp, error) {
  163. return hezStrToBJJ(string(b))
  164. }
  165. // Scan implements Scanner for database/sql
  166. func (b *HezBJJ) Scan(src interface{}) error {
  167. bjj := &babyjub.PublicKeyComp{}
  168. if err := bjj.Scan(src); err != nil {
  169. return tracerr.Wrap(err)
  170. }
  171. if bjj == nil {
  172. return nil
  173. }
  174. *b = NewHezBJJ(*bjj)
  175. return nil
  176. }
  177. // Value implements valuer for database/sql
  178. func (b HezBJJ) Value() (driver.Value, error) {
  179. bjj, err := b.ToBJJ()
  180. if err != nil {
  181. return nil, tracerr.Wrap(err)
  182. }
  183. return bjj.Value()
  184. }
  185. // StrHezBJJ is used to unmarshal HezBJJ directly into an alias of babyjub.PublicKeyComp
  186. type StrHezBJJ babyjub.PublicKeyComp
  187. // UnmarshalText unmarshalls a StrHezBJJ
  188. func (s *StrHezBJJ) UnmarshalText(text []byte) error {
  189. bjj, err := hezStrToBJJ(string(text))
  190. if err != nil {
  191. return tracerr.Wrap(err)
  192. }
  193. *s = StrHezBJJ(bjj)
  194. return nil
  195. }
  196. // HezIdx is used to value common.Idx directly into strings that follow the Idx key hez format (hez:tokenSymbol:idx) to sql DBs.
  197. // Note that this can only be used to insert to DB since there is no way to automatically read from the DB since it needs the tokenSymbol
  198. type HezIdx string
  199. // StrHezIdx is used to unmarshal HezIdx directly into an alias of common.Idx
  200. type StrHezIdx common.Idx
  201. // UnmarshalText unmarshals a StrHezIdx
  202. func (s *StrHezIdx) UnmarshalText(text []byte) error {
  203. withoutHez := strings.TrimPrefix(string(text), "hez:")
  204. splitted := strings.Split(withoutHez, ":")
  205. const expectedLen = 2
  206. if len(splitted) != expectedLen {
  207. return tracerr.Wrap(fmt.Errorf("can not unmarshal %s into StrHezIdx", text))
  208. }
  209. idxInt, err := strconv.Atoi(splitted[1])
  210. if err != nil {
  211. return tracerr.Wrap(err)
  212. }
  213. *s = StrHezIdx(common.Idx(idxInt))
  214. return nil
  215. }
  216. // EthSignature is used to scan/value []byte representing an Ethereum signature directly into strings from/to sql DBs.
  217. type EthSignature string
  218. // NewEthSignature creates a *EthSignature from []byte
  219. // If the provided signature is nil the returned *EthSignature will also be nil
  220. func NewEthSignature(signature []byte) *EthSignature {
  221. if signature == nil {
  222. return nil
  223. }
  224. ethSignature := EthSignature("0x" + hex.EncodeToString(signature))
  225. return &ethSignature
  226. }
  227. // Scan implements Scanner for database/sql
  228. func (e *EthSignature) Scan(src interface{}) error {
  229. if srcStr, ok := src.(string); ok {
  230. // src is a string
  231. *e = *(NewEthSignature([]byte(srcStr)))
  232. return nil
  233. } else if srcBytes, ok := src.([]byte); ok {
  234. // src is []byte
  235. *e = *(NewEthSignature(srcBytes))
  236. return nil
  237. } else {
  238. // unexpected src
  239. return tracerr.Wrap(fmt.Errorf("can't scan %T into apitypes.EthSignature", src))
  240. }
  241. }
  242. // Value implements valuer for database/sql
  243. func (e EthSignature) Value() (driver.Value, error) {
  244. without0x := strings.TrimPrefix(string(e), "0x")
  245. return hex.DecodeString(without0x)
  246. }
  247. // UnmarshalText unmarshals a StrEthSignature
  248. func (e *EthSignature) UnmarshalText(text []byte) error {
  249. without0x := strings.TrimPrefix(string(text), "0x")
  250. signature, err := hex.DecodeString(without0x)
  251. if err != nil {
  252. return tracerr.Wrap(err)
  253. }
  254. *e = EthSignature([]byte(signature))
  255. return nil
  256. }