You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

263 lines
7.5 KiB

  1. package apitypes
  2. import (
  3. "database/sql/driver"
  4. "encoding/base64"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "math/big"
  9. "strconv"
  10. "strings"
  11. ethCommon "github.com/ethereum/go-ethereum/common"
  12. "github.com/hermeznetwork/hermez-node/common"
  13. "github.com/iden3/go-iden3-crypto/babyjub"
  14. )
  15. // BigIntStr is used to scan/value *big.Int directly into strings from/to sql DBs.
  16. // It assumes that *big.Int are inserted/fetched to/from the DB using the BigIntMeddler meddler
  17. // defined at github.com/hermeznetwork/hermez-node/db
  18. type BigIntStr string
  19. // NewBigIntStr creates a *BigIntStr from a *big.Int.
  20. // If the provided bigInt is nil the returned *BigIntStr will also be nil
  21. func NewBigIntStr(bigInt *big.Int) *BigIntStr {
  22. if bigInt == nil {
  23. return nil
  24. }
  25. bigIntStr := BigIntStr(bigInt.String())
  26. return &bigIntStr
  27. }
  28. // Scan implements Scanner for database/sql
  29. func (b *BigIntStr) Scan(src interface{}) error {
  30. // decode base64 src
  31. var decoded []byte
  32. var err error
  33. if srcStr, ok := src.(string); ok {
  34. // src is a string
  35. decoded, err = base64.StdEncoding.DecodeString(srcStr)
  36. } else if srcBytes, ok := src.([]byte); ok {
  37. // src is []byte
  38. decoded, err = base64.StdEncoding.DecodeString(string(srcBytes))
  39. } else {
  40. // unexpected src
  41. return fmt.Errorf("can't scan %T into apitypes.BigIntStr", src)
  42. }
  43. if err != nil {
  44. return err
  45. }
  46. // decoded bytes to *big.Int
  47. bigInt := &big.Int{}
  48. bigInt = bigInt.SetBytes(decoded)
  49. // *big.Int to BigIntStr
  50. bigIntStr := NewBigIntStr(bigInt)
  51. if bigIntStr == nil {
  52. return nil
  53. }
  54. *b = *bigIntStr
  55. return nil
  56. }
  57. // Value implements valuer for database/sql
  58. func (b BigIntStr) Value() (driver.Value, error) {
  59. // string to *big.Int
  60. bigInt := &big.Int{}
  61. bigInt, ok := bigInt.SetString(string(b), 10)
  62. if !ok || bigInt == nil {
  63. return nil, errors.New("invalid representation of a *big.Int")
  64. }
  65. // *big.Int to base64
  66. return base64.StdEncoding.EncodeToString(bigInt.Bytes()), nil
  67. }
  68. // StrBigInt is used to unmarshal BigIntStr directly into an alias of big.Int
  69. type StrBigInt big.Int
  70. // UnmarshalText unmarshals a StrBigInt
  71. func (s *StrBigInt) UnmarshalText(text []byte) error {
  72. bi, ok := (*big.Int)(s).SetString(string(text), 10)
  73. if !ok {
  74. return fmt.Errorf("could not unmarshal %s into a StrBigInt", text)
  75. }
  76. *s = StrBigInt(*bi)
  77. return nil
  78. }
  79. // CollectedFees is used to retrieve common.batch.CollectedFee from the DB
  80. type CollectedFees map[common.TokenID]BigIntStr
  81. // UnmarshalJSON unmarshals a json representation of map[common.TokenID]*big.Int
  82. func (c *CollectedFees) UnmarshalJSON(text []byte) error {
  83. bigIntMap := make(map[common.TokenID]*big.Int)
  84. if err := json.Unmarshal(text, &bigIntMap); err != nil {
  85. return err
  86. }
  87. bStrMap := make(map[common.TokenID]BigIntStr)
  88. for k, v := range bigIntMap {
  89. bStr := NewBigIntStr(v)
  90. bStrMap[k] = *bStr
  91. }
  92. *c = CollectedFees(bStrMap)
  93. return nil
  94. }
  95. // HezEthAddr is used to scan/value Ethereum Address directly into strings that follow the Ethereum address hez fotmat (^hez:0x[a-fA-F0-9]{40}$) from/to sql DBs.
  96. // It assumes that Ethereum Address are inserted/fetched to/from the DB using the default Scan/Value interface
  97. type HezEthAddr string
  98. // NewHezEthAddr creates a HezEthAddr from an Ethereum addr
  99. func NewHezEthAddr(addr ethCommon.Address) HezEthAddr {
  100. return HezEthAddr("hez:" + addr.String())
  101. }
  102. // ToEthAddr returns an Ethereum Address created from HezEthAddr
  103. func (a HezEthAddr) ToEthAddr() (ethCommon.Address, error) {
  104. addrStr := strings.TrimPrefix(string(a), "hez:")
  105. var addr ethCommon.Address
  106. return addr, addr.UnmarshalText([]byte(addrStr))
  107. }
  108. // Scan implements Scanner for database/sql
  109. func (a *HezEthAddr) Scan(src interface{}) error {
  110. ethAddr := &ethCommon.Address{}
  111. if err := ethAddr.Scan(src); err != nil {
  112. return err
  113. }
  114. if ethAddr == nil {
  115. return nil
  116. }
  117. *a = NewHezEthAddr(*ethAddr)
  118. return nil
  119. }
  120. // Value implements valuer for database/sql
  121. func (a HezEthAddr) Value() (driver.Value, error) {
  122. ethAddr, err := a.ToEthAddr()
  123. if err != nil {
  124. return nil, err
  125. }
  126. return ethAddr.Value()
  127. }
  128. // StrHezEthAddr is used to unmarshal HezEthAddr directly into an alias of ethCommon.Address
  129. type StrHezEthAddr ethCommon.Address
  130. // UnmarshalText unmarshals a StrHezEthAddr
  131. func (s *StrHezEthAddr) UnmarshalText(text []byte) error {
  132. withoutHez := strings.TrimPrefix(string(text), "hez:")
  133. var addr ethCommon.Address
  134. if err := addr.UnmarshalText([]byte(withoutHez)); err != nil {
  135. return err
  136. }
  137. *s = StrHezEthAddr(addr)
  138. return nil
  139. }
  140. // HezBJJ is used to scan/value *babyjub.PublicKey directly into strings that follow the BJJ public key hez fotmat (^hez:[A-Za-z0-9_-]{44}$) from/to sql DBs.
  141. // It assumes that *babyjub.PublicKey are inserted/fetched to/from the DB using the default Scan/Value interface
  142. type HezBJJ string
  143. // NewHezBJJ creates a HezBJJ from a *babyjub.PublicKey.
  144. // Calling this method with a nil bjj causes panic
  145. func NewHezBJJ(bjj *babyjub.PublicKey) HezBJJ {
  146. pkComp := [32]byte(bjj.Compress())
  147. sum := pkComp[0]
  148. for i := 1; i < len(pkComp); i++ {
  149. sum += pkComp[i]
  150. }
  151. bjjSum := append(pkComp[:], sum)
  152. return HezBJJ("hez:" + base64.RawURLEncoding.EncodeToString(bjjSum))
  153. }
  154. func hezStrToBJJ(s string) (*babyjub.PublicKey, error) {
  155. const decodedLen = 33
  156. const encodedLen = 44
  157. formatErr := errors.New("invalid BJJ format. Must follow this regex: ^hez:[A-Za-z0-9_-]{44}$")
  158. encoded := strings.TrimPrefix(s, "hez:")
  159. if len(encoded) != encodedLen {
  160. return nil, formatErr
  161. }
  162. decoded, err := base64.RawURLEncoding.DecodeString(encoded)
  163. if err != nil {
  164. return nil, formatErr
  165. }
  166. if len(decoded) != decodedLen {
  167. return nil, formatErr
  168. }
  169. bjjBytes := [decodedLen - 1]byte{}
  170. copy(bjjBytes[:decodedLen-1], decoded[:decodedLen-1])
  171. sum := bjjBytes[0]
  172. for i := 1; i < len(bjjBytes); i++ {
  173. sum += bjjBytes[i]
  174. }
  175. if decoded[decodedLen-1] != sum {
  176. return nil, errors.New("checksum verification failed")
  177. }
  178. bjjComp := babyjub.PublicKeyComp(bjjBytes)
  179. return bjjComp.Decompress()
  180. }
  181. // ToBJJ returns a *babyjub.PublicKey created from HezBJJ
  182. func (b HezBJJ) ToBJJ() (*babyjub.PublicKey, error) {
  183. return hezStrToBJJ(string(b))
  184. }
  185. // Scan implements Scanner for database/sql
  186. func (b *HezBJJ) Scan(src interface{}) error {
  187. bjj := &babyjub.PublicKey{}
  188. if err := bjj.Scan(src); err != nil {
  189. return err
  190. }
  191. if bjj == nil {
  192. return nil
  193. }
  194. *b = NewHezBJJ(bjj)
  195. return nil
  196. }
  197. // Value implements valuer for database/sql
  198. func (b HezBJJ) Value() (driver.Value, error) {
  199. bjj, err := b.ToBJJ()
  200. if err != nil {
  201. return nil, err
  202. }
  203. return bjj.Value()
  204. }
  205. // StrHezBJJ is used to unmarshal HezBJJ directly into an alias of babyjub.PublicKey
  206. type StrHezBJJ babyjub.PublicKey
  207. // UnmarshalText unmarshals a StrHezBJJ
  208. func (s *StrHezBJJ) UnmarshalText(text []byte) error {
  209. bjj, err := hezStrToBJJ(string(text))
  210. if err != nil {
  211. return err
  212. }
  213. *s = StrHezBJJ(*bjj)
  214. return nil
  215. }
  216. // HezIdx is used to value common.Idx directly into strings that follow the Idx key hez fotmat (hez:tokenSymbol:idx) to sql DBs.
  217. // Note that this can only be used to insert to DB since there is no way to automaticaly read from the DB since it needs the tokenSymbol
  218. type HezIdx string
  219. // StrHezIdx is used to unmarshal HezIdx directly into an alias of common.Idx
  220. type StrHezIdx common.Idx
  221. // UnmarshalText unmarshals a StrHezIdx
  222. func (s *StrHezIdx) UnmarshalText(text []byte) error {
  223. withoutHez := strings.TrimPrefix(string(text), "hez:")
  224. splitted := strings.Split(withoutHez, ":")
  225. const expectedLen = 2
  226. if len(splitted) != expectedLen {
  227. return fmt.Errorf("can not unmarshal %s into StrHezIdx", text)
  228. }
  229. idxInt, err := strconv.Atoi(splitted[1])
  230. if err != nil {
  231. return err
  232. }
  233. *s = StrHezIdx(common.Idx(idxInt))
  234. return nil
  235. }