You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

179 lines
5.4 KiB

  1. package db
  2. import (
  3. "encoding/base64"
  4. "fmt"
  5. "math/big"
  6. "reflect"
  7. "strings"
  8. "github.com/gobuffalo/packr/v2"
  9. "github.com/hermeznetwork/hermez-node/log"
  10. "github.com/jmoiron/sqlx"
  11. migrate "github.com/rubenv/sql-migrate"
  12. "github.com/russross/meddler"
  13. )
  14. // InitSQLDB runs migrations and registers meddlers
  15. func InitSQLDB(port int, host, user, password, name string) (*sqlx.DB, error) {
  16. // Init meddler
  17. initMeddler()
  18. meddler.Default = meddler.PostgreSQL
  19. // Stablish connection
  20. psqlconn := fmt.Sprintf(
  21. "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
  22. host,
  23. port,
  24. user,
  25. password,
  26. name,
  27. )
  28. db, err := sqlx.Connect("postgres", psqlconn)
  29. if err != nil {
  30. return nil, err
  31. }
  32. // Run DB migrations
  33. migrations := &migrate.PackrMigrationSource{
  34. Box: packr.New("hermez-db-migrations", "./migrations"),
  35. }
  36. nMigrations, err := migrate.Exec(db.DB, "postgres", migrations, migrate.Up)
  37. if err != nil {
  38. return nil, err
  39. }
  40. log.Info("successfully ran ", nMigrations, " migrations")
  41. return db, nil
  42. }
  43. // initMeddler registers tags to be used to read/write from SQL DBs using meddler
  44. func initMeddler() {
  45. meddler.Register("bigint", BigIntMeddler{})
  46. meddler.Register("bigintnull", BigIntNullMeddler{})
  47. }
  48. // BulkInsert performs a bulk insert with a single statement into the specified table. Example:
  49. // `db.BulkInsert(myDB, "INSERT INTO block (eth_block_num, timestamp, hash) VALUES %s", blocks[:])`
  50. // Note that all the columns must be specified in the query, and they must be
  51. // in the same order as in the table.
  52. // Note that the fields in the structs need to be defined in the same order as
  53. // in the table columns.
  54. func BulkInsert(db meddler.DB, q string, args interface{}) error {
  55. arrayValue := reflect.ValueOf(args)
  56. arrayLen := arrayValue.Len()
  57. valueStrings := make([]string, 0, arrayLen)
  58. var arglist = make([]interface{}, 0)
  59. for i := 0; i < arrayLen; i++ {
  60. arg := arrayValue.Index(i).Addr().Interface()
  61. elemArglist, err := meddler.Default.Values(arg, true)
  62. if err != nil {
  63. return err
  64. }
  65. arglist = append(arglist, elemArglist...)
  66. value := "("
  67. for j := 0; j < len(elemArglist); j++ {
  68. value += fmt.Sprintf("$%d, ", i*len(elemArglist)+j+1)
  69. }
  70. value = value[:len(value)-2] + ")"
  71. valueStrings = append(valueStrings, value)
  72. }
  73. stmt := fmt.Sprintf(q, strings.Join(valueStrings, ","))
  74. _, err := db.Exec(stmt, arglist...)
  75. return err
  76. }
  77. // BigIntMeddler encodes or decodes the field value to or from JSON
  78. type BigIntMeddler struct{}
  79. // PreRead is called before a Scan operation for fields that have the BigIntMeddler
  80. func (b BigIntMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
  81. // give a pointer to a byte buffer to grab the raw data
  82. return new(string), nil
  83. }
  84. // PostRead is called after a Scan operation for fields that have the BigIntMeddler
  85. func (b BigIntMeddler) PostRead(fieldPtr, scanTarget interface{}) error {
  86. ptr := scanTarget.(*string)
  87. if ptr == nil {
  88. return fmt.Errorf("BigIntMeddler.PostRead: nil pointer")
  89. }
  90. data, err := base64.StdEncoding.DecodeString(*ptr)
  91. if err != nil {
  92. return fmt.Errorf("big.Int decode error: %v", err)
  93. }
  94. field := fieldPtr.(**big.Int)
  95. *field = new(big.Int).SetBytes(data)
  96. return nil
  97. }
  98. // PreWrite is called before an Insert or Update operation for fields that have the BigIntMeddler
  99. func (b BigIntMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{}, err error) {
  100. field := fieldPtr.(*big.Int)
  101. str := base64.StdEncoding.EncodeToString(field.Bytes())
  102. return str, nil
  103. }
  104. // BigIntNullMeddler encodes or decodes the field value to or from JSON
  105. type BigIntNullMeddler struct{}
  106. // PreRead is called before a Scan operation for fields that have the BigIntNullMeddler
  107. func (b BigIntNullMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
  108. return &fieldAddr, nil
  109. }
  110. // PostRead is called after a Scan operation for fields that have the BigIntNullMeddler
  111. func (b BigIntNullMeddler) PostRead(fieldPtr, scanTarget interface{}) error {
  112. sv := reflect.ValueOf(scanTarget)
  113. if sv.Elem().IsNil() {
  114. // null column, so set target to be zero value
  115. fv := reflect.ValueOf(fieldPtr)
  116. fv.Elem().Set(reflect.Zero(fv.Elem().Type()))
  117. return nil
  118. }
  119. // not null
  120. encoded := new([]byte)
  121. refEnc := reflect.ValueOf(encoded)
  122. refEnc.Elem().Set(sv.Elem().Elem())
  123. data, err := base64.StdEncoding.DecodeString(string(*encoded))
  124. if err != nil {
  125. return fmt.Errorf("big.Int decode error: %v", err)
  126. }
  127. field := fieldPtr.(**big.Int)
  128. *field = new(big.Int).SetBytes(data)
  129. return nil
  130. }
  131. // PreWrite is called before an Insert or Update operation for fields that have the BigIntNullMeddler
  132. func (b BigIntNullMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{}, err error) {
  133. field := fieldPtr.(*big.Int)
  134. if field == nil {
  135. return nil, nil
  136. }
  137. return base64.StdEncoding.EncodeToString(field.Bytes()), nil
  138. }
  139. // SliceToSlicePtrs converts any []Foo to []*Foo
  140. func SliceToSlicePtrs(slice interface{}) interface{} {
  141. v := reflect.ValueOf(slice)
  142. vLen := v.Len()
  143. typ := v.Type().Elem()
  144. res := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(typ)), vLen, vLen)
  145. for i := 0; i < vLen; i++ {
  146. res.Index(i).Set(v.Index(i).Addr())
  147. }
  148. return res.Interface()
  149. }
  150. // SlicePtrsToSlice converts any []*Foo to []Foo
  151. func SlicePtrsToSlice(slice interface{}) interface{} {
  152. v := reflect.ValueOf(slice)
  153. vLen := v.Len()
  154. typ := v.Type().Elem().Elem()
  155. res := reflect.MakeSlice(reflect.SliceOf(typ), vLen, vLen)
  156. for i := 0; i < vLen; i++ {
  157. res.Index(i).Set(v.Index(i).Elem())
  158. }
  159. return res.Interface()
  160. }