From b14495cfccbe15d8c18c2c21db739f01d6e0435c Mon Sep 17 00:00:00 2001 From: Eduard S Date: Wed, 7 Oct 2020 16:39:48 +0200 Subject: [PATCH] Replace all []*Foo by []Foo in sql db return values - Implement SlicePtrsToSlice and use it in all `meddler.QueryAll` sql db functions to always return []Foo instead of []*Foo --- api/api_test.go | 4 ++-- api/dbtoapistructs.go | 2 +- common/l1tx.go | 15 ++++++------- db/historydb/historydb.go | 33 +++++++++++++++-------------- db/historydb/historydb_test.go | 10 ++++----- db/l2db/l2db.go | 5 +++-- db/l2db/l2db_test.go | 4 ++-- db/statedb/txprocessors.go | 15 +++++++------ db/utils.go | 29 ++++++++++++++++++++++++- db/utils_test.go | 35 +++++++++++++++++++++++++++++++ go.mod | 2 ++ synchronizer/synchronizer_test.go | 4 ++-- test/ethclient.go | 2 +- txselector/txselector.go | 18 ++++++++-------- 14 files changed, 124 insertions(+), 54 deletions(-) create mode 100644 db/utils_test.go diff --git a/api/api_test.go b/api/api_test.go index 6ffd044..c04d783 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -197,7 +197,7 @@ func TestMain(m *testing.M) { genericTxs = append(genericTxs, l2tx.Tx()) } // Transform generic Txs to HistoryTx - historyTxs := []*historydb.HistoryTx{} + historyTxs := []historydb.HistoryTx{} for _, genericTx := range genericTxs { // find timestamp var timestamp time.Time @@ -238,7 +238,7 @@ func TestMain(m *testing.M) { *feeUSD = *usd * genericTx.Fee.Percentage() } } - historyTxs = append(historyTxs, &historydb.HistoryTx{ + historyTxs = append(historyTxs, historydb.HistoryTx{ IsL1: genericTx.IsL1, TxID: genericTx.TxID, Type: genericTx.Type, diff --git a/api/dbtoapistructs.go b/api/dbtoapistructs.go index 51d5f13..f248411 100644 --- a/api/dbtoapistructs.go +++ b/api/dbtoapistructs.go @@ -69,7 +69,7 @@ type historyTxAPI struct { Token common.Token `json:"token"` } -func historyTxsToAPI(dbTxs []*historydb.HistoryTx) []historyTxAPI { +func historyTxsToAPI(dbTxs []historydb.HistoryTx) []historyTxAPI { apiTxs := []historyTxAPI{} for i := 0; i < len(dbTxs); i++ { apiTx := historyTxAPI{ diff --git a/common/l1tx.go b/common/l1tx.go index 480c7a6..045211f 100644 --- a/common/l1tx.go +++ b/common/l1tx.go @@ -90,27 +90,28 @@ func NewL1Tx(l1Tx *L1Tx) (*L1Tx, error) { return l1Tx, nil } -func (l1Tx *L1Tx) CalcTxID() (*TxID, error) { +// CalcTxID calculates the TxId of the L1Tx +func (tx *L1Tx) CalcTxID() (*TxID, error) { var txID TxID - if l1Tx.UserOrigin { - if l1Tx.ToForgeL1TxsNum == nil { + if tx.UserOrigin { + if tx.ToForgeL1TxsNum == nil { return nil, fmt.Errorf("L1Tx.UserOrigin == true && L1Tx.ToForgeL1TxsNum == nil") } txID[0] = TxIDPrefixL1UserTx var toForgeL1TxsNumBytes [8]byte - binary.BigEndian.PutUint64(toForgeL1TxsNumBytes[:], uint64(*l1Tx.ToForgeL1TxsNum)) + binary.BigEndian.PutUint64(toForgeL1TxsNumBytes[:], uint64(*tx.ToForgeL1TxsNum)) copy(txID[1:9], toForgeL1TxsNumBytes[:]) } else { - if l1Tx.BatchNum == nil { + if tx.BatchNum == nil { return nil, fmt.Errorf("L1Tx.UserOrigin == false && L1Tx.BatchNum == nil") } txID[0] = TxIDPrefixL1CoordTx var batchNumBytes [8]byte - binary.BigEndian.PutUint64(batchNumBytes[:], uint64(*l1Tx.BatchNum)) + binary.BigEndian.PutUint64(batchNumBytes[:], uint64(*tx.BatchNum)) copy(txID[1:9], batchNumBytes[:]) } var positionBytes [2]byte - binary.BigEndian.PutUint16(positionBytes[:], uint16(l1Tx.Position)) + binary.BigEndian.PutUint16(positionBytes[:], uint16(tx.Position)) copy(txID[9:11], positionBytes[:]) return &txID, nil diff --git a/db/historydb/historydb.go b/db/historydb/historydb.go index fe287c5..5950135 100644 --- a/db/historydb/historydb.go +++ b/db/historydb/historydb.go @@ -106,14 +106,14 @@ func (hdb *HistoryDB) GetBlock(blockNum int64) (*common.Block, error) { } // GetBlocks retrieve blocks from the DB, given a range of block numbers defined by from and to -func (hdb *HistoryDB) GetBlocks(from, to int64) ([]*common.Block, error) { +func (hdb *HistoryDB) GetBlocks(from, to int64) ([]common.Block, error) { var blocks []*common.Block err := meddler.QueryAll( hdb.db, &blocks, "SELECT * FROM block WHERE $1 <= eth_block_num AND eth_block_num < $2;", from, to, ) - return blocks, err + return db.SlicePtrsToSlice(blocks).([]common.Block), err } // GetLastBlock retrieve the block with the highest block number from the DB @@ -155,14 +155,14 @@ func (hdb *HistoryDB) addBatches(d meddler.DB, batches []common.Batch) error { } // GetBatches retrieve batches from the DB, given a range of batch numbers defined by from and to -func (hdb *HistoryDB) GetBatches(from, to common.BatchNum) ([]*common.Batch, error) { +func (hdb *HistoryDB) GetBatches(from, to common.BatchNum) ([]common.Batch, error) { var batches []*common.Batch err := meddler.QueryAll( hdb.db, &batches, "SELECT * FROM batch WHERE $1 <= batch_num AND batch_num < $2;", from, to, ) - return batches, err + return db.SlicePtrsToSlice(batches).([]common.Batch), err } // GetLastBatchNum returns the BatchNum of the latest forged batch @@ -215,13 +215,13 @@ func (hdb *HistoryDB) addBids(d meddler.DB, bids []common.Bid) error { } // GetBids return the bids -func (hdb *HistoryDB) GetBids() ([]*common.Bid, error) { +func (hdb *HistoryDB) GetBids() ([]common.Bid, error) { var bids []*common.Bid err := meddler.QueryAll( hdb.db, &bids, "SELECT * FROM bid;", ) - return bids, err + return db.SlicePtrsToSlice(bids).([]common.Bid), err } // AddCoordinators insert Coordinators into the DB @@ -283,13 +283,13 @@ func (hdb *HistoryDB) UpdateTokenValue(tokenSymbol string, value float64) error } // GetTokens returns a list of tokens from the DB -func (hdb *HistoryDB) GetTokens() ([]*common.Token, error) { +func (hdb *HistoryDB) GetTokens() ([]common.Token, error) { var tokens []*common.Token err := meddler.QueryAll( hdb.db, &tokens, "SELECT * FROM token ORDER BY token_id;", ) - return tokens, err + return db.SlicePtrsToSlice(tokens).([]common.Token), err } // GetTokenSymbols returns all the token symbols from the DB @@ -329,13 +329,13 @@ func (hdb *HistoryDB) addAccounts(d meddler.DB, accounts []common.Account) error } // GetAccounts returns a list of accounts from the DB -func (hdb *HistoryDB) GetAccounts() ([]*common.Account, error) { +func (hdb *HistoryDB) GetAccounts() ([]common.Account, error) { var accs []*common.Account err := meddler.QueryAll( hdb.db, &accs, "SELECT * FROM account ORDER BY idx;", ) - return accs, err + return db.SlicePtrsToSlice(accs).([]common.Account), err } // AddL1Txs inserts L1 txs to the DB. USD and LoadAmountUSD will be set automatically before storing the tx. @@ -398,14 +398,14 @@ func (hdb *HistoryDB) addTxs(d meddler.DB, txs []common.Tx) error { } // GetTxs returns a list of txs from the DB -func (hdb *HistoryDB) GetTxs() ([]*common.Tx, error) { +func (hdb *HistoryDB) GetTxs() ([]common.Tx, error) { var txs []*common.Tx err := meddler.QueryAll( hdb.db, &txs, `SELECT * FROM tx ORDER BY (batch_num, position) ASC`, ) - return txs, err + return db.SlicePtrsToSlice(txs).([]common.Tx), err } // GetHistoryTxs returns a list of txs from the DB using the HistoryTx struct @@ -413,7 +413,7 @@ func (hdb *HistoryDB) GetHistoryTxs( ethAddr *ethCommon.Address, bjj *babyjub.PublicKey, tokenID, idx, batchNum *uint, txType *common.TxType, offset, limit *uint, last bool, -) ([]*HistoryTx, int, error) { +) ([]HistoryTx, int, error) { if ethAddr != nil && bjj != nil { return nil, 0, errors.New("ethAddr and bjj are incompatible") } @@ -495,14 +495,15 @@ func (hdb *HistoryDB) GetHistoryTxs( queryStr += fmt.Sprintf("LIMIT %d;", *limit) query = hdb.db.Rebind(queryStr) // log.Debug(query) - txs := []*HistoryTx{} - if err := meddler.QueryAll(hdb.db, &txs, query, args...); err != nil { + txsPtrs := []*HistoryTx{} + if err := meddler.QueryAll(hdb.db, &txsPtrs, query, args...); err != nil { return nil, 0, err } + txs := db.SlicePtrsToSlice(txsPtrs).([]HistoryTx) if len(txs) == 0 { return nil, 0, sql.ErrNoRows } else if last { - tmp := []*HistoryTx{} + tmp := []HistoryTx{} for i := len(txs) - 1; i >= 0; i-- { tmp = append(tmp, txs[i]) } diff --git a/db/historydb/historydb_test.go b/db/historydb/historydb_test.go index 693a280..b2401d5 100644 --- a/db/historydb/historydb_test.go +++ b/db/historydb/historydb_test.go @@ -60,8 +60,8 @@ func TestBlocks(t *testing.T) { assert.Equal(t, len(blocks), len(fetchedBlocks)) // Compare generated vs getted blocks assert.NoError(t, err) - for i, fetchedBlock := range fetchedBlocks { - assertEqualBlock(t, &blocks[i], fetchedBlock) + for i := range fetchedBlocks { + assertEqualBlock(t, &blocks[i], &fetchedBlocks[i]) } // Get blocks from the DB one by one for i := fromBlock; i < toBlock; i++ { @@ -100,7 +100,7 @@ func TestBatches(t *testing.T) { fetchedBatches, err := historyDB.GetBatches(0, common.BatchNum(nBatches)) assert.NoError(t, err) for i, fetchedBatch := range fetchedBatches { - assert.Equal(t, batches[i], *fetchedBatch) + assert.Equal(t, batches[i], fetchedBatch) } // Test GetLastBatchNum fetchedLastBatchNum, err := historyDB.GetLastBatchNum() @@ -132,7 +132,7 @@ func TestBids(t *testing.T) { assert.NoError(t, err) // Compare fetched bids vs generated bids for i, bid := range fetchedBids { - assert.Equal(t, bids[i], *bid) + assert.Equal(t, bids[i], bid) } } @@ -191,7 +191,7 @@ func TestAccounts(t *testing.T) { assert.NoError(t, err) // Compare fetched accounts vs generated accounts for i, acc := range fetchedAccs { - assert.Equal(t, accs[i], *acc) + assert.Equal(t, accs[i], acc) } } diff --git a/db/l2db/l2db.go b/db/l2db/l2db.go index 19839be..10f21f8 100644 --- a/db/l2db/l2db.go +++ b/db/l2db/l2db.go @@ -7,6 +7,7 @@ import ( ethCommon "github.com/ethereum/go-ethereum/common" "github.com/hermeznetwork/hermez-node/common" + "github.com/hermeznetwork/hermez-node/db" "github.com/hermeznetwork/hermez-node/log" "github.com/iden3/go-iden3-crypto/babyjub" "github.com/jmoiron/sqlx" @@ -134,14 +135,14 @@ func (l2db *L2DB) GetTx(txID common.TxID) (*common.PoolL2Tx, error) { } // GetPendingTxs return all the pending txs of the L2DB, that have a non NULL AbsoluteFee -func (l2db *L2DB) GetPendingTxs() ([]*common.PoolL2Tx, error) { +func (l2db *L2DB) GetPendingTxs() ([]common.PoolL2Tx, error) { var txs []*common.PoolL2Tx err := meddler.QueryAll( l2db.db, &txs, selectPoolTx+"WHERE state = $1 AND token.usd IS NOT NULL", common.PoolL2TxStatePending, ) - return txs, err + return db.SlicePtrsToSlice(txs).([]common.PoolL2Tx), err } // StartForging updates the state of the transactions that will begin the forging process. diff --git a/db/l2db/l2db_test.go b/db/l2db/l2db_test.go index 48e3426..9eb68dd 100644 --- a/db/l2db/l2db_test.go +++ b/db/l2db/l2db_test.go @@ -111,8 +111,8 @@ func TestGetPending(t *testing.T) { fetchedTxs, err := l2DB.GetPendingTxs() assert.NoError(t, err) assert.Equal(t, len(pendingTxs), len(fetchedTxs)) - for i, fetchedTx := range fetchedTxs { - assertTx(t, pendingTxs[i], fetchedTx) + for i := range fetchedTxs { + assertTx(t, pendingTxs[i], &fetchedTxs[i]) } } diff --git a/db/statedb/txprocessors.go b/db/statedb/txprocessors.go index 9fdebc8..7dec441 100644 --- a/db/statedb/txprocessors.go +++ b/db/statedb/txprocessors.go @@ -68,8 +68,9 @@ func (s *StateDB) ProcessTxs(l1usertxs, l1coordinatortxs []common.L1Tx, l2txs [] } // assumption: l1usertx are sorted by L1Tx.Position - for _, tx := range l1usertxs { - exitIdx, exitAccount, newExit, err := s.processL1Tx(exitTree, &tx) + for i := range l1usertxs { + tx := &l1usertxs[i] + exitIdx, exitAccount, newExit, err := s.processL1Tx(exitTree, tx) if err != nil { return nil, nil, err } @@ -85,8 +86,9 @@ func (s *StateDB) ProcessTxs(l1usertxs, l1coordinatortxs []common.L1Tx, l2txs [] s.i++ } } - for _, tx := range l1coordinatortxs { - exitIdx, exitAccount, newExit, err := s.processL1Tx(exitTree, &tx) + for i := range l1coordinatortxs { + tx := &l1coordinatortxs[i] + exitIdx, exitAccount, newExit, err := s.processL1Tx(exitTree, tx) if err != nil { return nil, nil, err } @@ -105,8 +107,9 @@ func (s *StateDB) ProcessTxs(l1usertxs, l1coordinatortxs []common.L1Tx, l2txs [] s.i++ } } - for _, tx := range l2txs { - exitIdx, exitAccount, newExit, err := s.processL2Tx(exitTree, &tx) + for i := range l2txs { + tx := &l2txs[i] + exitIdx, exitAccount, newExit, err := s.processL2Tx(exitTree, tx) if err != nil { return nil, nil, err } diff --git a/db/utils.go b/db/utils.go index af8772e..4e341c1 100644 --- a/db/utils.go +++ b/db/utils.go @@ -52,7 +52,10 @@ func initMeddler() { // BulkInsert performs a bulk insert with a single statement into the specified table. Example: // `db.BulkInsert(myDB, "INSERT INTO block (eth_block_num, timestamp, hash) VALUES %s", blocks[:])` -// Note that all the columns must be specified in the query, and they must be in the same order as in the table. +// Note that all the columns must be specified in the query, and they must be +// in the same order as in the table. +// Note that the fields in the structs need to be defined in the same order as +// in the table columns. func BulkInsert(db meddler.DB, q string, args interface{}) error { arrayValue := reflect.ValueOf(args) arrayLen := arrayValue.Len() @@ -150,3 +153,27 @@ func (b BigIntNullMeddler) PreWrite(fieldPtr interface{}) (saveValue interface{} } return base64.StdEncoding.EncodeToString(field.Bytes()), nil } + +// SliceToSlicePtrs converts any []Foo to []*Foo +func SliceToSlicePtrs(slice interface{}) interface{} { + v := reflect.ValueOf(slice) + vLen := v.Len() + typ := v.Type().Elem() + res := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(typ)), vLen, vLen) + for i := 0; i < vLen; i++ { + res.Index(i).Set(v.Index(i).Addr()) + } + return res.Interface() +} + +// SlicePtrsToSlice converts any []*Foo to []Foo +func SlicePtrsToSlice(slice interface{}) interface{} { + v := reflect.ValueOf(slice) + vLen := v.Len() + typ := v.Type().Elem().Elem() + res := reflect.MakeSlice(reflect.SliceOf(typ), vLen, vLen) + for i := 0; i < vLen; i++ { + res.Index(i).Set(v.Index(i).Elem()) + } + return res.Interface() +} diff --git a/db/utils_test.go b/db/utils_test.go new file mode 100644 index 0000000..a5c83b2 --- /dev/null +++ b/db/utils_test.go @@ -0,0 +1,35 @@ +package db + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type foo struct { + V int +} + +func TestSliceToSlicePtrs(t *testing.T) { + n := 16 + a := make([]foo, n) + for i := 0; i < n; i++ { + a[i] = foo{V: i} + } + b := SliceToSlicePtrs(a).([]*foo) + for i := 0; i < len(a); i++ { + assert.Equal(t, a[i], *b[i]) + } +} + +func TestSlicePtrsToSlice(t *testing.T) { + n := 16 + a := make([]*foo, n) + for i := 0; i < n; i++ { + a[i] = &foo{V: i} + } + b := SlicePtrsToSlice(a).([]foo) + for i := 0; i < len(a); i++ { + assert.Equal(t, *a[i], b[i]) + } +} diff --git a/go.mod b/go.mod index 06d0650..c5582be 100644 --- a/go.mod +++ b/go.mod @@ -27,3 +27,5 @@ require ( golang.org/x/tools/gopls v0.5.0 // indirect gopkg.in/go-playground/validator.v9 v9.29.1 ) + +// replace github.com/russross/meddler => /home/dev/git/iden3/hermez/meddler diff --git a/synchronizer/synchronizer_test.go b/synchronizer/synchronizer_test.go index e8217d8..db6b8eb 100644 --- a/synchronizer/synchronizer_test.go +++ b/synchronizer/synchronizer_test.go @@ -78,8 +78,8 @@ D (3): 15 require.Nil(t, err) } - for _, l1UserTx := range l1UserTxs[0] { - client.CtlAddL1TxUser(&l1UserTx) + for i := range l1UserTxs[0] { + client.CtlAddL1TxUser(&l1UserTxs[0][i]) } client.CtlMineBlock() diff --git a/test/ethclient.go b/test/ethclient.go index 45b108b..3c32d0f 100644 --- a/test/ethclient.go +++ b/test/ethclient.go @@ -333,7 +333,7 @@ func NewClient(l bool, timer Timer, addr *ethCommon.Address, setup *ClientSetup) ExitRoots: make([]*big.Int, 0), ExitNullifierMap: make(map[[256 / 8]byte]bool), // TokenID = 0 is ETH. Set first entry in TokenList with 0x0 address for ETH. - TokenList: []ethCommon.Address{ethCommon.Address{}}, + TokenList: []ethCommon.Address{{}}, TokenMap: make(map[ethCommon.Address]bool), MapL1TxQueue: mapL1TxQueue, LastL1L2Batch: 0, diff --git a/txselector/txselector.go b/txselector/txselector.go index 12700d1..eafb710 100644 --- a/txselector/txselector.go +++ b/txselector/txselector.go @@ -81,7 +81,7 @@ func (txsel *TxSelector) GetL2TxSelection(batchNum common.BatchNum) ([]common.Po _, err = txsel.localAccountsDB.GetAccount(&tx.FromIdx) if err == nil { // if FromIdx has an account into the AccountsDB - validTxs = append(validTxs, *tx) + validTxs = append(validTxs, tx) } } @@ -127,7 +127,7 @@ func (txsel *TxSelector) GetL1L2TxSelection(batchNum common.BatchNum, l1Txs []co // a L1CoordinatorTx of this type, in the DB there still seem // that needs to create a new L1CoordinatorTx, but as is already // created, the tx is valid - validTxs = append(validTxs, *l2TxsRaw[i]) + validTxs = append(validTxs, l2TxsRaw[i]) continue } @@ -143,7 +143,7 @@ func (txsel *TxSelector) GetL1L2TxSelection(batchNum common.BatchNum, l1Txs []co // account for ToEthAddr&ToBJJ already exist, // there is no need to create a new one. // tx valid, StateDB will use the ToIdx==0 to define the AuxToIdx - validTxs = append(validTxs, *l2TxsRaw[i]) + validTxs = append(validTxs, l2TxsRaw[i]) continue } // if not, check if AccountCreationAuth exist for that ToEthAddr&BJJ @@ -159,7 +159,7 @@ func (txsel *TxSelector) GetL1L2TxSelection(batchNum common.BatchNum, l1Txs []co log.Debugw("invalid L2Tx: ToIdx not found in StateDB, neither ToEthAddr & ToBJJ found in AccountCreationAuths L2DB", "ToIdx", l2TxsRaw[i].ToIdx, "ToEthAddr", l2TxsRaw[i].ToEthAddr, "ToBJJ", l2TxsRaw[i].ToBJJ) continue } - validTxs = append(validTxs, *l2TxsRaw[i]) + validTxs = append(validTxs, l2TxsRaw[i]) } else { // case: ToBJJ==0: // if idx exist for EthAddr use it @@ -168,7 +168,7 @@ func (txsel *TxSelector) GetL1L2TxSelection(batchNum common.BatchNum, l1Txs []co // account for ToEthAddr already exist, // there is no need to create a new one. // tx valid, StateDB will use the ToIdx==0 to define the AuxToIdx - validTxs = append(validTxs, *l2TxsRaw[i]) + validTxs = append(validTxs, l2TxsRaw[i]) continue } // if not, check if AccountCreationAuth exist for that ToEthAddr @@ -178,7 +178,7 @@ func (txsel *TxSelector) GetL1L2TxSelection(batchNum common.BatchNum, l1Txs []co log.Debugw("invalid L2Tx: ToIdx not found in StateDB, neither ToEthAddr found in AccountCreationAuths L2DB", "ToIdx", l2TxsRaw[i].ToIdx, "ToEthAddr", l2TxsRaw[i].ToEthAddr) continue } - validTxs = append(validTxs, *l2TxsRaw[i]) + validTxs = append(validTxs, l2TxsRaw[i]) } // create L1CoordinatorTx for the accountCreation l1CoordinatorTx := common.L1Tx{ @@ -199,7 +199,7 @@ func (txsel *TxSelector) GetL1L2TxSelection(batchNum common.BatchNum, l1Txs []co // account for ToEthAddr&ToBJJ already exist, (where ToEthAddr==0xff) // there is no need to create a new one. // tx valid, StateDB will use the ToIdx==0 to define the AuxToIdx - validTxs = append(validTxs, *l2TxsRaw[i]) + validTxs = append(validTxs, l2TxsRaw[i]) continue } // if idx don't exist for EthAddr&BJJ, @@ -230,10 +230,10 @@ func (txsel *TxSelector) GetL1L2TxSelection(batchNum common.BatchNum, l1Txs []co continue } // Account found in the DB, include the l2Tx in the selection - validTxs = append(validTxs, *l2TxsRaw[i]) + validTxs = append(validTxs, l2TxsRaw[i]) } else if *l2TxsRaw[i].ToIdx == common.Idx(1) { // nil already checked before // valid txs (of Exit type) - validTxs = append(validTxs, *l2TxsRaw[i]) + validTxs = append(validTxs, l2TxsRaw[i]) } }