You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

250 lines
7.4 KiB

  1. package db
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "math/big"
  7. "reflect"
  8. "strings"
  9. "time"
  10. "github.com/gobuffalo/packr/v2"
  11. "github.com/hermeznetwork/hermez-node/log"
  12. "github.com/hermeznetwork/tracerr"
  13. "github.com/jmoiron/sqlx"
  14. migrate "github.com/rubenv/sql-migrate"
  15. "github.com/russross/meddler"
  16. "golang.org/x/sync/semaphore"
  17. )
  18. var migrations *migrate.PackrMigrationSource
  19. func init() {
  20. migrations = &migrate.PackrMigrationSource{
  21. Box: packr.New("hermez-db-migrations", "./migrations"),
  22. }
  23. ms, err := migrations.FindMigrations()
  24. if err != nil {
  25. panic(err)
  26. }
  27. if len(ms) == 0 {
  28. panic(fmt.Errorf("no SQL migrations found"))
  29. }
  30. }
  31. // MigrationsUp runs the SQL migrations Up
  32. func MigrationsUp(db *sql.DB) error {
  33. nMigrations, err := migrate.Exec(db, "postgres", migrations, migrate.Up)
  34. if err != nil {
  35. return tracerr.Wrap(err)
  36. }
  37. log.Info("successfully ran ", nMigrations, " migrations Up")
  38. return nil
  39. }
  40. // MigrationsDown runs the SQL migrations Down
  41. func MigrationsDown(db *sql.DB) error {
  42. nMigrations, err := migrate.Exec(db, "postgres", migrations, migrate.Down)
  43. if err != nil {
  44. return tracerr.Wrap(err)
  45. }
  46. log.Info("successfully ran ", nMigrations, " migrations Down")
  47. return nil
  48. }
  49. // ConnectSQLDB connects to the SQL DB
  50. func ConnectSQLDB(port int, host, user, password, name string) (*sqlx.DB, error) {
  51. // Init meddler
  52. initMeddler()
  53. meddler.Default = meddler.PostgreSQL
  54. // Stablish connection
  55. psqlconn := fmt.Sprintf(
  56. "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
  57. host,
  58. port,
  59. user,
  60. password,
  61. name,
  62. )
  63. db, err := sqlx.Connect("postgres", psqlconn)
  64. if err != nil {
  65. return nil, tracerr.Wrap(err)
  66. }
  67. return db, nil
  68. }
  69. // InitSQLDB runs migrations and registers meddlers
  70. func InitSQLDB(port int, host, user, password, name string) (*sqlx.DB, error) {
  71. db, err := ConnectSQLDB(port, host, user, password, name)
  72. if err != nil {
  73. return nil, tracerr.Wrap(err)
  74. }
  75. // Run DB migrations
  76. if err := MigrationsUp(db.DB); err != nil {
  77. return nil, tracerr.Wrap(err)
  78. }
  79. return db, nil
  80. }
  81. // APIConnectionController is used to limit the SQL open connections used by the API
  82. type APIConnectionController struct {
  83. smphr *semaphore.Weighted
  84. timeout time.Duration
  85. }
  86. // NewAPIConnectionController initialize APIConnectionController
  87. func NewAPIConnectionController(maxConnections int, timeout time.Duration) *APIConnectionController {
  88. return &APIConnectionController{
  89. smphr: semaphore.NewWeighted(int64(maxConnections)),
  90. timeout: timeout,
  91. }
  92. }
  93. // Acquire reserves a SQL connection. If the connection is not acquired
  94. // within the timeout, the function will return an error
  95. func (acc *APIConnectionController) Acquire() (context.CancelFunc, error) {
  96. ctx, cancel := context.WithTimeout(context.Background(), acc.timeout) //nolint:govet
  97. return cancel, acc.smphr.Acquire(ctx, 1)
  98. }
  99. // Release frees a SQL connection
  100. func (acc *APIConnectionController) Release() {
  101. acc.smphr.Release(1)
  102. }
  103. // initMeddler registers tags to be used to read/write from SQL DBs using meddler
  104. func initMeddler() {
  105. meddler.Register("bigint", BigIntMeddler{})
  106. meddler.Register("bigintnull", BigIntNullMeddler{})
  107. }
  108. // BulkInsert performs a bulk insert with a single statement into the specified table. Example:
  109. // `db.BulkInsert(myDB, "INSERT INTO block (eth_block_num, timestamp, hash) VALUES %s", blocks[:])`
  110. // Note that all the columns must be specified in the query, and they must be
  111. // in the same order as in the table.
  112. // Note that the fields in the structs need to be defined in the same order as
  113. // in the table columns.
  114. func BulkInsert(db meddler.DB, q string, args interface{}) error {
  115. arrayValue := reflect.ValueOf(args)
  116. arrayLen := arrayValue.Len()
  117. valueStrings := make([]string, 0, arrayLen)
  118. var arglist = make([]interface{}, 0)
  119. for i := 0; i < arrayLen; i++ {
  120. arg := arrayValue.Index(i).Addr().Interface()
  121. elemArglist, err := meddler.Default.Values(arg, true)
  122. if err != nil {
  123. return tracerr.Wrap(err)
  124. }
  125. arglist = append(arglist, elemArglist...)
  126. value := "("
  127. for j := 0; j < len(elemArglist); j++ {
  128. value += fmt.Sprintf("$%d, ", i*len(elemArglist)+j+1)
  129. }
  130. value = value[:len(value)-2] + ")"
  131. valueStrings = append(valueStrings, value)
  132. }
  133. stmt := fmt.Sprintf(q, strings.Join(valueStrings, ","))
  134. _, err := db.Exec(stmt, arglist...)
  135. return tracerr.Wrap(err)
  136. }
  137. // BigIntMeddler encodes or decodes the field value to or from JSON
  138. type BigIntMeddler struct{}
  139. // PreRead is called before a Scan operation for fields that have the BigIntMeddler
  140. func (b BigIntMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
  141. // give a pointer to a byte buffer to grab the raw data
  142. return new(string), nil
  143. }
  144. // PostRead is called after a Scan operation for fields that have the BigIntMeddler
  145. func (b BigIntMeddler) PostRead(fieldPtr, scanTarget interface{}) error {
  146. ptr := scanTarget.(*string)
  147. if ptr == nil {
  148. return tracerr.Wrap(fmt.Errorf("BigIntMeddler.PostRead: nil pointer"))
  149. }
  150. field := fieldPtr.(**big.Int)
  151. *field = new(big.Int).SetBytes([]byte(*ptr))
  152. return nil
  153. }
  154. // PreWrite is called before an Insert or Update operation for fields that have the BigIntMeddler
  155. func (b BigIntMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{}, err error) {
  156. field := fieldPtr.(*big.Int)
  157. return field.Bytes(), nil
  158. }
  159. // BigIntNullMeddler encodes or decodes the field value to or from JSON
  160. type BigIntNullMeddler struct{}
  161. // PreRead is called before a Scan operation for fields that have the BigIntNullMeddler
  162. func (b BigIntNullMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
  163. return &fieldAddr, nil
  164. }
  165. // PostRead is called after a Scan operation for fields that have the BigIntNullMeddler
  166. func (b BigIntNullMeddler) PostRead(fieldPtr, scanTarget interface{}) error {
  167. field := fieldPtr.(**big.Int)
  168. ptrPtr := scanTarget.(*interface{})
  169. if *ptrPtr == nil {
  170. // null column, so set target to be zero value
  171. *field = nil
  172. return nil
  173. }
  174. // not null
  175. ptr := (*ptrPtr).([]byte)
  176. if ptr == nil {
  177. return tracerr.Wrap(fmt.Errorf("BigIntMeddler.PostRead: nil pointer"))
  178. }
  179. *field = new(big.Int).SetBytes(ptr)
  180. return nil
  181. }
  182. // PreWrite is called before an Insert or Update operation for fields that have the BigIntNullMeddler
  183. func (b BigIntNullMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{}, err error) {
  184. field := fieldPtr.(*big.Int)
  185. if field == nil {
  186. return nil, nil
  187. }
  188. return field.Bytes(), nil
  189. }
  190. // SliceToSlicePtrs converts any []Foo to []*Foo
  191. func SliceToSlicePtrs(slice interface{}) interface{} {
  192. v := reflect.ValueOf(slice)
  193. vLen := v.Len()
  194. typ := v.Type().Elem()
  195. res := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(typ)), vLen, vLen)
  196. for i := 0; i < vLen; i++ {
  197. res.Index(i).Set(v.Index(i).Addr())
  198. }
  199. return res.Interface()
  200. }
  201. // SlicePtrsToSlice converts any []*Foo to []Foo
  202. func SlicePtrsToSlice(slice interface{}) interface{} {
  203. v := reflect.ValueOf(slice)
  204. vLen := v.Len()
  205. typ := v.Type().Elem().Elem()
  206. res := reflect.MakeSlice(reflect.SliceOf(typ), vLen, vLen)
  207. for i := 0; i < vLen; i++ {
  208. res.Index(i).Set(v.Index(i).Elem())
  209. }
  210. return res.Interface()
  211. }
  212. // Rollback an sql transaction, and log the error if it's not nil
  213. func Rollback(txn *sqlx.Tx) {
  214. if err := txn.Rollback(); err != nil {
  215. log.Errorw("Rollback", "err", err)
  216. }
  217. }
  218. // RowsClose close the rows of an sql query, and log the errir if it's not nil
  219. func RowsClose(rows *sql.Rows) {
  220. if err := rows.Close(); err != nil {
  221. log.Errorw("rows.Close", "err", err)
  222. }
  223. }