You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

310 lines
8.9 KiB

  1. package apitypes
  2. import (
  3. "database/sql/driver"
  4. "encoding/base64"
  5. "encoding/hex"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "math/big"
  10. "strconv"
  11. "strings"
  12. ethCommon "github.com/ethereum/go-ethereum/common"
  13. "github.com/hermeznetwork/hermez-node/common"
  14. "github.com/iden3/go-iden3-crypto/babyjub"
  15. )
  16. // BigIntStr is used to scan/value *big.Int directly into strings from/to sql DBs.
  17. // It assumes that *big.Int are inserted/fetched to/from the DB using the BigIntMeddler meddler
  18. // defined at github.com/hermeznetwork/hermez-node/db
  19. type BigIntStr string
  20. // NewBigIntStr creates a *BigIntStr from a *big.Int.
  21. // If the provided bigInt is nil the returned *BigIntStr will also be nil
  22. func NewBigIntStr(bigInt *big.Int) *BigIntStr {
  23. if bigInt == nil {
  24. return nil
  25. }
  26. bigIntStr := BigIntStr(bigInt.String())
  27. return &bigIntStr
  28. }
  29. // Scan implements Scanner for database/sql
  30. func (b *BigIntStr) Scan(src interface{}) error {
  31. // decode base64 src
  32. var decoded []byte
  33. var err error
  34. if srcStr, ok := src.(string); ok {
  35. // src is a string
  36. decoded, err = base64.StdEncoding.DecodeString(srcStr)
  37. } else if srcBytes, ok := src.([]byte); ok {
  38. // src is []byte
  39. decoded, err = base64.StdEncoding.DecodeString(string(srcBytes))
  40. } else {
  41. // unexpected src
  42. return fmt.Errorf("can't scan %T into apitypes.BigIntStr", src)
  43. }
  44. if err != nil {
  45. return err
  46. }
  47. // decoded bytes to *big.Int
  48. bigInt := &big.Int{}
  49. bigInt = bigInt.SetBytes(decoded)
  50. // *big.Int to BigIntStr
  51. bigIntStr := NewBigIntStr(bigInt)
  52. if bigIntStr == nil {
  53. return nil
  54. }
  55. *b = *bigIntStr
  56. return nil
  57. }
  58. // Value implements valuer for database/sql
  59. func (b BigIntStr) Value() (driver.Value, error) {
  60. // string to *big.Int
  61. bigInt := &big.Int{}
  62. bigInt, ok := bigInt.SetString(string(b), 10)
  63. if !ok || bigInt == nil {
  64. return nil, errors.New("invalid representation of a *big.Int")
  65. }
  66. // *big.Int to base64
  67. return base64.StdEncoding.EncodeToString(bigInt.Bytes()), nil
  68. }
  69. // StrBigInt is used to unmarshal BigIntStr directly into an alias of big.Int
  70. type StrBigInt big.Int
  71. // UnmarshalText unmarshals a StrBigInt
  72. func (s *StrBigInt) UnmarshalText(text []byte) error {
  73. bi, ok := (*big.Int)(s).SetString(string(text), 10)
  74. if !ok {
  75. return fmt.Errorf("could not unmarshal %s into a StrBigInt", text)
  76. }
  77. *s = StrBigInt(*bi)
  78. return nil
  79. }
  80. // CollectedFees is used to retrieve common.batch.CollectedFee from the DB
  81. type CollectedFees map[common.TokenID]BigIntStr
  82. // UnmarshalJSON unmarshals a json representation of map[common.TokenID]*big.Int
  83. func (c *CollectedFees) UnmarshalJSON(text []byte) error {
  84. bigIntMap := make(map[common.TokenID]*big.Int)
  85. if err := json.Unmarshal(text, &bigIntMap); err != nil {
  86. return err
  87. }
  88. bStrMap := make(map[common.TokenID]BigIntStr)
  89. for k, v := range bigIntMap {
  90. bStr := NewBigIntStr(v)
  91. bStrMap[k] = *bStr
  92. }
  93. *c = CollectedFees(bStrMap)
  94. return nil
  95. }
  96. // HezEthAddr is used to scan/value Ethereum Address directly into strings that follow the Ethereum address hez fotmat (^hez:0x[a-fA-F0-9]{40}$) from/to sql DBs.
  97. // It assumes that Ethereum Address are inserted/fetched to/from the DB using the default Scan/Value interface
  98. type HezEthAddr string
  99. // NewHezEthAddr creates a HezEthAddr from an Ethereum addr
  100. func NewHezEthAddr(addr ethCommon.Address) HezEthAddr {
  101. return HezEthAddr("hez:" + addr.String())
  102. }
  103. // ToEthAddr returns an Ethereum Address created from HezEthAddr
  104. func (a HezEthAddr) ToEthAddr() (ethCommon.Address, error) {
  105. addrStr := strings.TrimPrefix(string(a), "hez:")
  106. var addr ethCommon.Address
  107. return addr, addr.UnmarshalText([]byte(addrStr))
  108. }
  109. // Scan implements Scanner for database/sql
  110. func (a *HezEthAddr) Scan(src interface{}) error {
  111. ethAddr := &ethCommon.Address{}
  112. if err := ethAddr.Scan(src); err != nil {
  113. return err
  114. }
  115. if ethAddr == nil {
  116. return nil
  117. }
  118. *a = NewHezEthAddr(*ethAddr)
  119. return nil
  120. }
  121. // Value implements valuer for database/sql
  122. func (a HezEthAddr) Value() (driver.Value, error) {
  123. ethAddr, err := a.ToEthAddr()
  124. if err != nil {
  125. return nil, err
  126. }
  127. return ethAddr.Value()
  128. }
  129. // StrHezEthAddr is used to unmarshal HezEthAddr directly into an alias of ethCommon.Address
  130. type StrHezEthAddr ethCommon.Address
  131. // UnmarshalText unmarshals a StrHezEthAddr
  132. func (s *StrHezEthAddr) UnmarshalText(text []byte) error {
  133. withoutHez := strings.TrimPrefix(string(text), "hez:")
  134. var addr ethCommon.Address
  135. if err := addr.UnmarshalText([]byte(withoutHez)); err != nil {
  136. return err
  137. }
  138. *s = StrHezEthAddr(addr)
  139. return nil
  140. }
  141. // HezBJJ is used to scan/value *babyjub.PublicKey directly into strings that follow the BJJ public key hez fotmat (^hez:[A-Za-z0-9_-]{44}$) from/to sql DBs.
  142. // It assumes that *babyjub.PublicKey are inserted/fetched to/from the DB using the default Scan/Value interface
  143. type HezBJJ string
  144. // NewHezBJJ creates a HezBJJ from a *babyjub.PublicKey.
  145. // Calling this method with a nil bjj causes panic
  146. func NewHezBJJ(bjj *babyjub.PublicKey) HezBJJ {
  147. pkComp := [32]byte(bjj.Compress())
  148. sum := pkComp[0]
  149. for i := 1; i < len(pkComp); i++ {
  150. sum += pkComp[i]
  151. }
  152. bjjSum := append(pkComp[:], sum)
  153. return HezBJJ("hez:" + base64.RawURLEncoding.EncodeToString(bjjSum))
  154. }
  155. func hezStrToBJJ(s string) (*babyjub.PublicKey, error) {
  156. const decodedLen = 33
  157. const encodedLen = 44
  158. formatErr := errors.New("invalid BJJ format. Must follow this regex: ^hez:[A-Za-z0-9_-]{44}$")
  159. encoded := strings.TrimPrefix(s, "hez:")
  160. if len(encoded) != encodedLen {
  161. return nil, formatErr
  162. }
  163. decoded, err := base64.RawURLEncoding.DecodeString(encoded)
  164. if err != nil {
  165. return nil, formatErr
  166. }
  167. if len(decoded) != decodedLen {
  168. return nil, formatErr
  169. }
  170. bjjBytes := [decodedLen - 1]byte{}
  171. copy(bjjBytes[:decodedLen-1], decoded[:decodedLen-1])
  172. sum := bjjBytes[0]
  173. for i := 1; i < len(bjjBytes); i++ {
  174. sum += bjjBytes[i]
  175. }
  176. if decoded[decodedLen-1] != sum {
  177. return nil, errors.New("checksum verification failed")
  178. }
  179. bjjComp := babyjub.PublicKeyComp(bjjBytes)
  180. return bjjComp.Decompress()
  181. }
  182. // ToBJJ returns a *babyjub.PublicKey created from HezBJJ
  183. func (b HezBJJ) ToBJJ() (*babyjub.PublicKey, error) {
  184. return hezStrToBJJ(string(b))
  185. }
  186. // Scan implements Scanner for database/sql
  187. func (b *HezBJJ) Scan(src interface{}) error {
  188. bjj := &babyjub.PublicKey{}
  189. if err := bjj.Scan(src); err != nil {
  190. return err
  191. }
  192. if bjj == nil {
  193. return nil
  194. }
  195. *b = NewHezBJJ(bjj)
  196. return nil
  197. }
  198. // Value implements valuer for database/sql
  199. func (b HezBJJ) Value() (driver.Value, error) {
  200. bjj, err := b.ToBJJ()
  201. if err != nil {
  202. return nil, err
  203. }
  204. return bjj.Value()
  205. }
  206. // StrHezBJJ is used to unmarshal HezBJJ directly into an alias of babyjub.PublicKey
  207. type StrHezBJJ babyjub.PublicKey
  208. // UnmarshalText unmarshals a StrHezBJJ
  209. func (s *StrHezBJJ) UnmarshalText(text []byte) error {
  210. bjj, err := hezStrToBJJ(string(text))
  211. if err != nil {
  212. return err
  213. }
  214. *s = StrHezBJJ(*bjj)
  215. return nil
  216. }
  217. // HezIdx is used to value common.Idx directly into strings that follow the Idx key hez fotmat (hez:tokenSymbol:idx) to sql DBs.
  218. // Note that this can only be used to insert to DB since there is no way to automaticaly read from the DB since it needs the tokenSymbol
  219. type HezIdx string
  220. // StrHezIdx is used to unmarshal HezIdx directly into an alias of common.Idx
  221. type StrHezIdx common.Idx
  222. // UnmarshalText unmarshals a StrHezIdx
  223. func (s *StrHezIdx) UnmarshalText(text []byte) error {
  224. withoutHez := strings.TrimPrefix(string(text), "hez:")
  225. splitted := strings.Split(withoutHez, ":")
  226. const expectedLen = 2
  227. if len(splitted) != expectedLen {
  228. return fmt.Errorf("can not unmarshal %s into StrHezIdx", text)
  229. }
  230. idxInt, err := strconv.Atoi(splitted[1])
  231. if err != nil {
  232. return err
  233. }
  234. *s = StrHezIdx(common.Idx(idxInt))
  235. return nil
  236. }
  237. // EthSignature is used to scan/value []byte representing an Ethereum signature directly into strings from/to sql DBs.
  238. type EthSignature string
  239. // NewEthSignature creates a *EthSignature from []byte
  240. // If the provided signature is nil the returned *EthSignature will also be nil
  241. func NewEthSignature(signature []byte) *EthSignature {
  242. if signature == nil {
  243. return nil
  244. }
  245. ethSignature := EthSignature("0x" + hex.EncodeToString(signature))
  246. return &ethSignature
  247. }
  248. // Scan implements Scanner for database/sql
  249. func (e *EthSignature) Scan(src interface{}) error {
  250. if srcStr, ok := src.(string); ok {
  251. // src is a string
  252. *e = *(NewEthSignature([]byte(srcStr)))
  253. return nil
  254. } else if srcBytes, ok := src.([]byte); ok {
  255. // src is []byte
  256. *e = *(NewEthSignature(srcBytes))
  257. return nil
  258. } else {
  259. // unexpected src
  260. return fmt.Errorf("can't scan %T into apitypes.EthSignature", src)
  261. }
  262. }
  263. // Value implements valuer for database/sql
  264. func (e EthSignature) Value() (driver.Value, error) {
  265. without0x := strings.TrimPrefix(string(e), "0x")
  266. return hex.DecodeString(without0x)
  267. }
  268. // UnmarshalText unmarshals a StrEthSignature
  269. func (e *EthSignature) UnmarshalText(text []byte) error {
  270. without0x := strings.TrimPrefix(string(text), "0x")
  271. signature, err := hex.DecodeString(without0x)
  272. if err != nil {
  273. return err
  274. }
  275. *e = EthSignature([]byte(signature))
  276. return nil
  277. }