You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

174 lines
5.2 KiB

  1. package db
  2. import (
  3. "fmt"
  4. "math/big"
  5. "reflect"
  6. "strings"
  7. "github.com/gobuffalo/packr/v2"
  8. "github.com/hermeznetwork/hermez-node/log"
  9. "github.com/jmoiron/sqlx"
  10. migrate "github.com/rubenv/sql-migrate"
  11. "github.com/russross/meddler"
  12. )
  13. // InitSQLDB runs migrations and registers meddlers
  14. func InitSQLDB(port int, host, user, password, name string) (*sqlx.DB, error) {
  15. // Init meddler
  16. initMeddler()
  17. meddler.Default = meddler.PostgreSQL
  18. // Stablish connection
  19. psqlconn := fmt.Sprintf(
  20. "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
  21. host,
  22. port,
  23. user,
  24. password,
  25. name,
  26. )
  27. db, err := sqlx.Connect("postgres", psqlconn)
  28. if err != nil {
  29. return nil, err
  30. }
  31. // Run DB migrations
  32. migrations := &migrate.PackrMigrationSource{
  33. Box: packr.New("hermez-db-migrations", "./migrations"),
  34. }
  35. nMigrations, err := migrate.Exec(db.DB, "postgres", migrations, migrate.Up)
  36. if err != nil {
  37. return nil, err
  38. }
  39. log.Info("successfully ran ", nMigrations, " migrations")
  40. return db, nil
  41. }
  42. // initMeddler registers tags to be used to read/write from SQL DBs using meddler
  43. func initMeddler() {
  44. meddler.Register("bigint", BigIntMeddler{})
  45. meddler.Register("bigintnull", BigIntNullMeddler{})
  46. }
  47. // BulkInsert performs a bulk insert with a single statement into the specified table. Example:
  48. // `db.BulkInsert(myDB, "INSERT INTO block (eth_block_num, timestamp, hash) VALUES %s", blocks[:])`
  49. // Note that all the columns must be specified in the query, and they must be
  50. // in the same order as in the table.
  51. // Note that the fields in the structs need to be defined in the same order as
  52. // in the table columns.
  53. func BulkInsert(db meddler.DB, q string, args interface{}) error {
  54. arrayValue := reflect.ValueOf(args)
  55. arrayLen := arrayValue.Len()
  56. valueStrings := make([]string, 0, arrayLen)
  57. var arglist = make([]interface{}, 0)
  58. for i := 0; i < arrayLen; i++ {
  59. arg := arrayValue.Index(i).Addr().Interface()
  60. elemArglist, err := meddler.Default.Values(arg, true)
  61. if err != nil {
  62. return err
  63. }
  64. arglist = append(arglist, elemArglist...)
  65. value := "("
  66. for j := 0; j < len(elemArglist); j++ {
  67. value += fmt.Sprintf("$%d, ", i*len(elemArglist)+j+1)
  68. }
  69. value = value[:len(value)-2] + ")"
  70. valueStrings = append(valueStrings, value)
  71. }
  72. stmt := fmt.Sprintf(q, strings.Join(valueStrings, ","))
  73. _, err := db.Exec(stmt, arglist...)
  74. return err
  75. }
  76. // BigIntMeddler encodes or decodes the field value to or from JSON
  77. type BigIntMeddler struct{}
  78. // PreRead is called before a Scan operation for fields that have the BigIntMeddler
  79. func (b BigIntMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
  80. // give a pointer to a byte buffer to grab the raw data
  81. return new(string), nil
  82. }
  83. // PostRead is called after a Scan operation for fields that have the BigIntMeddler
  84. func (b BigIntMeddler) PostRead(fieldPtr, scanTarget interface{}) error {
  85. ptr := scanTarget.(*string)
  86. if ptr == nil {
  87. return fmt.Errorf("BigIntMeddler.PostRead: nil pointer")
  88. }
  89. field := fieldPtr.(**big.Int)
  90. *field = new(big.Int).SetBytes([]byte(*ptr))
  91. return nil
  92. }
  93. // PreWrite is called before an Insert or Update operation for fields that have the BigIntMeddler
  94. func (b BigIntMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{}, err error) {
  95. field := fieldPtr.(*big.Int)
  96. return field.Bytes(), nil
  97. }
  98. // BigIntNullMeddler encodes or decodes the field value to or from JSON
  99. type BigIntNullMeddler struct{}
  100. // PreRead is called before a Scan operation for fields that have the BigIntNullMeddler
  101. func (b BigIntNullMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
  102. return &fieldAddr, nil
  103. }
  104. // PostRead is called after a Scan operation for fields that have the BigIntNullMeddler
  105. func (b BigIntNullMeddler) PostRead(fieldPtr, scanTarget interface{}) error {
  106. field := fieldPtr.(**big.Int)
  107. ptrPtr := scanTarget.(*interface{})
  108. if *ptrPtr == nil {
  109. // null column, so set target to be zero value
  110. *field = nil
  111. return nil
  112. }
  113. // not null
  114. ptr := (*ptrPtr).([]byte)
  115. if ptr == nil {
  116. return fmt.Errorf("BigIntMeddler.PostRead: nil pointer")
  117. }
  118. *field = new(big.Int).SetBytes(ptr)
  119. return nil
  120. }
  121. // PreWrite is called before an Insert or Update operation for fields that have the BigIntNullMeddler
  122. func (b BigIntNullMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{}, err error) {
  123. field := fieldPtr.(*big.Int)
  124. if field == nil {
  125. return nil, nil
  126. }
  127. return field.Bytes(), nil
  128. }
  129. // SliceToSlicePtrs converts any []Foo to []*Foo
  130. func SliceToSlicePtrs(slice interface{}) interface{} {
  131. v := reflect.ValueOf(slice)
  132. vLen := v.Len()
  133. typ := v.Type().Elem()
  134. res := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(typ)), vLen, vLen)
  135. for i := 0; i < vLen; i++ {
  136. res.Index(i).Set(v.Index(i).Addr())
  137. }
  138. return res.Interface()
  139. }
  140. // SlicePtrsToSlice converts any []*Foo to []Foo
  141. func SlicePtrsToSlice(slice interface{}) interface{} {
  142. v := reflect.ValueOf(slice)
  143. vLen := v.Len()
  144. typ := v.Type().Elem().Elem()
  145. res := reflect.MakeSlice(reflect.SliceOf(typ), vLen, vLen)
  146. for i := 0; i < vLen; i++ {
  147. res.Index(i).Set(v.Index(i).Elem())
  148. }
  149. return res.Interface()
  150. }
  151. // Rollback an sql transaction, and log the error if it's not nil
  152. func Rollback(txn *sqlx.Tx) {
  153. err := txn.Rollback()
  154. if err != nil {
  155. log.Errorw("Rollback", "err", err)
  156. }
  157. }