You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

262 lines
7.7 KiB

  1. package db
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "math/big"
  7. "reflect"
  8. "strings"
  9. "time"
  10. "github.com/gobuffalo/packr/v2"
  11. "github.com/hermeznetwork/hermez-node/log"
  12. "github.com/hermeznetwork/tracerr"
  13. "github.com/jmoiron/sqlx"
  14. //nolint:errcheck // driver for postgres DB
  15. _ "github.com/lib/pq"
  16. migrate "github.com/rubenv/sql-migrate"
  17. "github.com/russross/meddler"
  18. "golang.org/x/sync/semaphore"
  19. )
  20. var migrations *migrate.PackrMigrationSource
  21. func init() {
  22. migrations = &migrate.PackrMigrationSource{
  23. Box: packr.New("hermez-db-migrations", "./migrations"),
  24. }
  25. ms, err := migrations.FindMigrations()
  26. if err != nil {
  27. panic(err)
  28. }
  29. if len(ms) == 0 {
  30. panic(fmt.Errorf("no SQL migrations found"))
  31. }
  32. }
  33. // MigrationsUp runs the SQL migrations Up
  34. func MigrationsUp(db *sql.DB) error {
  35. nMigrations, err := migrate.Exec(db, "postgres", migrations, migrate.Up)
  36. if err != nil {
  37. return tracerr.Wrap(err)
  38. }
  39. log.Info("successfully ran ", nMigrations, " migrations Up")
  40. return nil
  41. }
  42. // MigrationsDown runs the SQL migrations Down
  43. func MigrationsDown(db *sql.DB) error {
  44. nMigrations, err := migrate.Exec(db, "postgres", migrations, migrate.Down)
  45. if err != nil {
  46. return tracerr.Wrap(err)
  47. }
  48. log.Info("successfully ran ", nMigrations, " migrations Down")
  49. return nil
  50. }
  51. // ConnectSQLDB connects to the SQL DB
  52. func ConnectSQLDB(port int, host, user, password, name string) (*sqlx.DB, error) {
  53. // Init meddler
  54. initMeddler()
  55. meddler.Default = meddler.PostgreSQL
  56. // Stablish connection
  57. psqlconn := fmt.Sprintf(
  58. "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
  59. host,
  60. port,
  61. user,
  62. password,
  63. name,
  64. )
  65. db, err := sqlx.Connect("postgres", psqlconn)
  66. if err != nil {
  67. return nil, tracerr.Wrap(err)
  68. }
  69. return db, nil
  70. }
  71. // InitSQLDB runs migrations and registers meddlers
  72. func InitSQLDB(port int, host, user, password, name string) (*sqlx.DB, error) {
  73. db, err := ConnectSQLDB(port, host, user, password, name)
  74. if err != nil {
  75. return nil, tracerr.Wrap(err)
  76. }
  77. // Run DB migrations
  78. if err := MigrationsUp(db.DB); err != nil {
  79. return nil, tracerr.Wrap(err)
  80. }
  81. return db, nil
  82. }
  83. // APIConnectionController is used to limit the SQL open connections used by the API
  84. type APIConnectionController struct {
  85. smphr *semaphore.Weighted
  86. timeout time.Duration
  87. }
  88. // NewAPIConnectionController initialize APIConnectionController
  89. func NewAPIConnectionController(maxConnections int, timeout time.Duration) *APIConnectionController {
  90. return &APIConnectionController{
  91. smphr: semaphore.NewWeighted(int64(maxConnections)),
  92. timeout: timeout,
  93. }
  94. }
  95. // Acquire reserves a SQL connection. If the connection is not acquired
  96. // within the timeout, the function will return an error
  97. func (acc *APIConnectionController) Acquire() (context.CancelFunc, error) {
  98. ctx, cancel := context.WithTimeout(context.Background(), acc.timeout) //nolint:govet
  99. return cancel, acc.smphr.Acquire(ctx, 1)
  100. }
  101. // Release frees a SQL connection
  102. func (acc *APIConnectionController) Release() {
  103. acc.smphr.Release(1)
  104. }
  105. // initMeddler registers tags to be used to read/write from SQL DBs using meddler
  106. func initMeddler() {
  107. meddler.Register("bigint", BigIntMeddler{})
  108. meddler.Register("bigintnull", BigIntNullMeddler{})
  109. }
  110. // BulkInsert performs a bulk insert with a single statement into the specified table. Example:
  111. // `db.BulkInsert(myDB, "INSERT INTO block (eth_block_num, timestamp, hash) VALUES %s", blocks[:])`
  112. // Note that all the columns must be specified in the query, and they must be
  113. // in the same order as in the table.
  114. // Note that the fields in the structs need to be defined in the same order as
  115. // in the table columns.
  116. func BulkInsert(db meddler.DB, q string, args interface{}) error {
  117. arrayValue := reflect.ValueOf(args)
  118. arrayLen := arrayValue.Len()
  119. valueStrings := make([]string, 0, arrayLen)
  120. var arglist = make([]interface{}, 0)
  121. for i := 0; i < arrayLen; i++ {
  122. arg := arrayValue.Index(i).Addr().Interface()
  123. elemArglist, err := meddler.Default.Values(arg, true)
  124. if err != nil {
  125. return tracerr.Wrap(err)
  126. }
  127. arglist = append(arglist, elemArglist...)
  128. value := "("
  129. for j := 0; j < len(elemArglist); j++ {
  130. value += fmt.Sprintf("$%d, ", i*len(elemArglist)+j+1)
  131. }
  132. value = value[:len(value)-2] + ")"
  133. valueStrings = append(valueStrings, value)
  134. }
  135. stmt := fmt.Sprintf(q, strings.Join(valueStrings, ","))
  136. _, err := db.Exec(stmt, arglist...)
  137. return tracerr.Wrap(err)
  138. }
  139. // BigIntMeddler encodes or decodes the field value to or from JSON
  140. type BigIntMeddler struct{}
  141. // PreRead is called before a Scan operation for fields that have the BigIntMeddler
  142. func (b BigIntMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
  143. // give a pointer to a byte buffer to grab the raw data
  144. return new(string), nil
  145. }
  146. // PostRead is called after a Scan operation for fields that have the BigIntMeddler
  147. func (b BigIntMeddler) PostRead(fieldPtr, scanTarget interface{}) error {
  148. ptr := scanTarget.(*string)
  149. if ptr == nil {
  150. return tracerr.Wrap(fmt.Errorf("BigIntMeddler.PostRead: nil pointer"))
  151. }
  152. field := fieldPtr.(**big.Int)
  153. var ok bool
  154. *field, ok = new(big.Int).SetString(*ptr, 10)
  155. if !ok {
  156. return tracerr.Wrap(fmt.Errorf("big.Int.SetString failed on \"%v\"", *ptr))
  157. }
  158. return nil
  159. }
  160. // PreWrite is called before an Insert or Update operation for fields that have the BigIntMeddler
  161. func (b BigIntMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{}, err error) {
  162. field := fieldPtr.(*big.Int)
  163. return field.String(), nil
  164. }
  165. // BigIntNullMeddler encodes or decodes the field value to or from JSON
  166. type BigIntNullMeddler struct{}
  167. // PreRead is called before a Scan operation for fields that have the BigIntNullMeddler
  168. func (b BigIntNullMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
  169. return &fieldAddr, nil
  170. }
  171. // PostRead is called after a Scan operation for fields that have the BigIntNullMeddler
  172. func (b BigIntNullMeddler) PostRead(fieldPtr, scanTarget interface{}) error {
  173. field := fieldPtr.(**big.Int)
  174. ptrPtr := scanTarget.(*interface{})
  175. if *ptrPtr == nil {
  176. // null column, so set target to be zero value
  177. *field = nil
  178. return nil
  179. }
  180. // not null
  181. ptr := (*ptrPtr).([]byte)
  182. if ptr == nil {
  183. return tracerr.Wrap(fmt.Errorf("BigIntMeddler.PostRead: nil pointer"))
  184. }
  185. var ok bool
  186. *field, ok = new(big.Int).SetString(string(ptr), 10)
  187. if !ok {
  188. return tracerr.Wrap(fmt.Errorf("big.Int.SetString failed on \"%v\"", string(ptr)))
  189. }
  190. return nil
  191. }
  192. // PreWrite is called before an Insert or Update operation for fields that have the BigIntNullMeddler
  193. func (b BigIntNullMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{}, err error) {
  194. field := fieldPtr.(*big.Int)
  195. if field == nil {
  196. return nil, nil
  197. }
  198. return field.String(), nil
  199. }
  200. // SliceToSlicePtrs converts any []Foo to []*Foo
  201. func SliceToSlicePtrs(slice interface{}) interface{} {
  202. v := reflect.ValueOf(slice)
  203. vLen := v.Len()
  204. typ := v.Type().Elem()
  205. res := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(typ)), vLen, vLen)
  206. for i := 0; i < vLen; i++ {
  207. res.Index(i).Set(v.Index(i).Addr())
  208. }
  209. return res.Interface()
  210. }
  211. // SlicePtrsToSlice converts any []*Foo to []Foo
  212. func SlicePtrsToSlice(slice interface{}) interface{} {
  213. v := reflect.ValueOf(slice)
  214. vLen := v.Len()
  215. typ := v.Type().Elem().Elem()
  216. res := reflect.MakeSlice(reflect.SliceOf(typ), vLen, vLen)
  217. for i := 0; i < vLen; i++ {
  218. res.Index(i).Set(v.Index(i).Elem())
  219. }
  220. return res.Interface()
  221. }
  222. // Rollback an sql transaction, and log the error if it's not nil
  223. func Rollback(txn *sqlx.Tx) {
  224. if err := txn.Rollback(); err != nil {
  225. log.Errorw("Rollback", "err", err)
  226. }
  227. }
  228. // RowsClose close the rows of an sql query, and log the errir if it's not nil
  229. func RowsClose(rows *sql.Rows) {
  230. if err := rows.Close(); err != nil {
  231. log.Errorw("rows.Close", "err", err)
  232. }
  233. }