You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

78 lines
2.4 KiB

  1. package db
  2. import (
  3. "encoding/base64"
  4. "fmt"
  5. "math/big"
  6. "reflect"
  7. "strings"
  8. "github.com/russross/meddler"
  9. )
  10. // InitMeddler registers tags to be used to read/write from SQL DBs using meddler
  11. func InitMeddler() {
  12. meddler.Register("bigint", BigIntMeddler{})
  13. }
  14. // BulkInsert performs a bulk insert with a single statement into the specified table. Example:
  15. // `db.BulkInsert(myDB, "INSERT INTO block (eth_block_num, timestamp, hash) VALUES %s", blocks[:])`
  16. // Note that all the columns must be specified in the query, and they must be in the same order as in the table.
  17. func BulkInsert(db meddler.DB, q string, args interface{}) error {
  18. arrayValue := reflect.ValueOf(args)
  19. arrayLen := arrayValue.Len()
  20. valueStrings := make([]string, 0, arrayLen)
  21. var arglist = make([]interface{}, 0)
  22. for i := 0; i < arrayLen; i++ {
  23. arg := arrayValue.Index(i).Addr().Interface()
  24. elemArglist, err := meddler.Default.Values(arg, true)
  25. if err != nil {
  26. return err
  27. }
  28. arglist = append(arglist, elemArglist...)
  29. value := "("
  30. for j := 0; j < len(elemArglist); j++ {
  31. value += fmt.Sprintf("$%d, ", i*len(elemArglist)+j+1)
  32. }
  33. value = value[:len(value)-2] + ")"
  34. valueStrings = append(valueStrings, value)
  35. }
  36. stmt := fmt.Sprintf(q, strings.Join(valueStrings, ","))
  37. _, err := db.Exec(stmt, arglist...)
  38. return err
  39. }
  40. // BigIntMeddler encodes or decodes the field value to or from JSON
  41. type BigIntMeddler struct{}
  42. // PreRead is called before a Scan operation for fields that have the BigIntMeddler
  43. func (b BigIntMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
  44. // give a pointer to a byte buffer to grab the raw data
  45. return new(string), nil
  46. }
  47. // PostRead is called after a Scan operation for fields that have the BigIntMeddler
  48. func (b BigIntMeddler) PostRead(fieldPtr, scanTarget interface{}) error {
  49. ptr := scanTarget.(*string)
  50. if ptr == nil {
  51. return fmt.Errorf("BigIntMeddler.PostRead: nil pointer")
  52. }
  53. data, err := base64.StdEncoding.DecodeString(*ptr)
  54. if err != nil {
  55. return fmt.Errorf("big.Int decode error: %v", err)
  56. }
  57. field := fieldPtr.(**big.Int)
  58. *field = new(big.Int).SetBytes(data)
  59. return nil
  60. }
  61. // PreWrite is called before an Insert or Update operation for fields that have the BigIntMeddler
  62. func (b BigIntMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{}, err error) {
  63. field := fieldPtr.(*big.Int)
  64. str := base64.StdEncoding.EncodeToString(field.Bytes())
  65. return str, nil
  66. }