diff --git a/api/account.go b/api/account.go index 154e989..54725c7 100644 --- a/api/account.go +++ b/api/account.go @@ -6,6 +6,8 @@ import ( "github.com/gin-gonic/gin" "github.com/hermeznetwork/hermez-node/apitypes" "github.com/hermeznetwork/hermez-node/db/historydb" + "github.com/hermeznetwork/hermez-node/db/statedb" + "github.com/hermeznetwork/tracerr" ) func (a *API) getAccount(c *gin.Context) { @@ -22,7 +24,7 @@ func (a *API) getAccount(c *gin.Context) { } // Get balance from stateDB - account, err := a.s.GetAccount(*idx) + account, err := a.s.LastGetAccount(*idx) if err != nil { retSQLErr(err, c) return @@ -56,19 +58,23 @@ func (a *API) getAccounts(c *gin.Context) { } // Get balances from stateDB - for x, apiAccount := range apiAccounts { - idx, err := stringToIdx(string(apiAccount.Idx), "Account Idx") - if err != nil { - retSQLErr(err, c) - return + if err := a.s.LastRead(func(sdb *statedb.Last) error { + for x, apiAccount := range apiAccounts { + idx, err := stringToIdx(string(apiAccount.Idx), "Account Idx") + if err != nil { + return tracerr.Wrap(err) + } + account, err := sdb.GetAccount(*idx) + if err != nil { + return tracerr.Wrap(err) + } + apiAccounts[x].Balance = apitypes.NewBigIntStr(account.Balance) + apiAccounts[x].Nonce = account.Nonce } - account, err := a.s.GetAccount(*idx) - if err != nil { - retSQLErr(err, c) - return - } - apiAccounts[x].Balance = apitypes.NewBigIntStr(account.Balance) - apiAccounts[x].Nonce = account.Nonce + return nil + }); err != nil { + retSQLErr(err, c) + return } // Build succesfull response diff --git a/api/account_test.go b/api/account_test.go index 9faa25d..4058091 100644 --- a/api/account_test.go +++ b/api/account_test.go @@ -10,6 +10,7 @@ import ( "github.com/hermeznetwork/hermez-node/db/historydb" "github.com/mitchellh/copystructure" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type testAccount struct { @@ -76,40 +77,40 @@ func TestGetAccounts(t *testing.T) { // Filter by BJJ path := fmt.Sprintf("%s?BJJ=%s&limit=%d", endpoint, tc.accounts[0].PublicKey, limit) err := doGoodReqPaginated(path, historydb.OrderAsc, &testAccountsResponse{}, appendIter) - assert.NoError(t, err) + require.NoError(t, err) assert.Greater(t, len(fetchedAccounts), 0) assert.LessOrEqual(t, len(fetchedAccounts), len(tc.accounts)) fetchedAccounts = []testAccount{} // Filter by ethAddr path = fmt.Sprintf("%s?hezEthereumAddress=%s&limit=%d", endpoint, tc.accounts[3].EthAddr, limit) err = doGoodReqPaginated(path, historydb.OrderAsc, &testAccountsResponse{}, appendIter) - assert.NoError(t, err) + require.NoError(t, err) assert.Greater(t, len(fetchedAccounts), 0) assert.LessOrEqual(t, len(fetchedAccounts), len(tc.accounts)) fetchedAccounts = []testAccount{} // both filters (incompatible) path = fmt.Sprintf("%s?hezEthereumAddress=%s&BJJ=%s&limit=%d", endpoint, tc.accounts[0].EthAddr, tc.accounts[0].PublicKey, limit) err = doBadReq("GET", path, nil, 400) - assert.NoError(t, err) + require.NoError(t, err) fetchedAccounts = []testAccount{} // Filter by token IDs path = fmt.Sprintf("%s?tokenIds=%s&limit=%d", endpoint, stringIds, limit) err = doGoodReqPaginated(path, historydb.OrderAsc, &testAccountsResponse{}, appendIter) - assert.NoError(t, err) + require.NoError(t, err) assert.Greater(t, len(fetchedAccounts), 0) assert.LessOrEqual(t, len(fetchedAccounts), len(tc.accounts)) fetchedAccounts = []testAccount{} // Token Ids + bjj path = fmt.Sprintf("%s?tokenIds=%s&BJJ=%s&limit=%d", endpoint, stringIds, tc.accounts[10].PublicKey, limit) err = doGoodReqPaginated(path, historydb.OrderAsc, &testAccountsResponse{}, appendIter) - assert.NoError(t, err) + require.NoError(t, err) assert.Greater(t, len(fetchedAccounts), 0) assert.LessOrEqual(t, len(fetchedAccounts), len(tc.accounts)) fetchedAccounts = []testAccount{} // No filters (checks response content) path = fmt.Sprintf("%s?limit=%d", endpoint, limit) err = doGoodReqPaginated(path, historydb.OrderAsc, &testAccountsResponse{}, appendIter) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, len(tc.accounts), len(fetchedAccounts)) for i := 0; i < len(fetchedAccounts); i++ { fetchedAccounts[i].Token.ItemID = 0 @@ -132,7 +133,7 @@ func TestGetAccounts(t *testing.T) { } } err = doGoodReqPaginated(path, historydb.OrderDesc, &testAccountsResponse{}, appendIter) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, len(reversedAccounts), len(fetchedAccounts)) for i := 0; i < len(fetchedAccounts); i++ { reversedAccounts[i].Token.ItemID = 0 @@ -147,21 +148,21 @@ func TestGetAccounts(t *testing.T) { // 400 path = fmt.Sprintf("%s?hezEthereumAddress=hez:0x123456", endpoint) err = doBadReq("GET", path, nil, 400) - assert.NoError(t, err) + require.NoError(t, err) // Test GetAccount path = fmt.Sprintf("%s/%s", endpoint, fetchedAccounts[2].Idx) account := testAccount{} - assert.NoError(t, doGoodReq("GET", path, nil, &account)) + require.NoError(t, doGoodReq("GET", path, nil, &account)) account.Token.ItemID = 0 assert.Equal(t, fetchedAccounts[2], account) // 400 path = fmt.Sprintf("%s/hez:12345", endpoint) err = doBadReq("GET", path, nil, 400) - assert.NoError(t, err) + require.NoError(t, err) // 404 path = fmt.Sprintf("%s/hez:10:12345", endpoint) err = doBadReq("GET", path, nil, 404) - assert.NoError(t, err) + require.NoError(t, err) } diff --git a/api/api_test.go b/api/api_test.go index c84ea6a..97dee27 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -355,6 +355,10 @@ func TestMain(m *testing.M) { panic(err) } } + // Make a checkpoint to make the accounts available in Last + if err := api.s.MakeCheckpoint(); err != nil { + panic(err) + } // Generate Coordinators and add them to HistoryDB const nCoords = 10 diff --git a/api/txspool.go b/api/txspool.go index 8f86d31..9f74121 100644 --- a/api/txspool.go +++ b/api/txspool.go @@ -170,7 +170,7 @@ func (a *API) verifyPoolL2TxWrite(txw *l2db.PoolL2TxWrite) error { return tracerr.Wrap(err) } // Get public key - account, err := a.s.GetAccount(poolTx.FromIdx) + account, err := a.s.LastGetAccount(poolTx.FromIdx) if err != nil { return tracerr.Wrap(err) } diff --git a/coordinator/pipeline_test.go b/coordinator/pipeline_test.go index f6a1e3d..1bd7bcf 100644 --- a/coordinator/pipeline_test.go +++ b/coordinator/pipeline_test.go @@ -150,7 +150,7 @@ func preloadSync(t *testing.T, ethClient *test.Client, sync *synchronizer.Synchr require.Nil(t, err) require.Equal(t, testTokensLen*testUsersLen, len(dbAccounts)) - sdbAccounts, err := stateDB.GetAccounts() + sdbAccounts, err := stateDB.TestGetAccounts() require.Nil(t, err) require.Equal(t, testTokensLen*testUsersLen, len(sdbAccounts)) @@ -200,12 +200,12 @@ PoolTransfer(0) User2-User3: 300 (126) }) require.NoError(t, err) // Sanity check - sdbAccounts, err := pipeline.txSelector.LocalAccountsDB().GetAccounts() + sdbAccounts, err := pipeline.txSelector.LocalAccountsDB().TestGetAccounts() require.Nil(t, err) require.Equal(t, testTokensLen*testUsersLen, len(sdbAccounts)) // Sanity check - sdbAccounts, err = pipeline.batchBuilder.LocalStateDB().GetAccounts() + sdbAccounts, err = pipeline.batchBuilder.LocalStateDB().TestGetAccounts() require.Nil(t, err) require.Equal(t, testTokensLen*testUsersLen, len(sdbAccounts)) diff --git a/db/kvdb/kvdb.go b/db/kvdb/kvdb.go index ce0a71c..5a9825d 100644 --- a/db/kvdb/kvdb.go +++ b/db/kvdb/kvdb.go @@ -24,6 +24,9 @@ const ( // PathCurrent defines the subpath of the current Batch in the subpath // of the KVDB PathCurrent = "current" + // PathLast defines the subpath of the last Batch in the subpath + // of the StateDB + PathLast = "last" ) var ( @@ -42,6 +45,58 @@ type KVDB struct { CurrentBatch common.BatchNum keep int m sync.Mutex + last *Last +} + +// Last is a consistent view to the last batch of the stateDB that can +// be queried concurrently. +type Last struct { + db *pebble.Storage + path string + rw sync.RWMutex +} + +func (k *Last) setNew() error { + k.rw.Lock() + defer k.rw.Unlock() + if k.db != nil { + k.db.Close() + } + lastPath := path.Join(k.path, PathLast) + err := os.RemoveAll(lastPath) + if err != nil { + return tracerr.Wrap(err) + } + db, err := pebble.NewPebbleStorage(path.Join(k.path, lastPath), false) + if err != nil { + return tracerr.Wrap(err) + } + k.db = db + return nil +} + +func (k *Last) set(kvdb *KVDB, batchNum common.BatchNum) error { + k.rw.Lock() + defer k.rw.Unlock() + if k.db != nil { + k.db.Close() + } + lastPath := path.Join(k.path, PathLast) + if err := kvdb.MakeCheckpointFromTo(batchNum, lastPath); err != nil { + return tracerr.Wrap(err) + } + db, err := pebble.NewPebbleStorage(lastPath, false) + if err != nil { + return tracerr.Wrap(err) + } + k.db = db + return nil +} + +func (k *Last) close() { + k.rw.Lock() + defer k.rw.Unlock() + k.db.Close() } // NewKVDB creates a new KVDB, allowing to use an in-memory or in-disk storage. @@ -58,6 +113,9 @@ func NewKVDB(pathDB string, keep int) (*KVDB, error) { path: pathDB, db: sto, keep: keep, + last: &Last{ + path: pathDB, + }, } // load currentBatch kvdb.CurrentBatch, err = kvdb.GetCurrentBatch() @@ -74,6 +132,13 @@ func NewKVDB(pathDB string, keep int) (*KVDB, error) { return kvdb, nil } +// LastRead is a thread-safe method to query the last KVDB +func (kvdb *KVDB) LastRead(fn func(db *pebble.Storage) error) error { + kvdb.last.rw.RLock() + defer kvdb.last.rw.RUnlock() + return fn(kvdb.last.db) +} + // DB returns the *pebble.Storage from the KVDB func (kvdb *KVDB) DB() *pebble.Storage { return kvdb.db @@ -139,14 +204,21 @@ func (kvdb *KVDB) reset(batchNum common.BatchNum, closeCurrent bool) error { kvdb.db = sto kvdb.CurrentIdx = common.RollupConstReservedIDx // 255 kvdb.CurrentBatch = 0 + if err := kvdb.last.setNew(); err != nil { + return tracerr.Wrap(err) + } return nil } - // copy 'BatchNumX' to 'current' + // copy 'batchNum' to 'current' if err := kvdb.MakeCheckpointFromTo(batchNum, currentPath); err != nil { return tracerr.Wrap(err) } + // copy 'batchNum' to 'last' + if err := kvdb.last.set(kvdb, batchNum); err != nil { + return tracerr.Wrap(err) + } // open the new 'current' sto, err := pebble.NewPebbleStorage(currentPath, false) @@ -334,6 +406,10 @@ func (kvdb *KVDB) MakeCheckpoint() error { if err := kvdb.db.Pebble().Checkpoint(checkpointPath); err != nil { return tracerr.Wrap(err) } + // copy 'CurrentBatch' to 'last' + if err := kvdb.last.set(kvdb, kvdb.CurrentBatch); err != nil { + return tracerr.Wrap(err) + } // delete old checkpoints if err := kvdb.deleteOldCheckpoints(); err != nil { return tracerr.Wrap(err) @@ -456,4 +532,5 @@ func pebbleMakeCheckpoint(source, dest string) error { // Close the DB func (kvdb *KVDB) Close() { kvdb.db.Close() + kvdb.last.close() } diff --git a/db/statedb/statedb.go b/db/statedb/statedb.go index dceb144..ab3c973 100644 --- a/db/statedb/statedb.go +++ b/db/statedb/statedb.go @@ -11,6 +11,7 @@ import ( "github.com/hermeznetwork/tracerr" "github.com/iden3/go-merkletree" "github.com/iden3/go-merkletree/db" + "github.com/iden3/go-merkletree/db/pebble" ) var ( @@ -58,11 +59,46 @@ type TypeStateDB string // StateDB represents the StateDB object type StateDB struct { - path string - Typ TypeStateDB - db *kvdb.KVDB - MT *merkletree.MerkleTree - keep int + path string + Typ TypeStateDB + db *kvdb.KVDB + nLevels int + MT *merkletree.MerkleTree + keep int +} + +// Last offers a subset of view methods of the StateDB that can be +// called via the LastRead method of StateDB in a thread-safe manner to obtain +// a consistent view to the last batch of the StateDB. +type Last struct { + db db.Storage +} + +// GetAccount returns the account for the given Idx +func (s *Last) GetAccount(idx common.Idx) (*common.Account, error) { + return GetAccountInTreeDB(s.db, idx) +} + +// GetCurrentBatch returns the current BatchNum stored in Last.db +func (s *Last) GetCurrentBatch() (common.BatchNum, error) { + cbBytes, err := s.db.Get(kvdb.KeyCurrentBatch) + if tracerr.Unwrap(err) == db.ErrNotFound { + return 0, nil + } else if err != nil { + return 0, tracerr.Wrap(err) + } + return common.BatchNumFromBytes(cbBytes) +} + +// DB returns the underlying storage of Last +func (s *Last) DB() db.Storage { + return s.db +} + +// GetAccounts returns all the accounts in the db. Use for debugging pruposes +// only. +func (s *Last) GetAccounts() ([]common.Account, error) { + return getAccounts(s.db) } // NewStateDB creates a new StateDB, allowing to use an in-memory or in-disk @@ -89,14 +125,72 @@ func NewStateDB(pathDB string, keep int, typ TypeStateDB, nLevels int) (*StateDB } return &StateDB{ - path: pathDB, - db: kv, - MT: mt, - Typ: typ, - keep: keep, + path: pathDB, + db: kv, + nLevels: nLevels, + MT: mt, + Typ: typ, + keep: keep, }, nil } +// LastRead is a thread-safe method to query the last checkpoint of the StateDB +// via the Last type methods +func (s *StateDB) LastRead(fn func(sdbLast *Last) error) error { + return s.db.LastRead( + func(db *pebble.Storage) error { + return fn(&Last{ + db: db, + }) + }, + ) +} + +// LastGetAccount is a thread-safe method to query an account in the last +// checkpoint of the StateDB. +func (s *StateDB) LastGetAccount(idx common.Idx) (*common.Account, error) { + var account *common.Account + if err := s.LastRead(func(sdb *Last) error { + var err error + account, err = sdb.GetAccount(idx) + return err + }); err != nil { + return nil, tracerr.Wrap(err) + } + return account, nil +} + +// LastGetCurrentBatch is a thread-safe method to get the current BatchNum in +// the last checkpoint of the StateDB. +func (s *StateDB) LastGetCurrentBatch() (common.BatchNum, error) { + var batchNum common.BatchNum + if err := s.LastRead(func(sdb *Last) error { + var err error + batchNum, err = sdb.GetCurrentBatch() + return err + }); err != nil { + return 0, tracerr.Wrap(err) + } + return batchNum, nil +} + +// LastMTGetRoot returns the root of the underlying Merkle Tree in the last +// checkpoint of the StateDB. +func (s *StateDB) LastMTGetRoot() (*big.Int, error) { + var root *big.Int + if err := s.LastRead(func(sdb *Last) error { + mt, err := merkletree.NewMerkleTree(sdb.DB().WithPrefix(PrefixKeyMT), s.nLevels) + if err != nil { + return tracerr.Wrap(err) + } + root = mt.Root().BigInt() + return nil + }); err != nil { + return nil, tracerr.Wrap(err) + } + return root, nil +} + // MakeCheckpoint does a checkpoint at the given batchNum in the defined path. // Internally this advances & stores the current BatchNum, and then stores a // Checkpoint of the current state of the StateDB. @@ -115,8 +209,8 @@ func (s *StateDB) CurrentIdx() common.Idx { return s.db.CurrentIdx } -// GetCurrentBatch returns the current BatchNum stored in the StateDB.db -func (s *StateDB) GetCurrentBatch() (common.BatchNum, error) { +// getCurrentBatch returns the current BatchNum stored in the StateDB.db +func (s *StateDB) getCurrentBatch() (common.BatchNum, error) { return s.db.GetCurrentBatch() } @@ -157,35 +251,50 @@ func (s *StateDB) GetAccount(idx common.Idx) (*common.Account, error) { return GetAccountInTreeDB(s.db.DB(), idx) } -// GetAccounts returns all the accounts in the db. Use for debugging pruposes -// only. -func (s *StateDB) GetAccounts() ([]common.Account, error) { - idxDB := s.db.StorageWithPrefix(PrefixKeyIdx) - idxs := []common.Idx{} - // NOTE: Current implementation of Iterate in the pebble interface is - // not efficient, as it iterates over all keys. Improve it following - // this example: https://github.com/cockroachdb/pebble/pull/923/files +func accountsIter(db db.Storage, fn func(a *common.Account) (bool, error)) error { + idxDB := db.WithPrefix(PrefixKeyIdx) if err := idxDB.Iterate(func(k []byte, v []byte) (bool, error) { idx, err := common.IdxFromBytes(k) if err != nil { return false, tracerr.Wrap(err) } - idxs = append(idxs, idx) - return true, nil + acc, err := GetAccountInTreeDB(db, idx) + if err != nil { + return false, tracerr.Wrap(err) + } + ok, err := fn(acc) + if err != nil { + return false, tracerr.Wrap(err) + } + return ok, nil }); err != nil { - return nil, tracerr.Wrap(err) + return tracerr.Wrap(err) } + return nil +} + +func getAccounts(db db.Storage) ([]common.Account, error) { accs := []common.Account{} - for i := range idxs { - acc, err := s.GetAccount(idxs[i]) - if err != nil { - return nil, tracerr.Wrap(err) - } - accs = append(accs, *acc) + if err := accountsIter( + db, + func(a *common.Account) (bool, error) { + accs = append(accs, *a) + return true, nil + }, + ); err != nil { + return nil, tracerr.Wrap(err) } return accs, nil } +// TestGetAccounts returns all the accounts in the db. Use only in tests. +// Outside tests getting all the accounts is discouraged because it's an +// expensive operation, but if you must do it, use `LastRead()` method to get a +// thread-safe and consistent view of the stateDB. +func (s *StateDB) TestGetAccounts() ([]common.Account, error) { + return getAccounts(s.db.DB()) +} + // GetAccountInTreeDB is abstracted from StateDB to be used from StateDB and // from ExitTree. GetAccount returns the account for the given Idx func GetAccountInTreeDB(sto db.Storage, idx common.Idx) (*common.Account, error) { @@ -336,11 +445,6 @@ func (s *StateDB) MTGetProof(idx common.Idx) (*merkletree.CircomVerifierProof, e return p, nil } -// MTGetRoot returns the current root of the underlying Merkle Tree -func (s *StateDB) MTGetRoot() *big.Int { - return s.MT.Root().BigInt() -} - // Close the StateDB func (s *StateDB) Close() { s.db.Close() diff --git a/db/statedb/statedb_test.go b/db/statedb/statedb_test.go index da8736f..ad136fe 100644 --- a/db/statedb/statedb_test.go +++ b/db/statedb/statedb_test.go @@ -43,10 +43,10 @@ func newAccount(t *testing.T, i int) *common.Account { func TestNewStateDBIntermediateState(t *testing.T) { dir, err := ioutil.TempDir("", "tmpdb") require.NoError(t, err) - defer assert.NoError(t, os.RemoveAll(dir)) + defer require.NoError(t, os.RemoveAll(dir)) sdb, err := NewStateDB(dir, 128, TypeTxSelector, 0) - assert.NoError(t, err) + require.NoError(t, err) // test values k0 := []byte("testkey0") @@ -56,19 +56,26 @@ func TestNewStateDBIntermediateState(t *testing.T) { // store some data tx, err := sdb.db.DB().NewTx() - assert.NoError(t, err) + require.NoError(t, err) err = tx.Put(k0, v0) - assert.NoError(t, err) + require.NoError(t, err) err = tx.Commit() - assert.NoError(t, err) + require.NoError(t, err) v, err := sdb.db.DB().Get(k0) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, v0, v) - // Close PebbleDB before creating a new StateDB - err = sdb.db.DB().Pebble().Close() + // k0 not yet in last + err = sdb.LastRead(func(sdb *Last) error { + _, err := sdb.DB().Get(k0) + assert.Equal(t, db.ErrNotFound, tracerr.Unwrap(err)) + return nil + }) require.NoError(t, err) + // Close PebbleDB before creating a new StateDB + sdb.Close() + // call NewStateDB which should get the db at the last checkpoint state // executing a Reset (discarding the last 'testkey0'&'testvalue0' data) sdb, err = NewStateDB(dir, 128, TypeTxSelector, 0) @@ -78,54 +85,90 @@ func TestNewStateDBIntermediateState(t *testing.T) { assert.Equal(t, db.ErrNotFound, tracerr.Unwrap(err)) assert.Nil(t, v) + // k0 not in last + err = sdb.LastRead(func(sdb *Last) error { + _, err := sdb.DB().Get(k0) + assert.Equal(t, db.ErrNotFound, tracerr.Unwrap(err)) + return nil + }) + require.NoError(t, err) + // store the same data from the beginning that has ben lost since last NewStateDB tx, err = sdb.db.DB().NewTx() - assert.NoError(t, err) + require.NoError(t, err) err = tx.Put(k0, v0) - assert.NoError(t, err) + require.NoError(t, err) err = tx.Commit() - assert.NoError(t, err) + require.NoError(t, err) v, err = sdb.db.DB().Get(k0) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, v0, v) + // k0 yet not in last + err = sdb.LastRead(func(sdb *Last) error { + _, err := sdb.DB().Get(k0) + assert.Equal(t, db.ErrNotFound, tracerr.Unwrap(err)) + return nil + }) + require.NoError(t, err) + // make checkpoints with the current state - bn, err := sdb.db.GetCurrentBatch() - assert.NoError(t, err) + bn, err := sdb.getCurrentBatch() + require.NoError(t, err) assert.Equal(t, common.BatchNum(0), bn) err = sdb.db.MakeCheckpoint() - assert.NoError(t, err) - bn, err = sdb.db.GetCurrentBatch() - assert.NoError(t, err) + require.NoError(t, err) + bn, err = sdb.getCurrentBatch() + require.NoError(t, err) assert.Equal(t, common.BatchNum(1), bn) + // k0 in last + err = sdb.LastRead(func(sdb *Last) error { + v, err := sdb.DB().Get(k0) + require.NoError(t, err) + assert.Equal(t, v0, v) + return nil + }) + require.NoError(t, err) + // write more data tx, err = sdb.db.DB().NewTx() - assert.NoError(t, err) + require.NoError(t, err) err = tx.Put(k1, v1) - assert.NoError(t, err) + require.NoError(t, err) + err = tx.Put(k0, v1) // overwrite k0 with v1 + require.NoError(t, err) err = tx.Commit() - assert.NoError(t, err) + require.NoError(t, err) v, err = sdb.db.DB().Get(k1) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, v1, v) - // Close PebbleDB before creating a new StateDB - err = sdb.db.DB().Pebble().Close() + err = sdb.LastRead(func(sdb *Last) error { + v, err := sdb.DB().Get(k0) + require.NoError(t, err) + assert.Equal(t, v0, v) + return nil + }) require.NoError(t, err) + // Close PebbleDB before creating a new StateDB + sdb.Close() + // call NewStateDB which should get the db at the last checkpoint state // executing a Reset (discarding the last 'testkey1'&'testvalue1' data) sdb, err = NewStateDB(dir, 128, TypeTxSelector, 0) require.NoError(t, err) - bn, err = sdb.db.GetCurrentBatch() - assert.NoError(t, err) + bn, err = sdb.getCurrentBatch() + require.NoError(t, err) assert.Equal(t, common.BatchNum(1), bn) + // we closed the db without doing a checkpoint after overwriting k0, so + // it's back to v0 v, err = sdb.db.DB().Get(k0) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, v0, v) v, err = sdb.db.DB().Get(k1) @@ -137,10 +180,10 @@ func TestNewStateDBIntermediateState(t *testing.T) { func TestStateDBWithoutMT(t *testing.T) { dir, err := ioutil.TempDir("", "tmpdb") require.NoError(t, err) - defer assert.NoError(t, os.RemoveAll(dir)) + defer require.NoError(t, os.RemoveAll(dir)) sdb, err := NewStateDB(dir, 128, TypeTxSelector, 0) - assert.NoError(t, err) + require.NoError(t, err) // create test accounts var accounts []*common.Account @@ -157,20 +200,20 @@ func TestStateDBWithoutMT(t *testing.T) { // add test accounts for i := 0; i < len(accounts); i++ { _, err = sdb.CreateAccount(accounts[i].Idx, accounts[i]) - assert.NoError(t, err) + require.NoError(t, err) } for i := 0; i < len(accounts); i++ { existingAccount := accounts[i].Idx accGetted, err := sdb.GetAccount(existingAccount) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, accounts[i], accGetted) } // try already existing idx and get error existingAccount := common.Idx(256) _, err = sdb.GetAccount(existingAccount) // check that exist - assert.NoError(t, err) + require.NoError(t, err) _, err = sdb.CreateAccount(common.Idx(256), accounts[1]) // check that can not be created twice assert.NotNil(t, err) assert.Equal(t, ErrAccountAlreadyExists, tracerr.Unwrap(err)) @@ -180,7 +223,7 @@ func TestStateDBWithoutMT(t *testing.T) { accounts[i].Nonce = accounts[i].Nonce + 1 existingAccount = common.Idx(i) _, err = sdb.UpdateAccount(existingAccount, accounts[i]) - assert.NoError(t, err) + require.NoError(t, err) } _, err = sdb.MTGetProof(common.Idx(1)) @@ -191,10 +234,10 @@ func TestStateDBWithoutMT(t *testing.T) { func TestStateDBWithMT(t *testing.T) { dir, err := ioutil.TempDir("", "tmpdb") require.NoError(t, err) - defer assert.NoError(t, os.RemoveAll(dir)) + defer require.NoError(t, os.RemoveAll(dir)) sdb, err := NewStateDB(dir, 128, TypeSynchronizer, 32) - assert.NoError(t, err) + require.NoError(t, err) // create test accounts var accounts []*common.Account @@ -210,33 +253,33 @@ func TestStateDBWithMT(t *testing.T) { // add test accounts for i := 0; i < len(accounts); i++ { _, err = sdb.CreateAccount(accounts[i].Idx, accounts[i]) - assert.NoError(t, err) + require.NoError(t, err) } for i := 0; i < len(accounts); i++ { accGetted, err := sdb.GetAccount(accounts[i].Idx) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, accounts[i], accGetted) } // try already existing idx and get error _, err = sdb.GetAccount(common.Idx(256)) // check that exist - assert.NoError(t, err) + require.NoError(t, err) _, err = sdb.CreateAccount(common.Idx(256), accounts[1]) // check that can not be created twice assert.NotNil(t, err) assert.Equal(t, ErrAccountAlreadyExists, tracerr.Unwrap(err)) _, err = sdb.MTGetProof(common.Idx(256)) - assert.NoError(t, err) + require.NoError(t, err) // update accounts for i := 0; i < len(accounts); i++ { accounts[i].Nonce = accounts[i].Nonce + 1 _, err = sdb.UpdateAccount(accounts[i].Idx, accounts[i]) - assert.NoError(t, err) + require.NoError(t, err) } a, err := sdb.GetAccount(common.Idx(256)) // check that account value has been updated - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, accounts[0].Nonce, a.Nonce) } @@ -245,10 +288,13 @@ func TestStateDBWithMT(t *testing.T) { func TestCheckpoints(t *testing.T) { dir, err := ioutil.TempDir("", "sdb") require.NoError(t, err) - defer assert.NoError(t, os.RemoveAll(dir)) + defer require.NoError(t, os.RemoveAll(dir)) sdb, err := NewStateDB(dir, 128, TypeSynchronizer, 32) - assert.NoError(t, err) + require.NoError(t, err) + + err = sdb.Reset(0) + require.NoError(t, err) // create test accounts var accounts []*common.Account @@ -259,22 +305,33 @@ func TestCheckpoints(t *testing.T) { // add test accounts for i := 0; i < len(accounts); i++ { _, err = sdb.CreateAccount(accounts[i].Idx, accounts[i]) - assert.NoError(t, err) + require.NoError(t, err) } + // account doesn't exist in Last checkpoint + _, err = sdb.LastGetAccount(accounts[0].Idx) + assert.Equal(t, db.ErrNotFound, tracerr.Unwrap(err)) // do checkpoints and check that currentBatch is correct - err = sdb.db.MakeCheckpoint() - assert.NoError(t, err) - cb, err := sdb.db.GetCurrentBatch() - assert.NoError(t, err) + err = sdb.MakeCheckpoint() + require.NoError(t, err) + cb, err := sdb.getCurrentBatch() + require.NoError(t, err) assert.Equal(t, common.BatchNum(1), cb) + // account exists in Last checkpoint + accCur, err := sdb.GetAccount(accounts[0].Idx) + require.NoError(t, err) + accLast, err := sdb.LastGetAccount(accounts[0].Idx) + require.NoError(t, err) + assert.Equal(t, accounts[0], accLast) + assert.Equal(t, accCur, accLast) + for i := 1; i < 10; i++ { - err = sdb.db.MakeCheckpoint() - assert.NoError(t, err) + err = sdb.MakeCheckpoint() + require.NoError(t, err) - cb, err = sdb.db.GetCurrentBatch() - assert.NoError(t, err) + cb, err = sdb.getCurrentBatch() + require.NoError(t, err) assert.Equal(t, common.BatchNum(i+1), cb) } @@ -282,7 +339,7 @@ func TestCheckpoints(t *testing.T) { // reset checkpoint err = sdb.Reset(3) - assert.NoError(t, err) + require.NoError(t, err) // check that reset can be repeated (as there exist the 'current' and // 'BatchNum3', from where the 'current' is a copy) @@ -290,21 +347,21 @@ func TestCheckpoints(t *testing.T) { require.NoError(t, err) // check that currentBatch is as expected after Reset - cb, err = sdb.db.GetCurrentBatch() - assert.NoError(t, err) + cb, err = sdb.getCurrentBatch() + require.NoError(t, err) assert.Equal(t, common.BatchNum(3), cb) // advance one checkpoint and check that currentBatch is fine - err = sdb.db.MakeCheckpoint() - assert.NoError(t, err) - cb, err = sdb.db.GetCurrentBatch() - assert.NoError(t, err) + err = sdb.MakeCheckpoint() + require.NoError(t, err) + cb, err = sdb.getCurrentBatch() + require.NoError(t, err) assert.Equal(t, common.BatchNum(4), cb) err = sdb.db.DeleteCheckpoint(common.BatchNum(1)) - assert.NoError(t, err) + require.NoError(t, err) err = sdb.db.DeleteCheckpoint(common.BatchNum(2)) - assert.NoError(t, err) + require.NoError(t, err) err = sdb.db.DeleteCheckpoint(common.BatchNum(1)) // does not exist, should return err assert.NotNil(t, err) err = sdb.db.DeleteCheckpoint(common.BatchNum(2)) // does not exist, should return err @@ -313,43 +370,43 @@ func TestCheckpoints(t *testing.T) { // Create a LocalStateDB from the initial StateDB dirLocal, err := ioutil.TempDir("", "ldb") require.NoError(t, err) - defer assert.NoError(t, os.RemoveAll(dirLocal)) + defer require.NoError(t, os.RemoveAll(dirLocal)) ldb, err := NewLocalStateDB(dirLocal, 128, sdb, TypeBatchBuilder, 32) - assert.NoError(t, err) + require.NoError(t, err) // get checkpoint 4 from sdb (StateDB) to ldb (LocalStateDB) err = ldb.Reset(4, true) - assert.NoError(t, err) + require.NoError(t, err) // check that currentBatch is 4 after the Reset - cb, err = ldb.db.GetCurrentBatch() - assert.NoError(t, err) + cb, err = ldb.getCurrentBatch() + require.NoError(t, err) assert.Equal(t, common.BatchNum(4), cb) // advance one checkpoint in ldb - err = ldb.db.MakeCheckpoint() - assert.NoError(t, err) - cb, err = ldb.db.GetCurrentBatch() - assert.NoError(t, err) + err = ldb.MakeCheckpoint() + require.NoError(t, err) + cb, err = ldb.getCurrentBatch() + require.NoError(t, err) assert.Equal(t, common.BatchNum(5), cb) // Create a 2nd LocalStateDB from the initial StateDB dirLocal2, err := ioutil.TempDir("", "ldb2") require.NoError(t, err) - defer assert.NoError(t, os.RemoveAll(dirLocal2)) + defer require.NoError(t, os.RemoveAll(dirLocal2)) ldb2, err := NewLocalStateDB(dirLocal2, 128, sdb, TypeBatchBuilder, 32) - assert.NoError(t, err) + require.NoError(t, err) // get checkpoint 4 from sdb (StateDB) to ldb (LocalStateDB) err = ldb2.Reset(4, true) - assert.NoError(t, err) + require.NoError(t, err) // check that currentBatch is 4 after the Reset cb, err = ldb2.db.GetCurrentBatch() - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, common.BatchNum(4), cb) // advance one checkpoint in ldb2 err = ldb2.db.MakeCheckpoint() - assert.NoError(t, err) + require.NoError(t, err) cb, err = ldb2.db.GetCurrentBatch() - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, common.BatchNum(5), cb) debug := false @@ -365,7 +422,7 @@ func TestStateDBGetAccounts(t *testing.T) { require.NoError(t, err) sdb, err := NewStateDB(dir, 128, TypeTxSelector, 0) - assert.NoError(t, err) + require.NoError(t, err) // create test accounts var accounts []common.Account @@ -380,14 +437,14 @@ func TestStateDBGetAccounts(t *testing.T) { require.NoError(t, err) } - dbAccounts, err := sdb.GetAccounts() + dbAccounts, err := sdb.TestGetAccounts() require.NoError(t, err) assert.Equal(t, accounts, dbAccounts) } func printCheckpoints(t *testing.T, path string) { files, err := ioutil.ReadDir(path) - assert.NoError(t, err) + require.NoError(t, err) fmt.Println(path) for _, f := range files { @@ -409,7 +466,7 @@ func bigFromStr(h string, u int) *big.Int { func TestCheckAccountsTreeTestVectors(t *testing.T) { dir, err := ioutil.TempDir("", "tmpdb") require.NoError(t, err) - defer assert.NoError(t, os.RemoveAll(dir)) + defer require.NoError(t, os.RemoveAll(dir)) sdb, err := NewStateDB(dir, 128, TypeSynchronizer, 32) require.NoError(t, err) @@ -483,7 +540,7 @@ func TestCheckAccountsTreeTestVectors(t *testing.T) { func TestListCheckpoints(t *testing.T) { dir, err := ioutil.TempDir("", "tmpdb") require.NoError(t, err) - defer assert.NoError(t, os.RemoveAll(dir)) + defer require.NoError(t, os.RemoveAll(dir)) sdb, err := NewStateDB(dir, 128, TypeSynchronizer, 32) require.NoError(t, err) @@ -491,7 +548,7 @@ func TestListCheckpoints(t *testing.T) { numCheckpoints := 16 // do checkpoints for i := 0; i < numCheckpoints; i++ { - err = sdb.db.MakeCheckpoint() + err = sdb.MakeCheckpoint() require.NoError(t, err) } list, err := sdb.db.ListCheckpoints() @@ -515,7 +572,7 @@ func TestListCheckpoints(t *testing.T) { func TestDeleteOldCheckpoints(t *testing.T) { dir, err := ioutil.TempDir("", "tmpdb") require.NoError(t, err) - defer assert.NoError(t, os.RemoveAll(dir)) + defer require.NoError(t, os.RemoveAll(dir)) keep := 16 sdb, err := NewStateDB(dir, keep, TypeSynchronizer, 32) @@ -525,7 +582,7 @@ func TestDeleteOldCheckpoints(t *testing.T) { // do checkpoints and check that we never have more than `keep` // checkpoints for i := 0; i < numCheckpoints; i++ { - err = sdb.db.MakeCheckpoint() + err = sdb.MakeCheckpoint() require.NoError(t, err) checkpoints, err := sdb.db.ListCheckpoints() require.NoError(t, err) diff --git a/synchronizer/synchronizer_test.go b/synchronizer/synchronizer_test.go index d07e790..a356f8c 100644 --- a/synchronizer/synchronizer_test.go +++ b/synchronizer/synchronizer_test.go @@ -17,6 +17,7 @@ import ( "github.com/hermeznetwork/hermez-node/db/historydb" "github.com/hermeznetwork/hermez-node/db/statedb" "github.com/hermeznetwork/hermez-node/eth" + "github.com/hermeznetwork/hermez-node/log" "github.com/hermeznetwork/hermez-node/test" "github.com/hermeznetwork/hermez-node/test/til" "github.com/jinzhu/copier" @@ -253,7 +254,7 @@ func checkSyncBlock(t *testing.T, s *Synchronizer, blockNum int, block, syncBloc // Compare accounts from HistoryDB with StateDB (they should match) dbAccounts, err := s.historyDB.GetAllAccounts() require.NoError(t, err) - sdbAccounts, err := s.stateDB.GetAccounts() + sdbAccounts, err := s.stateDB.TestGetAccounts() require.NoError(t, err) assertEqualAccountsHistoryDBStateDB(t, dbAccounts, sdbAccounts) } @@ -338,6 +339,7 @@ func TestSyncGeneral(t *testing.T) { s, err := NewSynchronizer(client, historyDB, stateDB, Config{ StatsRefreshPeriod: 0 * time.Second, }) + log.Error(err) require.NoError(t, err) ctx := context.Background() @@ -651,7 +653,7 @@ func TestSyncGeneral(t *testing.T) { // Accounts in HistoryDB and StateDB must be empty dbAccounts, err := s.historyDB.GetAllAccounts() require.NoError(t, err) - sdbAccounts, err := s.stateDB.GetAccounts() + sdbAccounts, err := s.stateDB.TestGetAccounts() require.NoError(t, err) assert.Equal(t, 0, len(dbAccounts)) assertEqualAccountsHistoryDBStateDB(t, dbAccounts, sdbAccounts) @@ -690,7 +692,7 @@ func TestSyncGeneral(t *testing.T) { // Accounts in HistoryDB and StateDB is only 2 entries dbAccounts, err = s.historyDB.GetAllAccounts() require.NoError(t, err) - sdbAccounts, err = s.stateDB.GetAccounts() + sdbAccounts, err = s.stateDB.TestGetAccounts() require.NoError(t, err) assert.Equal(t, 2, len(dbAccounts)) assertEqualAccountsHistoryDBStateDB(t, dbAccounts, sdbAccounts) diff --git a/test/debugapi/debugapi.go b/test/debugapi/debugapi.go index 2d56c56..9ff170f 100644 --- a/test/debugapi/debugapi.go +++ b/test/debugapi/debugapi.go @@ -55,7 +55,7 @@ func (a *DebugAPI) handleAccount(c *gin.Context) { badReq(err, c) return } - account, err := a.stateDB.GetAccount(common.Idx(uri.Idx)) + account, err := a.stateDB.LastGetAccount(common.Idx(uri.Idx)) if err != nil { badReq(err, c) return @@ -64,8 +64,12 @@ func (a *DebugAPI) handleAccount(c *gin.Context) { } func (a *DebugAPI) handleAccounts(c *gin.Context) { - accounts, err := a.stateDB.GetAccounts() - if err != nil { + var accounts []common.Account + if err := a.stateDB.LastRead(func(sdb *statedb.Last) error { + var err error + accounts, err = sdb.GetAccounts() + return err + }); err != nil { badReq(err, c) return } @@ -73,7 +77,7 @@ func (a *DebugAPI) handleAccounts(c *gin.Context) { } func (a *DebugAPI) handleCurrentBatch(c *gin.Context) { - batchNum, err := a.stateDB.GetCurrentBatch() + batchNum, err := a.stateDB.LastGetCurrentBatch() if err != nil { badReq(err, c) return @@ -82,7 +86,11 @@ func (a *DebugAPI) handleCurrentBatch(c *gin.Context) { } func (a *DebugAPI) handleMTRoot(c *gin.Context) { - root := a.stateDB.MTGetRoot() + root, err := a.stateDB.LastMTGetRoot() + if err != nil { + badReq(err, c) + return + } c.JSON(http.StatusOK, root) } diff --git a/test/debugapi/debugapi_test.go b/test/debugapi/debugapi_test.go index 09498d0..3746860 100644 --- a/test/debugapi/debugapi_test.go +++ b/test/debugapi/debugapi_test.go @@ -66,6 +66,9 @@ func TestDebugAPI(t *testing.T) { _, err = sdb.CreateAccount(account.Idx, account) require.Nil(t, err) } + // Make a checkpoint (batchNum 2) to make the accounts available in Last + err = sdb.MakeCheckpoint() + require.Nil(t, err) url := fmt.Sprintf("http://%v/debug/", addr) @@ -73,7 +76,7 @@ func TestDebugAPI(t *testing.T) { req, err := sling.New().Get(url).Path("sdb/batchnum").ReceiveSuccess(&batchNum) require.Equal(t, http.StatusOK, req.StatusCode) require.Nil(t, err) - assert.Equal(t, common.BatchNum(1), batchNum) + assert.Equal(t, common.BatchNum(2), batchNum) var mtroot *big.Int req, err = sling.New().Get(url).Path("sdb/mtroot").ReceiveSuccess(&mtroot) diff --git a/txprocessor/zkinputsgen_test.go b/txprocessor/zkinputsgen_test.go index 1ca135b..d6a66aa 100644 --- a/txprocessor/zkinputsgen_test.go +++ b/txprocessor/zkinputsgen_test.go @@ -176,7 +176,7 @@ func TestZKInputsEmpty(t *testing.T) { assert.Equal(t, "0", ptOut.ZKInputs.Metadata.NewExitRootRaw.BigInt().String()) // check that there are no accounts - accs, err := sdb.GetAccounts() + accs, err := sdb.TestGetAccounts() require.NoError(t, err) assert.Equal(t, 0, len(accs)) @@ -208,7 +208,7 @@ func TestZKInputsEmpty(t *testing.T) { rootNonZero := sdb.MT.Root() // check that there is 1 account - accs, err = sdb.GetAccounts() + accs, err = sdb.TestGetAccounts() require.NoError(t, err) assert.Equal(t, 1, len(accs)) @@ -234,7 +234,7 @@ func TestZKInputsEmpty(t *testing.T) { assert.Equal(t, "0", ptOut.ZKInputs.Metadata.NewExitRootRaw.BigInt().String()) // check that there is still 1 account - accs, err = sdb.GetAccounts() + accs, err = sdb.TestGetAccounts() require.NoError(t, err) assert.Equal(t, 1, len(accs))