You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

188 lines
5.6 KiB

  1. package db
  2. import (
  3. "database/sql"
  4. "encoding/base64"
  5. "fmt"
  6. "math/big"
  7. "reflect"
  8. "strings"
  9. "github.com/gobuffalo/packr/v2"
  10. "github.com/hermeznetwork/hermez-node/log"
  11. "github.com/jmoiron/sqlx"
  12. migrate "github.com/rubenv/sql-migrate"
  13. "github.com/russross/meddler"
  14. )
  15. // InitSQLDB runs migrations and registers meddlers
  16. func InitSQLDB(port int, host, user, password, name string) (*sqlx.DB, error) {
  17. // Init meddler
  18. initMeddler()
  19. meddler.Default = meddler.PostgreSQL
  20. // Stablish connection
  21. psqlconn := fmt.Sprintf(
  22. "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
  23. host,
  24. port,
  25. user,
  26. password,
  27. name,
  28. )
  29. db, err := sqlx.Connect("postgres", psqlconn)
  30. if err != nil {
  31. return nil, err
  32. }
  33. // Run DB migrations
  34. migrations := &migrate.PackrMigrationSource{
  35. Box: packr.New("hermez-db-migrations", "./migrations"),
  36. }
  37. nMigrations, err := migrate.Exec(db.DB, "postgres", migrations, migrate.Up)
  38. if err != nil {
  39. return nil, err
  40. }
  41. log.Info("successfully ran ", nMigrations, " migrations")
  42. return db, nil
  43. }
  44. // initMeddler registers tags to be used to read/write from SQL DBs using meddler
  45. func initMeddler() {
  46. meddler.Register("bigint", BigIntMeddler{})
  47. meddler.Register("bigintnull", BigIntNullMeddler{})
  48. }
  49. // BulkInsert performs a bulk insert with a single statement into the specified table. Example:
  50. // `db.BulkInsert(myDB, "INSERT INTO block (eth_block_num, timestamp, hash) VALUES %s", blocks[:])`
  51. // Note that all the columns must be specified in the query, and they must be
  52. // in the same order as in the table.
  53. // Note that the fields in the structs need to be defined in the same order as
  54. // in the table columns.
  55. func BulkInsert(db meddler.DB, q string, args interface{}) error {
  56. arrayValue := reflect.ValueOf(args)
  57. arrayLen := arrayValue.Len()
  58. valueStrings := make([]string, 0, arrayLen)
  59. var arglist = make([]interface{}, 0)
  60. for i := 0; i < arrayLen; i++ {
  61. arg := arrayValue.Index(i).Addr().Interface()
  62. elemArglist, err := meddler.Default.Values(arg, true)
  63. if err != nil {
  64. return err
  65. }
  66. arglist = append(arglist, elemArglist...)
  67. value := "("
  68. for j := 0; j < len(elemArglist); j++ {
  69. value += fmt.Sprintf("$%d, ", i*len(elemArglist)+j+1)
  70. }
  71. value = value[:len(value)-2] + ")"
  72. valueStrings = append(valueStrings, value)
  73. }
  74. stmt := fmt.Sprintf(q, strings.Join(valueStrings, ","))
  75. _, err := db.Exec(stmt, arglist...)
  76. return err
  77. }
  78. // BigIntMeddler encodes or decodes the field value to or from JSON
  79. type BigIntMeddler struct{}
  80. // PreRead is called before a Scan operation for fields that have the BigIntMeddler
  81. func (b BigIntMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
  82. // give a pointer to a byte buffer to grab the raw data
  83. return new(string), nil
  84. }
  85. // PostRead is called after a Scan operation for fields that have the BigIntMeddler
  86. func (b BigIntMeddler) PostRead(fieldPtr, scanTarget interface{}) error {
  87. ptr := scanTarget.(*string)
  88. if ptr == nil {
  89. return fmt.Errorf("BigIntMeddler.PostRead: nil pointer")
  90. }
  91. data, err := base64.StdEncoding.DecodeString(*ptr)
  92. if err != nil {
  93. return fmt.Errorf("big.Int decode error: %v", err)
  94. }
  95. field := fieldPtr.(**big.Int)
  96. *field = new(big.Int).SetBytes(data)
  97. return nil
  98. }
  99. // PreWrite is called before an Insert or Update operation for fields that have the BigIntMeddler
  100. func (b BigIntMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{}, err error) {
  101. field := fieldPtr.(*big.Int)
  102. str := base64.StdEncoding.EncodeToString(field.Bytes())
  103. return str, nil
  104. }
  105. // BigIntNullMeddler encodes or decodes the field value to or from JSON
  106. type BigIntNullMeddler struct{}
  107. // PreRead is called before a Scan operation for fields that have the BigIntNullMeddler
  108. func (b BigIntNullMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
  109. return &fieldAddr, nil
  110. }
  111. // PostRead is called after a Scan operation for fields that have the BigIntNullMeddler
  112. func (b BigIntNullMeddler) PostRead(fieldPtr, scanTarget interface{}) error {
  113. sv := reflect.ValueOf(scanTarget)
  114. if sv.Elem().IsNil() {
  115. // null column, so set target to be zero value
  116. fv := reflect.ValueOf(fieldPtr)
  117. fv.Elem().Set(reflect.Zero(fv.Elem().Type()))
  118. return nil
  119. }
  120. // not null
  121. encoded := new([]byte)
  122. refEnc := reflect.ValueOf(encoded)
  123. refEnc.Elem().Set(sv.Elem().Elem())
  124. data, err := base64.StdEncoding.DecodeString(string(*encoded))
  125. if err != nil {
  126. return fmt.Errorf("big.Int decode error: %v", err)
  127. }
  128. field := fieldPtr.(**big.Int)
  129. *field = new(big.Int).SetBytes(data)
  130. return nil
  131. }
  132. // PreWrite is called before an Insert or Update operation for fields that have the BigIntNullMeddler
  133. func (b BigIntNullMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{}, err error) {
  134. field := fieldPtr.(*big.Int)
  135. if field == nil {
  136. return nil, nil
  137. }
  138. return base64.StdEncoding.EncodeToString(field.Bytes()), nil
  139. }
  140. // SliceToSlicePtrs converts any []Foo to []*Foo
  141. func SliceToSlicePtrs(slice interface{}) interface{} {
  142. v := reflect.ValueOf(slice)
  143. vLen := v.Len()
  144. typ := v.Type().Elem()
  145. res := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(typ)), vLen, vLen)
  146. for i := 0; i < vLen; i++ {
  147. res.Index(i).Set(v.Index(i).Addr())
  148. }
  149. return res.Interface()
  150. }
  151. // SlicePtrsToSlice converts any []*Foo to []Foo
  152. func SlicePtrsToSlice(slice interface{}) interface{} {
  153. v := reflect.ValueOf(slice)
  154. vLen := v.Len()
  155. typ := v.Type().Elem().Elem()
  156. res := reflect.MakeSlice(reflect.SliceOf(typ), vLen, vLen)
  157. for i := 0; i < vLen; i++ {
  158. res.Index(i).Set(v.Index(i).Elem())
  159. }
  160. return res.Interface()
  161. }
  162. // Rollback an sql transaction, and log the error if it's not nil
  163. func Rollback(txn *sql.Tx) {
  164. err := txn.Rollback()
  165. if err != nil {
  166. log.Errorw("Rollback", "err", err)
  167. }
  168. }