You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

195 lines
5.5 KiB

  1. package apitypes
  2. import (
  3. "database/sql/driver"
  4. "encoding/base64"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "math/big"
  9. "strings"
  10. ethCommon "github.com/ethereum/go-ethereum/common"
  11. "github.com/hermeznetwork/hermez-node/common"
  12. "github.com/iden3/go-iden3-crypto/babyjub"
  13. )
  14. // BigIntStr is used to scan/value *big.Int directly into strings from/to sql DBs.
  15. // It assumes that *big.Int are inserted/fetched to/from the DB using the BigIntMeddler meddler
  16. // defined at github.com/hermeznetwork/hermez-node/db
  17. type BigIntStr string
  18. // NewBigIntStr creates a *BigIntStr from a *big.Int.
  19. // If the provided bigInt is nil the returned *BigIntStr will also be nil
  20. func NewBigIntStr(bigInt *big.Int) *BigIntStr {
  21. if bigInt == nil {
  22. return nil
  23. }
  24. bigIntStr := BigIntStr(bigInt.String())
  25. return &bigIntStr
  26. }
  27. // Scan implements Scanner for database/sql
  28. func (b *BigIntStr) Scan(src interface{}) error {
  29. // decode base64 src
  30. var decoded []byte
  31. var err error
  32. if srcStr, ok := src.(string); ok {
  33. // src is a string
  34. decoded, err = base64.StdEncoding.DecodeString(srcStr)
  35. } else if srcBytes, ok := src.([]byte); ok {
  36. // src is []byte
  37. decoded, err = base64.StdEncoding.DecodeString(string(srcBytes))
  38. } else {
  39. // unexpected src
  40. return fmt.Errorf("can't scan %T into apitypes.BigIntStr", src)
  41. }
  42. if err != nil {
  43. return err
  44. }
  45. // decoded bytes to *big.Int
  46. bigInt := &big.Int{}
  47. bigInt = bigInt.SetBytes(decoded)
  48. // *big.Int to BigIntStr
  49. bigIntStr := NewBigIntStr(bigInt)
  50. if bigIntStr == nil {
  51. return nil
  52. }
  53. *b = *bigIntStr
  54. return nil
  55. }
  56. // Value implements valuer for database/sql
  57. func (b BigIntStr) Value() (driver.Value, error) {
  58. // string to *big.Int
  59. bigInt := &big.Int{}
  60. bigInt, ok := bigInt.SetString(string(b), 10)
  61. if !ok || bigInt == nil {
  62. return nil, errors.New("invalid representation of a *big.Int")
  63. }
  64. // *big.Int to base64
  65. return base64.StdEncoding.EncodeToString(bigInt.Bytes()), nil
  66. }
  67. // CollectedFees is used to retrieve common.batch.CollectedFee from the DB
  68. type CollectedFees map[common.TokenID]BigIntStr
  69. // UnmarshalJSON unmarshals a json representation of map[common.TokenID]*big.Int
  70. func (c *CollectedFees) UnmarshalJSON(text []byte) error {
  71. bigIntMap := make(map[common.TokenID]*big.Int)
  72. if err := json.Unmarshal(text, &bigIntMap); err != nil {
  73. return err
  74. }
  75. bStrMap := make(map[common.TokenID]BigIntStr)
  76. for k, v := range bigIntMap {
  77. bStr := NewBigIntStr(v)
  78. bStrMap[k] = *bStr
  79. }
  80. *c = CollectedFees(bStrMap)
  81. return nil
  82. }
  83. // HezEthAddr is used to scan/value Ethereum Address directly into strings that follow the Ethereum address hez fotmat (^hez:0x[a-fA-F0-9]{40}$) from/to sql DBs.
  84. // It assumes that Ethereum Address are inserted/fetched to/from the DB using the default Scan/Value interface
  85. type HezEthAddr string
  86. // NewHezEthAddr creates a HezEthAddr from an Ethereum addr
  87. func NewHezEthAddr(addr ethCommon.Address) HezEthAddr {
  88. return HezEthAddr("hez:" + addr.String())
  89. }
  90. // ToEthAddr returns an Ethereum Address created from HezEthAddr
  91. func (a HezEthAddr) ToEthAddr() (ethCommon.Address, error) {
  92. addrStr := strings.TrimPrefix(string(a), "hez:")
  93. var addr ethCommon.Address
  94. return addr, addr.UnmarshalText([]byte(addrStr))
  95. }
  96. // Scan implements Scanner for database/sql
  97. func (a *HezEthAddr) Scan(src interface{}) error {
  98. ethAddr := &ethCommon.Address{}
  99. if err := ethAddr.Scan(src); err != nil {
  100. return err
  101. }
  102. if ethAddr == nil {
  103. return nil
  104. }
  105. *a = NewHezEthAddr(*ethAddr)
  106. return nil
  107. }
  108. // Value implements valuer for database/sql
  109. func (a HezEthAddr) Value() (driver.Value, error) {
  110. ethAddr, err := a.ToEthAddr()
  111. if err != nil {
  112. return nil, err
  113. }
  114. return ethAddr.Value()
  115. }
  116. // HezBJJ is used to scan/value *babyjub.PublicKey directly into strings that follow the BJJ public key hez fotmat (^hez:[A-Za-z0-9_-]{44}$) from/to sql DBs.
  117. // It assumes that *babyjub.PublicKey are inserted/fetched to/from the DB using the default Scan/Value interface
  118. type HezBJJ string
  119. // NewHezBJJ creates a HezBJJ from a *babyjub.PublicKey.
  120. // Calling this method with a nil bjj causes panic
  121. func NewHezBJJ(bjj *babyjub.PublicKey) HezBJJ {
  122. pkComp := [32]byte(bjj.Compress())
  123. sum := pkComp[0]
  124. for i := 1; i < len(pkComp); i++ {
  125. sum += pkComp[i]
  126. }
  127. bjjSum := append(pkComp[:], sum)
  128. return HezBJJ("hez:" + base64.RawURLEncoding.EncodeToString(bjjSum))
  129. }
  130. // ToBJJ returns a *babyjub.PublicKey created from HezBJJ
  131. func (b HezBJJ) ToBJJ() (*babyjub.PublicKey, error) {
  132. const decodedLen = 33
  133. const encodedLen = 44
  134. formatErr := errors.New("invalid BJJ format. Must follow this regex: ^hez:[A-Za-z0-9_-]{44}$")
  135. encoded := strings.TrimPrefix(string(b), "hez:")
  136. if len(encoded) != encodedLen {
  137. return nil, formatErr
  138. }
  139. decoded, err := base64.RawURLEncoding.DecodeString(encoded)
  140. if err != nil {
  141. return nil, formatErr
  142. }
  143. if len(decoded) != decodedLen {
  144. return nil, formatErr
  145. }
  146. bjjBytes := [decodedLen - 1]byte{}
  147. copy(bjjBytes[:decodedLen-1], decoded[:decodedLen-1])
  148. sum := bjjBytes[0]
  149. for i := 1; i < len(bjjBytes); i++ {
  150. sum += bjjBytes[i]
  151. }
  152. if decoded[decodedLen-1] != sum {
  153. return nil, errors.New("checksum verification failed")
  154. }
  155. bjjComp := babyjub.PublicKeyComp(bjjBytes)
  156. return bjjComp.Decompress()
  157. }
  158. // Scan implements Scanner for database/sql
  159. func (b *HezBJJ) Scan(src interface{}) error {
  160. bjj := &babyjub.PublicKey{}
  161. if err := bjj.Scan(src); err != nil {
  162. return err
  163. }
  164. if bjj == nil {
  165. return nil
  166. }
  167. *b = NewHezBJJ(bjj)
  168. return nil
  169. }
  170. // Value implements valuer for database/sql
  171. func (b HezBJJ) Value() (driver.Value, error) {
  172. bjj, err := b.ToBJJ()
  173. if err != nil {
  174. return nil, err
  175. }
  176. return bjj.Value()
  177. }