Replace all []*Foo by []Foo in sql db return values

- Implement SlicePtrsToSlice and use it in all `meddler.QueryAll` sql db functions to always return []Foo instead of []*Foo
This commit is contained in:
Eduard S
2020-10-07 16:39:48 +02:00
parent 0277210c39
commit b14495cfcc
14 changed files with 124 additions and 54 deletions

View File

@@ -106,14 +106,14 @@ func (hdb *HistoryDB) GetBlock(blockNum int64) (*common.Block, error) {
}
// GetBlocks retrieve blocks from the DB, given a range of block numbers defined by from and to
func (hdb *HistoryDB) GetBlocks(from, to int64) ([]*common.Block, error) {
func (hdb *HistoryDB) GetBlocks(from, to int64) ([]common.Block, error) {
var blocks []*common.Block
err := meddler.QueryAll(
hdb.db, &blocks,
"SELECT * FROM block WHERE $1 <= eth_block_num AND eth_block_num < $2;",
from, to,
)
return blocks, err
return db.SlicePtrsToSlice(blocks).([]common.Block), err
}
// GetLastBlock retrieve the block with the highest block number from the DB
@@ -155,14 +155,14 @@ func (hdb *HistoryDB) addBatches(d meddler.DB, batches []common.Batch) error {
}
// GetBatches retrieve batches from the DB, given a range of batch numbers defined by from and to
func (hdb *HistoryDB) GetBatches(from, to common.BatchNum) ([]*common.Batch, error) {
func (hdb *HistoryDB) GetBatches(from, to common.BatchNum) ([]common.Batch, error) {
var batches []*common.Batch
err := meddler.QueryAll(
hdb.db, &batches,
"SELECT * FROM batch WHERE $1 <= batch_num AND batch_num < $2;",
from, to,
)
return batches, err
return db.SlicePtrsToSlice(batches).([]common.Batch), err
}
// GetLastBatchNum returns the BatchNum of the latest forged batch
@@ -215,13 +215,13 @@ func (hdb *HistoryDB) addBids(d meddler.DB, bids []common.Bid) error {
}
// GetBids return the bids
func (hdb *HistoryDB) GetBids() ([]*common.Bid, error) {
func (hdb *HistoryDB) GetBids() ([]common.Bid, error) {
var bids []*common.Bid
err := meddler.QueryAll(
hdb.db, &bids,
"SELECT * FROM bid;",
)
return bids, err
return db.SlicePtrsToSlice(bids).([]common.Bid), err
}
// AddCoordinators insert Coordinators into the DB
@@ -283,13 +283,13 @@ func (hdb *HistoryDB) UpdateTokenValue(tokenSymbol string, value float64) error
}
// GetTokens returns a list of tokens from the DB
func (hdb *HistoryDB) GetTokens() ([]*common.Token, error) {
func (hdb *HistoryDB) GetTokens() ([]common.Token, error) {
var tokens []*common.Token
err := meddler.QueryAll(
hdb.db, &tokens,
"SELECT * FROM token ORDER BY token_id;",
)
return tokens, err
return db.SlicePtrsToSlice(tokens).([]common.Token), err
}
// GetTokenSymbols returns all the token symbols from the DB
@@ -329,13 +329,13 @@ func (hdb *HistoryDB) addAccounts(d meddler.DB, accounts []common.Account) error
}
// GetAccounts returns a list of accounts from the DB
func (hdb *HistoryDB) GetAccounts() ([]*common.Account, error) {
func (hdb *HistoryDB) GetAccounts() ([]common.Account, error) {
var accs []*common.Account
err := meddler.QueryAll(
hdb.db, &accs,
"SELECT * FROM account ORDER BY idx;",
)
return accs, err
return db.SlicePtrsToSlice(accs).([]common.Account), err
}
// AddL1Txs inserts L1 txs to the DB. USD and LoadAmountUSD will be set automatically before storing the tx.
@@ -398,14 +398,14 @@ func (hdb *HistoryDB) addTxs(d meddler.DB, txs []common.Tx) error {
}
// GetTxs returns a list of txs from the DB
func (hdb *HistoryDB) GetTxs() ([]*common.Tx, error) {
func (hdb *HistoryDB) GetTxs() ([]common.Tx, error) {
var txs []*common.Tx
err := meddler.QueryAll(
hdb.db, &txs,
`SELECT * FROM tx
ORDER BY (batch_num, position) ASC`,
)
return txs, err
return db.SlicePtrsToSlice(txs).([]common.Tx), err
}
// GetHistoryTxs returns a list of txs from the DB using the HistoryTx struct
@@ -413,7 +413,7 @@ func (hdb *HistoryDB) GetHistoryTxs(
ethAddr *ethCommon.Address, bjj *babyjub.PublicKey,
tokenID, idx, batchNum *uint, txType *common.TxType,
offset, limit *uint, last bool,
) ([]*HistoryTx, int, error) {
) ([]HistoryTx, int, error) {
if ethAddr != nil && bjj != nil {
return nil, 0, errors.New("ethAddr and bjj are incompatible")
}
@@ -495,14 +495,15 @@ func (hdb *HistoryDB) GetHistoryTxs(
queryStr += fmt.Sprintf("LIMIT %d;", *limit)
query = hdb.db.Rebind(queryStr)
// log.Debug(query)
txs := []*HistoryTx{}
if err := meddler.QueryAll(hdb.db, &txs, query, args...); err != nil {
txsPtrs := []*HistoryTx{}
if err := meddler.QueryAll(hdb.db, &txsPtrs, query, args...); err != nil {
return nil, 0, err
}
txs := db.SlicePtrsToSlice(txsPtrs).([]HistoryTx)
if len(txs) == 0 {
return nil, 0, sql.ErrNoRows
} else if last {
tmp := []*HistoryTx{}
tmp := []HistoryTx{}
for i := len(txs) - 1; i >= 0; i-- {
tmp = append(tmp, txs[i])
}

View File

@@ -60,8 +60,8 @@ func TestBlocks(t *testing.T) {
assert.Equal(t, len(blocks), len(fetchedBlocks))
// Compare generated vs getted blocks
assert.NoError(t, err)
for i, fetchedBlock := range fetchedBlocks {
assertEqualBlock(t, &blocks[i], fetchedBlock)
for i := range fetchedBlocks {
assertEqualBlock(t, &blocks[i], &fetchedBlocks[i])
}
// Get blocks from the DB one by one
for i := fromBlock; i < toBlock; i++ {
@@ -100,7 +100,7 @@ func TestBatches(t *testing.T) {
fetchedBatches, err := historyDB.GetBatches(0, common.BatchNum(nBatches))
assert.NoError(t, err)
for i, fetchedBatch := range fetchedBatches {
assert.Equal(t, batches[i], *fetchedBatch)
assert.Equal(t, batches[i], fetchedBatch)
}
// Test GetLastBatchNum
fetchedLastBatchNum, err := historyDB.GetLastBatchNum()
@@ -132,7 +132,7 @@ func TestBids(t *testing.T) {
assert.NoError(t, err)
// Compare fetched bids vs generated bids
for i, bid := range fetchedBids {
assert.Equal(t, bids[i], *bid)
assert.Equal(t, bids[i], bid)
}
}
@@ -191,7 +191,7 @@ func TestAccounts(t *testing.T) {
assert.NoError(t, err)
// Compare fetched accounts vs generated accounts
for i, acc := range fetchedAccs {
assert.Equal(t, accs[i], *acc)
assert.Equal(t, accs[i], acc)
}
}

View File

@@ -7,6 +7,7 @@ import (
ethCommon "github.com/ethereum/go-ethereum/common"
"github.com/hermeznetwork/hermez-node/common"
"github.com/hermeznetwork/hermez-node/db"
"github.com/hermeznetwork/hermez-node/log"
"github.com/iden3/go-iden3-crypto/babyjub"
"github.com/jmoiron/sqlx"
@@ -134,14 +135,14 @@ func (l2db *L2DB) GetTx(txID common.TxID) (*common.PoolL2Tx, error) {
}
// GetPendingTxs return all the pending txs of the L2DB, that have a non NULL AbsoluteFee
func (l2db *L2DB) GetPendingTxs() ([]*common.PoolL2Tx, error) {
func (l2db *L2DB) GetPendingTxs() ([]common.PoolL2Tx, error) {
var txs []*common.PoolL2Tx
err := meddler.QueryAll(
l2db.db, &txs,
selectPoolTx+"WHERE state = $1 AND token.usd IS NOT NULL",
common.PoolL2TxStatePending,
)
return txs, err
return db.SlicePtrsToSlice(txs).([]common.PoolL2Tx), err
}
// StartForging updates the state of the transactions that will begin the forging process.

View File

@@ -111,8 +111,8 @@ func TestGetPending(t *testing.T) {
fetchedTxs, err := l2DB.GetPendingTxs()
assert.NoError(t, err)
assert.Equal(t, len(pendingTxs), len(fetchedTxs))
for i, fetchedTx := range fetchedTxs {
assertTx(t, pendingTxs[i], fetchedTx)
for i := range fetchedTxs {
assertTx(t, pendingTxs[i], &fetchedTxs[i])
}
}

View File

@@ -68,8 +68,9 @@ func (s *StateDB) ProcessTxs(l1usertxs, l1coordinatortxs []common.L1Tx, l2txs []
}
// assumption: l1usertx are sorted by L1Tx.Position
for _, tx := range l1usertxs {
exitIdx, exitAccount, newExit, err := s.processL1Tx(exitTree, &tx)
for i := range l1usertxs {
tx := &l1usertxs[i]
exitIdx, exitAccount, newExit, err := s.processL1Tx(exitTree, tx)
if err != nil {
return nil, nil, err
}
@@ -85,8 +86,9 @@ func (s *StateDB) ProcessTxs(l1usertxs, l1coordinatortxs []common.L1Tx, l2txs []
s.i++
}
}
for _, tx := range l1coordinatortxs {
exitIdx, exitAccount, newExit, err := s.processL1Tx(exitTree, &tx)
for i := range l1coordinatortxs {
tx := &l1coordinatortxs[i]
exitIdx, exitAccount, newExit, err := s.processL1Tx(exitTree, tx)
if err != nil {
return nil, nil, err
}
@@ -105,8 +107,9 @@ func (s *StateDB) ProcessTxs(l1usertxs, l1coordinatortxs []common.L1Tx, l2txs []
s.i++
}
}
for _, tx := range l2txs {
exitIdx, exitAccount, newExit, err := s.processL2Tx(exitTree, &tx)
for i := range l2txs {
tx := &l2txs[i]
exitIdx, exitAccount, newExit, err := s.processL2Tx(exitTree, tx)
if err != nil {
return nil, nil, err
}

View File

@@ -52,7 +52,10 @@ func initMeddler() {
// BulkInsert performs a bulk insert with a single statement into the specified table. Example:
// `db.BulkInsert(myDB, "INSERT INTO block (eth_block_num, timestamp, hash) VALUES %s", blocks[:])`
// Note that all the columns must be specified in the query, and they must be in the same order as in the table.
// Note that all the columns must be specified in the query, and they must be
// in the same order as in the table.
// Note that the fields in the structs need to be defined in the same order as
// in the table columns.
func BulkInsert(db meddler.DB, q string, args interface{}) error {
arrayValue := reflect.ValueOf(args)
arrayLen := arrayValue.Len()
@@ -150,3 +153,27 @@ func (b BigIntNullMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{}
}
return base64.StdEncoding.EncodeToString(field.Bytes()), nil
}
// SliceToSlicePtrs converts any []Foo to []*Foo
func SliceToSlicePtrs(slice interface{}) interface{} {
v := reflect.ValueOf(slice)
vLen := v.Len()
typ := v.Type().Elem()
res := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(typ)), vLen, vLen)
for i := 0; i < vLen; i++ {
res.Index(i).Set(v.Index(i).Addr())
}
return res.Interface()
}
// SlicePtrsToSlice converts any []*Foo to []Foo
func SlicePtrsToSlice(slice interface{}) interface{} {
v := reflect.ValueOf(slice)
vLen := v.Len()
typ := v.Type().Elem().Elem()
res := reflect.MakeSlice(reflect.SliceOf(typ), vLen, vLen)
for i := 0; i < vLen; i++ {
res.Index(i).Set(v.Index(i).Elem())
}
return res.Interface()
}

35
db/utils_test.go Normal file
View File

@@ -0,0 +1,35 @@
package db
import (
"testing"
"github.com/stretchr/testify/assert"
)
type foo struct {
V int
}
func TestSliceToSlicePtrs(t *testing.T) {
n := 16
a := make([]foo, n)
for i := 0; i < n; i++ {
a[i] = foo{V: i}
}
b := SliceToSlicePtrs(a).([]*foo)
for i := 0; i < len(a); i++ {
assert.Equal(t, a[i], *b[i])
}
}
func TestSlicePtrsToSlice(t *testing.T) {
n := 16
a := make([]*foo, n)
for i := 0; i < n; i++ {
a[i] = &foo{V: i}
}
b := SlicePtrsToSlice(a).([]foo)
for i := 0; i < len(a); i++ {
assert.Equal(t, *a[i], b[i])
}
}