You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

269 lines
8.0 KiB

  1. /*
  2. Package db have some common utilities shared by db/l2db and db/historydb, the most relevant ones are:
  3. - SQL connection utilities
  4. - Managing the SQL schema: this is done using migration files placed under db/migrations. The files are executed by
  5. order of the file name.
  6. - Custom meddlers: used to easily transform struct <==> table
  7. */
  8. package db
  9. import (
  10. "context"
  11. "database/sql"
  12. "fmt"
  13. "math/big"
  14. "reflect"
  15. "strings"
  16. "time"
  17. "github.com/gobuffalo/packr/v2"
  18. "github.com/hermeznetwork/hermez-node/log"
  19. "github.com/hermeznetwork/tracerr"
  20. "github.com/jmoiron/sqlx"
  21. //nolint:errcheck // driver for postgres DB
  22. _ "github.com/lib/pq"
  23. migrate "github.com/rubenv/sql-migrate"
  24. "github.com/russross/meddler"
  25. "golang.org/x/sync/semaphore"
  26. )
  27. var migrations *migrate.PackrMigrationSource
  28. func init() {
  29. migrations = &migrate.PackrMigrationSource{
  30. Box: packr.New("hermez-db-migrations", "./migrations"),
  31. }
  32. ms, err := migrations.FindMigrations()
  33. if err != nil {
  34. panic(err)
  35. }
  36. if len(ms) == 0 {
  37. panic(fmt.Errorf("no SQL migrations found"))
  38. }
  39. }
  40. // MigrationsUp runs the SQL migrations Up
  41. func MigrationsUp(db *sql.DB) error {
  42. nMigrations, err := migrate.Exec(db, "postgres", migrations, migrate.Up)
  43. if err != nil {
  44. return tracerr.Wrap(err)
  45. }
  46. log.Info("successfully ran ", nMigrations, " migrations Up")
  47. return nil
  48. }
  49. // MigrationsDown runs the SQL migrations Down
  50. func MigrationsDown(db *sql.DB) error {
  51. nMigrations, err := migrate.Exec(db, "postgres", migrations, migrate.Down)
  52. if err != nil {
  53. return tracerr.Wrap(err)
  54. }
  55. log.Info("successfully ran ", nMigrations, " migrations Down")
  56. return nil
  57. }
  58. // ConnectSQLDB connects to the SQL DB
  59. func ConnectSQLDB(port int, host, user, password, name string) (*sqlx.DB, error) {
  60. // Init meddler
  61. initMeddler()
  62. meddler.Default = meddler.PostgreSQL
  63. // Stablish connection
  64. psqlconn := fmt.Sprintf(
  65. "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
  66. host,
  67. port,
  68. user,
  69. password,
  70. name,
  71. )
  72. db, err := sqlx.Connect("postgres", psqlconn)
  73. if err != nil {
  74. return nil, tracerr.Wrap(err)
  75. }
  76. return db, nil
  77. }
  78. // InitSQLDB runs migrations and registers meddlers
  79. func InitSQLDB(port int, host, user, password, name string) (*sqlx.DB, error) {
  80. db, err := ConnectSQLDB(port, host, user, password, name)
  81. if err != nil {
  82. return nil, tracerr.Wrap(err)
  83. }
  84. // Run DB migrations
  85. if err := MigrationsUp(db.DB); err != nil {
  86. return nil, tracerr.Wrap(err)
  87. }
  88. return db, nil
  89. }
  90. // APIConnectionController is used to limit the SQL open connections used by the API
  91. type APIConnectionController struct {
  92. smphr *semaphore.Weighted
  93. timeout time.Duration
  94. }
  95. // NewAPIConnectionController initialize APIConnectionController
  96. func NewAPIConnectionController(maxConnections int, timeout time.Duration) *APIConnectionController {
  97. return &APIConnectionController{
  98. smphr: semaphore.NewWeighted(int64(maxConnections)),
  99. timeout: timeout,
  100. }
  101. }
  102. // Acquire reserves a SQL connection. If the connection is not acquired
  103. // within the timeout, the function will return an error
  104. func (acc *APIConnectionController) Acquire() (context.CancelFunc, error) {
  105. ctx, cancel := context.WithTimeout(context.Background(), acc.timeout) //nolint:govet
  106. return cancel, acc.smphr.Acquire(ctx, 1)
  107. }
  108. // Release frees a SQL connection
  109. func (acc *APIConnectionController) Release() {
  110. acc.smphr.Release(1)
  111. }
  112. // initMeddler registers tags to be used to read/write from SQL DBs using meddler
  113. func initMeddler() {
  114. meddler.Register("bigint", BigIntMeddler{})
  115. meddler.Register("bigintnull", BigIntNullMeddler{})
  116. }
  117. // BulkInsert performs a bulk insert with a single statement into the specified table. Example:
  118. // `db.BulkInsert(myDB, "INSERT INTO block (eth_block_num, timestamp, hash) VALUES %s", blocks[:])`
  119. // Note that all the columns must be specified in the query, and they must be
  120. // in the same order as in the table.
  121. // Note that the fields in the structs need to be defined in the same order as
  122. // in the table columns.
  123. func BulkInsert(db meddler.DB, q string, args interface{}) error {
  124. arrayValue := reflect.ValueOf(args)
  125. arrayLen := arrayValue.Len()
  126. valueStrings := make([]string, 0, arrayLen)
  127. var arglist = make([]interface{}, 0)
  128. for i := 0; i < arrayLen; i++ {
  129. arg := arrayValue.Index(i).Addr().Interface()
  130. elemArglist, err := meddler.Default.Values(arg, true)
  131. if err != nil {
  132. return tracerr.Wrap(err)
  133. }
  134. arglist = append(arglist, elemArglist...)
  135. value := "("
  136. for j := 0; j < len(elemArglist); j++ {
  137. value += fmt.Sprintf("$%d, ", i*len(elemArglist)+j+1)
  138. }
  139. value = value[:len(value)-2] + ")"
  140. valueStrings = append(valueStrings, value)
  141. }
  142. stmt := fmt.Sprintf(q, strings.Join(valueStrings, ","))
  143. _, err := db.Exec(stmt, arglist...)
  144. return tracerr.Wrap(err)
  145. }
  146. // BigIntMeddler encodes or decodes the field value to or from JSON
  147. type BigIntMeddler struct{}
  148. // PreRead is called before a Scan operation for fields that have the BigIntMeddler
  149. func (b BigIntMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
  150. // give a pointer to a byte buffer to grab the raw data
  151. return new(string), nil
  152. }
  153. // PostRead is called after a Scan operation for fields that have the BigIntMeddler
  154. func (b BigIntMeddler) PostRead(fieldPtr, scanTarget interface{}) error {
  155. ptr := scanTarget.(*string)
  156. if ptr == nil {
  157. return tracerr.Wrap(fmt.Errorf("BigIntMeddler.PostRead: nil pointer"))
  158. }
  159. field := fieldPtr.(**big.Int)
  160. var ok bool
  161. *field, ok = new(big.Int).SetString(*ptr, 10)
  162. if !ok {
  163. return tracerr.Wrap(fmt.Errorf("big.Int.SetString failed on \"%v\"", *ptr))
  164. }
  165. return nil
  166. }
  167. // PreWrite is called before an Insert or Update operation for fields that have the BigIntMeddler
  168. func (b BigIntMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{}, err error) {
  169. field := fieldPtr.(*big.Int)
  170. return field.String(), nil
  171. }
  172. // BigIntNullMeddler encodes or decodes the field value to or from JSON
  173. type BigIntNullMeddler struct{}
  174. // PreRead is called before a Scan operation for fields that have the BigIntNullMeddler
  175. func (b BigIntNullMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
  176. return &fieldAddr, nil
  177. }
  178. // PostRead is called after a Scan operation for fields that have the BigIntNullMeddler
  179. func (b BigIntNullMeddler) PostRead(fieldPtr, scanTarget interface{}) error {
  180. field := fieldPtr.(**big.Int)
  181. ptrPtr := scanTarget.(*interface{})
  182. if *ptrPtr == nil {
  183. // null column, so set target to be zero value
  184. *field = nil
  185. return nil
  186. }
  187. // not null
  188. ptr := (*ptrPtr).([]byte)
  189. if ptr == nil {
  190. return tracerr.Wrap(fmt.Errorf("BigIntMeddler.PostRead: nil pointer"))
  191. }
  192. var ok bool
  193. *field, ok = new(big.Int).SetString(string(ptr), 10)
  194. if !ok {
  195. return tracerr.Wrap(fmt.Errorf("big.Int.SetString failed on \"%v\"", string(ptr)))
  196. }
  197. return nil
  198. }
  199. // PreWrite is called before an Insert or Update operation for fields that have the BigIntNullMeddler
  200. func (b BigIntNullMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{}, err error) {
  201. field := fieldPtr.(*big.Int)
  202. if field == nil {
  203. return nil, nil
  204. }
  205. return field.String(), nil
  206. }
  207. // SliceToSlicePtrs converts any []Foo to []*Foo
  208. func SliceToSlicePtrs(slice interface{}) interface{} {
  209. v := reflect.ValueOf(slice)
  210. vLen := v.Len()
  211. typ := v.Type().Elem()
  212. res := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(typ)), vLen, vLen)
  213. for i := 0; i < vLen; i++ {
  214. res.Index(i).Set(v.Index(i).Addr())
  215. }
  216. return res.Interface()
  217. }
  218. // SlicePtrsToSlice converts any []*Foo to []Foo
  219. func SlicePtrsToSlice(slice interface{}) interface{} {
  220. v := reflect.ValueOf(slice)
  221. vLen := v.Len()
  222. typ := v.Type().Elem().Elem()
  223. res := reflect.MakeSlice(reflect.SliceOf(typ), vLen, vLen)
  224. for i := 0; i < vLen; i++ {
  225. res.Index(i).Set(v.Index(i).Elem())
  226. }
  227. return res.Interface()
  228. }
  229. // Rollback an sql transaction, and log the error if it's not nil
  230. func Rollback(txn *sqlx.Tx) {
  231. if err := txn.Rollback(); err != nil {
  232. log.Errorw("Rollback", "err", err)
  233. }
  234. }
  235. // RowsClose close the rows of an sql query, and log the errir if it's not nil
  236. func RowsClose(rows *sql.Rows) {
  237. if err := rows.Close(); err != nil {
  238. log.Errorw("rows.Close", "err", err)
  239. }
  240. }