You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

221 lines
6.4 KiB

  1. package db
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "math/big"
  6. "reflect"
  7. "strings"
  8. "github.com/gobuffalo/packr/v2"
  9. "github.com/hermeznetwork/hermez-node/log"
  10. "github.com/hermeznetwork/tracerr"
  11. "github.com/jmoiron/sqlx"
  12. migrate "github.com/rubenv/sql-migrate"
  13. "github.com/russross/meddler"
  14. )
  15. var migrations *migrate.PackrMigrationSource
  16. func init() {
  17. migrations = &migrate.PackrMigrationSource{
  18. Box: packr.New("hermez-db-migrations", "./migrations"),
  19. }
  20. ms, err := migrations.FindMigrations()
  21. if err != nil {
  22. panic(err)
  23. }
  24. if len(ms) == 0 {
  25. panic(fmt.Errorf("no SQL migrations found"))
  26. }
  27. }
  28. // MigrationsUp runs the SQL migrations Up
  29. func MigrationsUp(db *sql.DB) error {
  30. nMigrations, err := migrate.Exec(db, "postgres", migrations, migrate.Up)
  31. if err != nil {
  32. return tracerr.Wrap(err)
  33. }
  34. log.Info("successfully ran ", nMigrations, " migrations Up")
  35. return nil
  36. }
  37. // MigrationsDown runs the SQL migrations Down
  38. func MigrationsDown(db *sql.DB) error {
  39. nMigrations, err := migrate.Exec(db, "postgres", migrations, migrate.Down)
  40. if err != nil {
  41. return tracerr.Wrap(err)
  42. }
  43. log.Info("successfully ran ", nMigrations, " migrations Down")
  44. return nil
  45. }
  46. // ConnectSQLDB connects to the SQL DB
  47. func ConnectSQLDB(port int, host, user, password, name string) (*sqlx.DB, error) {
  48. // Init meddler
  49. initMeddler()
  50. meddler.Default = meddler.PostgreSQL
  51. // Stablish connection
  52. psqlconn := fmt.Sprintf(
  53. "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
  54. host,
  55. port,
  56. user,
  57. password,
  58. name,
  59. )
  60. db, err := sqlx.Connect("postgres", psqlconn)
  61. if err != nil {
  62. return nil, tracerr.Wrap(err)
  63. }
  64. return db, nil
  65. }
  66. // InitSQLDB runs migrations and registers meddlers
  67. func InitSQLDB(port int, host, user, password, name string) (*sqlx.DB, error) {
  68. db, err := ConnectSQLDB(port, host, user, password, name)
  69. if err != nil {
  70. return nil, tracerr.Wrap(err)
  71. }
  72. // Run DB migrations
  73. if err := MigrationsUp(db.DB); err != nil {
  74. return nil, tracerr.Wrap(err)
  75. }
  76. return db, nil
  77. }
  78. // initMeddler registers tags to be used to read/write from SQL DBs using meddler
  79. func initMeddler() {
  80. meddler.Register("bigint", BigIntMeddler{})
  81. meddler.Register("bigintnull", BigIntNullMeddler{})
  82. }
  83. // BulkInsert performs a bulk insert with a single statement into the specified table. Example:
  84. // `db.BulkInsert(myDB, "INSERT INTO block (eth_block_num, timestamp, hash) VALUES %s", blocks[:])`
  85. // Note that all the columns must be specified in the query, and they must be
  86. // in the same order as in the table.
  87. // Note that the fields in the structs need to be defined in the same order as
  88. // in the table columns.
  89. func BulkInsert(db meddler.DB, q string, args interface{}) error {
  90. arrayValue := reflect.ValueOf(args)
  91. arrayLen := arrayValue.Len()
  92. valueStrings := make([]string, 0, arrayLen)
  93. var arglist = make([]interface{}, 0)
  94. for i := 0; i < arrayLen; i++ {
  95. arg := arrayValue.Index(i).Addr().Interface()
  96. elemArglist, err := meddler.Default.Values(arg, true)
  97. if err != nil {
  98. return tracerr.Wrap(err)
  99. }
  100. arglist = append(arglist, elemArglist...)
  101. value := "("
  102. for j := 0; j < len(elemArglist); j++ {
  103. value += fmt.Sprintf("$%d, ", i*len(elemArglist)+j+1)
  104. }
  105. value = value[:len(value)-2] + ")"
  106. valueStrings = append(valueStrings, value)
  107. }
  108. stmt := fmt.Sprintf(q, strings.Join(valueStrings, ","))
  109. _, err := db.Exec(stmt, arglist...)
  110. return tracerr.Wrap(err)
  111. }
  112. // BigIntMeddler encodes or decodes the field value to or from JSON
  113. type BigIntMeddler struct{}
  114. // PreRead is called before a Scan operation for fields that have the BigIntMeddler
  115. func (b BigIntMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
  116. // give a pointer to a byte buffer to grab the raw data
  117. return new(string), nil
  118. }
  119. // PostRead is called after a Scan operation for fields that have the BigIntMeddler
  120. func (b BigIntMeddler) PostRead(fieldPtr, scanTarget interface{}) error {
  121. ptr := scanTarget.(*string)
  122. if ptr == nil {
  123. return tracerr.Wrap(fmt.Errorf("BigIntMeddler.PostRead: nil pointer"))
  124. }
  125. field := fieldPtr.(**big.Int)
  126. *field = new(big.Int).SetBytes([]byte(*ptr))
  127. return nil
  128. }
  129. // PreWrite is called before an Insert or Update operation for fields that have the BigIntMeddler
  130. func (b BigIntMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{}, err error) {
  131. field := fieldPtr.(*big.Int)
  132. return field.Bytes(), nil
  133. }
  134. // BigIntNullMeddler encodes or decodes the field value to or from JSON
  135. type BigIntNullMeddler struct{}
  136. // PreRead is called before a Scan operation for fields that have the BigIntNullMeddler
  137. func (b BigIntNullMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
  138. return &fieldAddr, nil
  139. }
  140. // PostRead is called after a Scan operation for fields that have the BigIntNullMeddler
  141. func (b BigIntNullMeddler) PostRead(fieldPtr, scanTarget interface{}) error {
  142. field := fieldPtr.(**big.Int)
  143. ptrPtr := scanTarget.(*interface{})
  144. if *ptrPtr == nil {
  145. // null column, so set target to be zero value
  146. *field = nil
  147. return nil
  148. }
  149. // not null
  150. ptr := (*ptrPtr).([]byte)
  151. if ptr == nil {
  152. return tracerr.Wrap(fmt.Errorf("BigIntMeddler.PostRead: nil pointer"))
  153. }
  154. *field = new(big.Int).SetBytes(ptr)
  155. return nil
  156. }
  157. // PreWrite is called before an Insert or Update operation for fields that have the BigIntNullMeddler
  158. func (b BigIntNullMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{}, err error) {
  159. field := fieldPtr.(*big.Int)
  160. if field == nil {
  161. return nil, nil
  162. }
  163. return field.Bytes(), nil
  164. }
  165. // SliceToSlicePtrs converts any []Foo to []*Foo
  166. func SliceToSlicePtrs(slice interface{}) interface{} {
  167. v := reflect.ValueOf(slice)
  168. vLen := v.Len()
  169. typ := v.Type().Elem()
  170. res := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(typ)), vLen, vLen)
  171. for i := 0; i < vLen; i++ {
  172. res.Index(i).Set(v.Index(i).Addr())
  173. }
  174. return res.Interface()
  175. }
  176. // SlicePtrsToSlice converts any []*Foo to []Foo
  177. func SlicePtrsToSlice(slice interface{}) interface{} {
  178. v := reflect.ValueOf(slice)
  179. vLen := v.Len()
  180. typ := v.Type().Elem().Elem()
  181. res := reflect.MakeSlice(reflect.SliceOf(typ), vLen, vLen)
  182. for i := 0; i < vLen; i++ {
  183. res.Index(i).Set(v.Index(i).Elem())
  184. }
  185. return res.Interface()
  186. }
  187. // Rollback an sql transaction, and log the error if it's not nil
  188. func Rollback(txn *sqlx.Tx) {
  189. if err := txn.Rollback(); err != nil {
  190. log.Errorw("Rollback", "err", err)
  191. }
  192. }
  193. // RowsClose close the rows of an sql query, and log the errir if it's not nil
  194. func RowsClose(rows *sql.Rows) {
  195. if err := rows.Close(); err != nil {
  196. log.Errorw("rows.Close", "err", err)
  197. }
  198. }