Browse Source

Abstract KVDB from StateDB

- KVDB contains the Checkpoint & Resets system
- StateDB uses KVDB and adds all the StateDB related methods
feature/sql-semaphore1
arnaucube 3 years ago
parent
commit
68bfbff269
7 changed files with 704 additions and 405 deletions
  1. +1
    -1
      coordinator/purger.go
  2. +405
    -0
      db/kvdb/kvdb.go
  3. +195
    -0
      db/kvdb/kvdb_test.go
  4. +44
    -351
      db/statedb/statedb.go
  5. +40
    -34
      db/statedb/statedb_test.go
  6. +3
    -3
      db/statedb/utils.go
  7. +16
    -16
      txprocessor/txprocessor.go

+ 1
- 1
coordinator/purger.go

@ -123,7 +123,7 @@ func poolMarkInvalidOldNonces(l2DB *l2db.L2DB, stateDB *statedb.LocalStateDB,
return tracerr.Wrap(err) return tracerr.Wrap(err)
} }
idxsNonce := make([]common.IdxNonce, len(idxs)) idxsNonce := make([]common.IdxNonce, len(idxs))
lastIdx, err := stateDB.GetIdx()
lastIdx, err := stateDB.GetCurrentIdx()
if err != nil { if err != nil {
return tracerr.Wrap(err) return tracerr.Wrap(err)
} }

+ 405
- 0
db/kvdb/kvdb.go

@ -0,0 +1,405 @@
// Package kvdb provides a key-value database with Checkpoints & Resets system
package kvdb
import (
"fmt"
"io/ioutil"
"os"
"path"
"sort"
"strings"
"github.com/hermeznetwork/hermez-node/common"
"github.com/hermeznetwork/hermez-node/log"
"github.com/hermeznetwork/tracerr"
"github.com/iden3/go-merkletree/db"
"github.com/iden3/go-merkletree/db/pebble"
)
const (
// PathBatchNum defines the subpath of the Batch Checkpoint in the
// subpath of the KVDB
PathBatchNum = "BatchNum"
// PathCurrent defines the subpath of the current Batch in the subpath
// of the KVDB
PathCurrent = "current"
)
var (
// KeyCurrentBatch is used as key in the db to store the current BatchNum
KeyCurrentBatch = []byte("k:currentbatch")
// keyCurrentIdx is used as key in the db to store the CurrentIdx
keyCurrentIdx = []byte("k:idx")
)
// KVDB represents the Key-Value DB object
type KVDB struct {
path string
db *pebble.Storage
// CurrentIdx holds the current Idx that the BatchBuilder is using
CurrentIdx common.Idx
CurrentBatch common.BatchNum
keep int
}
// NewKVDB creates a new KVDB, allowing to use an in-memory or in-disk storage.
// Checkpoints older than the value defined by `keep` will be deleted.
func NewKVDB(pathDB string, keep int) (*KVDB, error) {
var sto *pebble.Storage
var err error
sto, err = pebble.NewPebbleStorage(path.Join(pathDB, PathCurrent), false)
if err != nil {
return nil, tracerr.Wrap(err)
}
kvdb := &KVDB{
path: pathDB,
db: sto,
keep: keep,
}
// load currentBatch
kvdb.CurrentBatch, err = kvdb.GetCurrentBatch()
if err != nil {
return nil, tracerr.Wrap(err)
}
// make reset (get checkpoint) at currentBatch
err = kvdb.reset(kvdb.CurrentBatch, false)
if err != nil {
return nil, tracerr.Wrap(err)
}
return kvdb, nil
}
// DB returns the *pebble.Storage from the KVDB
func (kvdb *KVDB) DB() *pebble.Storage {
return kvdb.db
}
// StorageWithPrefix returns the db.Storage with the given prefix from the
// current KVDB
func (kvdb *KVDB) StorageWithPrefix(prefix []byte) db.Storage {
return kvdb.db.WithPrefix(prefix)
}
// Reset resets the KVDB to the checkpoint at the given batchNum. Reset does
// not delete the checkpoints between old current and the new current, those
// checkpoints will remain in the storage, and eventually will be deleted when
// MakeCheckpoint overwrites them.
func (kvdb *KVDB) Reset(batchNum common.BatchNum) error {
return kvdb.reset(batchNum, true)
}
// reset resets the KVDB to the checkpoint at the given batchNum. Reset does
// not delete the checkpoints between old current and the new current, those
// checkpoints will remain in the storage, and eventually will be deleted when
// MakeCheckpoint overwrites them. `closeCurrent` will close the currently
// opened db before doing the reset.
func (kvdb *KVDB) reset(batchNum common.BatchNum, closeCurrent bool) error {
currentPath := path.Join(kvdb.path, PathCurrent)
if closeCurrent {
if err := kvdb.db.Pebble().Close(); err != nil {
return tracerr.Wrap(err)
}
}
// remove 'current'
err := os.RemoveAll(currentPath)
if err != nil {
return tracerr.Wrap(err)
}
// remove all checkpoints > batchNum
for i := batchNum + 1; i <= kvdb.CurrentBatch; i++ {
if err := kvdb.DeleteCheckpoint(i); err != nil {
return tracerr.Wrap(err)
}
}
if batchNum == 0 {
// if batchNum == 0, open the new fresh 'current'
sto, err := pebble.NewPebbleStorage(currentPath, false)
if err != nil {
return tracerr.Wrap(err)
}
kvdb.db = sto
kvdb.CurrentIdx = 255
kvdb.CurrentBatch = batchNum
return nil
}
checkpointPath := path.Join(kvdb.path, fmt.Sprintf("%s%d", PathBatchNum, batchNum))
// copy 'BatchNumX' to 'current'
err = pebbleMakeCheckpoint(checkpointPath, currentPath)
if err != nil {
return tracerr.Wrap(err)
}
// open the new 'current'
sto, err := pebble.NewPebbleStorage(currentPath, false)
if err != nil {
return tracerr.Wrap(err)
}
kvdb.db = sto
// get currentBatch num
kvdb.CurrentBatch, err = kvdb.GetCurrentBatch()
if err != nil {
return tracerr.Wrap(err)
}
// idx is obtained from the statedb reset
kvdb.CurrentIdx, err = kvdb.GetCurrentIdx()
if err != nil {
return tracerr.Wrap(err)
}
return nil
}
// ResetFromSynchronizer performs a reset in the KVDB getting the state from
// synchronizerKVDB for the given batchNum.
func (kvdb *KVDB) ResetFromSynchronizer(batchNum common.BatchNum, synchronizerKVDB *KVDB) error {
if synchronizerKVDB == nil {
return tracerr.Wrap(fmt.Errorf("synchronizerKVDB can not be nil"))
}
if batchNum == 0 {
kvdb.CurrentIdx = 0
return nil
}
synchronizerCheckpointPath := path.Join(synchronizerKVDB.path,
fmt.Sprintf("%s%d", PathBatchNum, batchNum))
checkpointPath := path.Join(kvdb.path, fmt.Sprintf("%s%d", PathBatchNum, batchNum))
currentPath := path.Join(kvdb.path, PathCurrent)
// use checkpoint from synchronizerKVDB
if _, err := os.Stat(synchronizerCheckpointPath); os.IsNotExist(err) {
// if synchronizerKVDB does not have checkpoint at batchNum, return err
return tracerr.Wrap(fmt.Errorf("Checkpoint \"%v\" not exist in Synchronizer",
synchronizerCheckpointPath))
}
if err := kvdb.db.Pebble().Close(); err != nil {
return tracerr.Wrap(err)
}
// remove 'current'
err := os.RemoveAll(currentPath)
if err != nil {
return tracerr.Wrap(err)
}
// copy synchronizer'BatchNumX' to 'current'
err = pebbleMakeCheckpoint(synchronizerCheckpointPath, currentPath)
if err != nil {
return tracerr.Wrap(err)
}
// copy synchronizer'BatchNumX' to 'BatchNumX'
err = pebbleMakeCheckpoint(synchronizerCheckpointPath, checkpointPath)
if err != nil {
return tracerr.Wrap(err)
}
// open the new 'current'
sto, err := pebble.NewPebbleStorage(currentPath, false)
if err != nil {
return tracerr.Wrap(err)
}
kvdb.db = sto
// get currentBatch num
kvdb.CurrentBatch, err = kvdb.GetCurrentBatch()
if err != nil {
return tracerr.Wrap(err)
}
return nil
}
// GetCurrentBatch returns the current BatchNum stored in the KVDB
func (kvdb *KVDB) GetCurrentBatch() (common.BatchNum, error) {
cbBytes, err := kvdb.db.Get(KeyCurrentBatch)
if tracerr.Unwrap(err) == db.ErrNotFound {
return 0, nil
}
if err != nil {
return 0, tracerr.Wrap(err)
}
return common.BatchNumFromBytes(cbBytes)
}
// setCurrentBatch stores the current BatchNum in the KVDB
func (kvdb *KVDB) setCurrentBatch() error {
tx, err := kvdb.db.NewTx()
if err != nil {
return tracerr.Wrap(err)
}
err = tx.Put(KeyCurrentBatch, kvdb.CurrentBatch.Bytes())
if err != nil {
return tracerr.Wrap(err)
}
if err := tx.Commit(); err != nil {
return tracerr.Wrap(err)
}
return nil
}
// GetCurrentIdx returns the stored Idx from the KVDB, which is the last Idx
// used for an Account in the KVDB.
func (kvdb *KVDB) GetCurrentIdx() (common.Idx, error) {
idxBytes, err := kvdb.db.Get(keyCurrentIdx)
if tracerr.Unwrap(err) == db.ErrNotFound {
return 0, nil
}
if err != nil {
return 0, tracerr.Wrap(err)
}
return common.IdxFromBytes(idxBytes[:])
}
// SetCurrentIdx stores Idx in the KVDB
func (kvdb *KVDB) SetCurrentIdx(idx common.Idx) error {
kvdb.CurrentIdx = idx
tx, err := kvdb.db.NewTx()
if err != nil {
return tracerr.Wrap(err)
}
idxBytes, err := idx.Bytes()
if err != nil {
return tracerr.Wrap(err)
}
err = tx.Put(keyCurrentIdx, idxBytes[:])
if err != nil {
return tracerr.Wrap(err)
}
if err := tx.Commit(); err != nil {
return tracerr.Wrap(err)
}
return nil
}
// MakeCheckpoint does a checkpoint at the given batchNum in the defined path.
// Internally this advances & stores the current BatchNum, and then stores a
// Checkpoint of the current state of the KVDB.
func (kvdb *KVDB) MakeCheckpoint() error {
// advance currentBatch
kvdb.CurrentBatch++
log.Debugw("Making KVDB checkpoint", "batch", kvdb.CurrentBatch)
checkpointPath := path.Join(kvdb.path, fmt.Sprintf("%s%d", PathBatchNum, kvdb.CurrentBatch))
if err := kvdb.setCurrentBatch(); err != nil {
return tracerr.Wrap(err)
}
// if checkpoint BatchNum already exist in disk, delete it
if _, err := os.Stat(checkpointPath); !os.IsNotExist(err) {
err := os.RemoveAll(checkpointPath)
if err != nil {
return tracerr.Wrap(err)
}
} else if err != nil && !os.IsNotExist(err) {
return tracerr.Wrap(err)
}
// execute Checkpoint
if err := kvdb.db.Pebble().Checkpoint(checkpointPath); err != nil {
return tracerr.Wrap(err)
}
// delete old checkpoints
if err := kvdb.deleteOldCheckpoints(); err != nil {
return tracerr.Wrap(err)
}
return nil
}
// DeleteCheckpoint removes if exist the checkpoint of the given batchNum
func (kvdb *KVDB) DeleteCheckpoint(batchNum common.BatchNum) error {
checkpointPath := path.Join(kvdb.path, fmt.Sprintf("%s%d", PathBatchNum, batchNum))
if _, err := os.Stat(checkpointPath); os.IsNotExist(err) {
return tracerr.Wrap(fmt.Errorf("Checkpoint with batchNum %d does not exist in DB", batchNum))
}
return os.RemoveAll(checkpointPath)
}
// ListCheckpoints returns the list of batchNums of the checkpoints, sorted.
// If there's a gap between the list of checkpoints, an error is returned.
func (kvdb *KVDB) ListCheckpoints() ([]int, error) {
files, err := ioutil.ReadDir(kvdb.path)
if err != nil {
return nil, tracerr.Wrap(err)
}
checkpoints := []int{}
var checkpoint int
pattern := fmt.Sprintf("%s%%d", PathBatchNum)
for _, file := range files {
fileName := file.Name()
if file.IsDir() && strings.HasPrefix(fileName, PathBatchNum) {
if _, err := fmt.Sscanf(fileName, pattern, &checkpoint); err != nil {
return nil, tracerr.Wrap(err)
}
checkpoints = append(checkpoints, checkpoint)
}
}
sort.Ints(checkpoints)
if len(checkpoints) > 0 {
first := checkpoints[0]
for _, checkpoint := range checkpoints[1:] {
first++
if checkpoint != first {
return nil, tracerr.Wrap(fmt.Errorf("checkpoint gap at %v", checkpoint))
}
}
}
return checkpoints, nil
}
// deleteOldCheckpoints deletes old checkpoints when there are more than
// `s.keep` checkpoints
func (kvdb *KVDB) deleteOldCheckpoints() error {
list, err := kvdb.ListCheckpoints()
if err != nil {
return tracerr.Wrap(err)
}
if len(list) > kvdb.keep {
for _, checkpoint := range list[:len(list)-kvdb.keep] {
if err := kvdb.DeleteCheckpoint(common.BatchNum(checkpoint)); err != nil {
return tracerr.Wrap(err)
}
}
}
return nil
}
func pebbleMakeCheckpoint(source, dest string) error {
// Remove dest folder (if it exists) before doing the checkpoint
if _, err := os.Stat(dest); !os.IsNotExist(err) {
err := os.RemoveAll(dest)
if err != nil {
return tracerr.Wrap(err)
}
} else if err != nil && !os.IsNotExist(err) {
return tracerr.Wrap(err)
}
sto, err := pebble.NewPebbleStorage(source, false)
if err != nil {
return tracerr.Wrap(err)
}
defer func() {
errClose := sto.Pebble().Close()
if errClose != nil {
log.Errorw("Pebble.Close", "err", errClose)
}
}()
// execute Checkpoint
err = sto.Pebble().Checkpoint(dest)
if err != nil {
return tracerr.Wrap(err)
}
return nil
}

+ 195
- 0
db/kvdb/kvdb_test.go

@ -0,0 +1,195 @@
package kvdb
import (
"fmt"
"io/ioutil"
"os"
"testing"
"github.com/hermeznetwork/hermez-node/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func addTestKV(t *testing.T, db *KVDB, k, v []byte) {
tx, err := db.db.NewTx()
require.NoError(t, err)
err = tx.Put(k, v)
require.NoError(t, err)
err = tx.Commit()
require.NoError(t, err)
}
func printCheckpoints(t *testing.T, path string) {
files, err := ioutil.ReadDir(path)
assert.NoError(t, err)
fmt.Println(path)
for _, f := range files {
fmt.Println(" " + f.Name())
}
}
func TestCheckpoints(t *testing.T) {
dir, err := ioutil.TempDir("", "sdb")
require.NoError(t, err)
defer assert.NoError(t, os.RemoveAll(dir))
db, err := NewKVDB(dir, 128)
assert.NoError(t, err)
// add test key-values
for i := 0; i < 10; i++ {
addTestKV(t, db, []byte{byte(i), byte(i)}, []byte{byte(i * 2), byte(i * 2)})
}
// do checkpoints and check that currentBatch is correct
err = db.MakeCheckpoint()
assert.NoError(t, err)
cb, err := db.GetCurrentBatch()
assert.NoError(t, err)
assert.Equal(t, common.BatchNum(1), cb)
for i := 1; i < 10; i++ {
err = db.MakeCheckpoint()
assert.NoError(t, err)
cb, err = db.GetCurrentBatch()
assert.NoError(t, err)
assert.Equal(t, common.BatchNum(i+1), cb)
}
// printCheckpoints(t, sdb.path)
// reset checkpoint
err = db.Reset(3)
assert.NoError(t, err)
// check that reset can be repeated (as there exist the 'current' and
// 'BatchNum3', from where the 'current' is a copy)
err = db.Reset(3)
require.NoError(t, err)
// check that currentBatch is as expected after Reset
cb, err = db.GetCurrentBatch()
assert.NoError(t, err)
assert.Equal(t, common.BatchNum(3), cb)
// advance one checkpoint and check that currentBatch is fine
err = db.MakeCheckpoint()
assert.NoError(t, err)
cb, err = db.GetCurrentBatch()
assert.NoError(t, err)
assert.Equal(t, common.BatchNum(4), cb)
err = db.DeleteCheckpoint(common.BatchNum(1))
assert.NoError(t, err)
err = db.DeleteCheckpoint(common.BatchNum(2))
assert.NoError(t, err)
err = db.DeleteCheckpoint(common.BatchNum(1)) // does not exist, should return err
assert.NotNil(t, err)
err = db.DeleteCheckpoint(common.BatchNum(2)) // does not exist, should return err
assert.NotNil(t, err)
// Create a new KVDB which will get Reset from the initial KVDB
dirLocal, err := ioutil.TempDir("", "ldb")
require.NoError(t, err)
defer assert.NoError(t, os.RemoveAll(dirLocal))
ldb, err := NewKVDB(dirLocal, 128)
assert.NoError(t, err)
// get checkpoint 4 from sdb (StateDB) to ldb (LocalStateDB)
err = ldb.ResetFromSynchronizer(4, db)
assert.NoError(t, err)
// check that currentBatch is 4 after the Reset
cb, err = ldb.GetCurrentBatch()
assert.NoError(t, err)
assert.Equal(t, common.BatchNum(4), cb)
// advance one checkpoint in ldb
err = ldb.MakeCheckpoint()
assert.NoError(t, err)
cb, err = ldb.GetCurrentBatch()
assert.NoError(t, err)
assert.Equal(t, common.BatchNum(5), cb)
// Create a 3rd KVDB which will get Reset from the initial KVDB
dirLocal2, err := ioutil.TempDir("", "ldb2")
require.NoError(t, err)
defer assert.NoError(t, os.RemoveAll(dirLocal2))
ldb2, err := NewKVDB(dirLocal2, 128)
assert.NoError(t, err)
// get checkpoint 4 from sdb (StateDB) to ldb (LocalStateDB)
err = ldb2.ResetFromSynchronizer(4, db)
assert.NoError(t, err)
// check that currentBatch is 4 after the Reset
cb, err = ldb2.GetCurrentBatch()
assert.NoError(t, err)
assert.Equal(t, common.BatchNum(4), cb)
// advance one checkpoint in ldb2
err = ldb2.MakeCheckpoint()
assert.NoError(t, err)
cb, err = ldb2.GetCurrentBatch()
assert.NoError(t, err)
assert.Equal(t, common.BatchNum(5), cb)
debug := false
if debug {
printCheckpoints(t, db.path)
printCheckpoints(t, ldb.path)
printCheckpoints(t, ldb2.path)
}
}
func TestListCheckpoints(t *testing.T) {
dir, err := ioutil.TempDir("", "tmpdb")
require.NoError(t, err)
defer assert.NoError(t, os.RemoveAll(dir))
db, err := NewKVDB(dir, 128)
require.NoError(t, err)
numCheckpoints := 16
// do checkpoints
for i := 0; i < numCheckpoints; i++ {
err = db.MakeCheckpoint()
require.NoError(t, err)
}
list, err := db.ListCheckpoints()
require.NoError(t, err)
assert.Equal(t, numCheckpoints, len(list))
assert.Equal(t, 1, list[0])
assert.Equal(t, numCheckpoints, list[len(list)-1])
numReset := 10
err = db.Reset(common.BatchNum(numReset))
require.NoError(t, err)
list, err = db.ListCheckpoints()
require.NoError(t, err)
assert.Equal(t, numReset, len(list))
assert.Equal(t, 1, list[0])
assert.Equal(t, numReset, list[len(list)-1])
}
func TestDeleteOldCheckpoints(t *testing.T) {
dir, err := ioutil.TempDir("", "tmpdb")
require.NoError(t, err)
defer assert.NoError(t, os.RemoveAll(dir))
keep := 16
db, err := NewKVDB(dir, keep)
require.NoError(t, err)
numCheckpoints := 32
// do checkpoints and check that we never have more than `keep`
// checkpoints
for i := 0; i < numCheckpoints; i++ {
err = db.MakeCheckpoint()
require.NoError(t, err)
checkpoints, err := db.ListCheckpoints()
require.NoError(t, err)
assert.LessOrEqual(t, len(checkpoints), keep)
}
}

+ 44
- 351
db/statedb/statedb.go

@ -3,23 +3,16 @@ package statedb
import ( import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"math/big" "math/big"
"os"
"path"
"sort"
"strings"
"github.com/hermeznetwork/hermez-node/common" "github.com/hermeznetwork/hermez-node/common"
"github.com/hermeznetwork/hermez-node/db/kvdb"
"github.com/hermeznetwork/hermez-node/log" "github.com/hermeznetwork/hermez-node/log"
"github.com/hermeznetwork/tracerr" "github.com/hermeznetwork/tracerr"
"github.com/iden3/go-merkletree" "github.com/iden3/go-merkletree"
"github.com/iden3/go-merkletree/db" "github.com/iden3/go-merkletree/db"
"github.com/iden3/go-merkletree/db/pebble"
) )
// TODO(Edu): Document here how StateDB is kept consistent
var ( var (
// ErrStateDBWithoutMT is used when a method that requires a MerkleTree // ErrStateDBWithoutMT is used when a method that requires a MerkleTree
// is called in a StateDB that does not have a MerkleTree defined // is called in a StateDB that does not have a MerkleTree defined
@ -36,9 +29,6 @@ var (
// BJJ with not compatible combination // BJJ with not compatible combination
ErrGetIdxNoCase = errors.New("Can not get Idx due unexpected combination of ethereum Address & BabyJubJub PublicKey") ErrGetIdxNoCase = errors.New("Can not get Idx due unexpected combination of ethereum Address & BabyJubJub PublicKey")
// KeyCurrentBatch is used as key in the db to store the current BatchNum
KeyCurrentBatch = []byte("k:currentbatch")
// PrefixKeyIdx is the key prefix for idx in the db // PrefixKeyIdx is the key prefix for idx in the db
PrefixKeyIdx = []byte("i:") PrefixKeyIdx = []byte("i:")
// PrefixKeyAccHash is the key prefix for account hash in the db // PrefixKeyAccHash is the key prefix for account hash in the db
@ -49,17 +39,9 @@ var (
PrefixKeyAddr = []byte("a:") PrefixKeyAddr = []byte("a:")
// PrefixKeyAddrBJJ is the key prefix for address-babyjubjub in the db // PrefixKeyAddrBJJ is the key prefix for address-babyjubjub in the db
PrefixKeyAddrBJJ = []byte("ab:") PrefixKeyAddrBJJ = []byte("ab:")
// keyidx is used as key in the db to store the current Idx
keyidx = []byte("k:idx")
) )
const ( const (
// PathBatchNum defines the subpath of the Batch Checkpoint in the
// subpath of the StateDB
PathBatchNum = "BatchNum"
// PathCurrent defines the subpath of the current Batch in the subpath
// of the StateDB
PathCurrent = "current"
// TypeSynchronizer defines a StateDB used by the Synchronizer, that // TypeSynchronizer defines a StateDB used by the Synchronizer, that
// generates the ExitTree when processing the txs // generates the ExitTree when processing the txs
TypeSynchronizer = "synchronizer" TypeSynchronizer = "synchronizer"
@ -78,28 +60,26 @@ type TypeStateDB string
type StateDB struct { type StateDB struct {
path string path string
Typ TypeStateDB Typ TypeStateDB
// CurrentIdx holds the current Idx that the BatchBuilder is using
CurrentIdx common.Idx
CurrentBatch common.BatchNum
db *pebble.Storage
MT *merkletree.MerkleTree
keep int
db *kvdb.KVDB
MT *merkletree.MerkleTree
keep int
} }
// NewStateDB creates a new StateDB, allowing to use an in-memory or in-disk // NewStateDB creates a new StateDB, allowing to use an in-memory or in-disk
// storage. Checkpoints older than the value defined by `keep` will be // storage. Checkpoints older than the value defined by `keep` will be
// deleted. // deleted.
func NewStateDB(pathDB string, keep int, typ TypeStateDB, nLevels int) (*StateDB, error) { func NewStateDB(pathDB string, keep int, typ TypeStateDB, nLevels int) (*StateDB, error) {
var sto *pebble.Storage
var kv *kvdb.KVDB
var err error var err error
sto, err = pebble.NewPebbleStorage(path.Join(pathDB, PathCurrent), false)
kv, err = kvdb.NewKVDB(pathDB, keep)
if err != nil { if err != nil {
return nil, tracerr.Wrap(err) return nil, tracerr.Wrap(err)
} }
var mt *merkletree.MerkleTree = nil var mt *merkletree.MerkleTree = nil
if typ == TypeSynchronizer || typ == TypeBatchBuilder { if typ == TypeSynchronizer || typ == TypeBatchBuilder {
mt, err = merkletree.NewMerkleTree(sto.WithPrefix(PrefixKeyMT), nLevels)
mt, err = merkletree.NewMerkleTree(kv.StorageWithPrefix(PrefixKeyMT), nLevels)
if err != nil { if err != nil {
return nil, tracerr.Wrap(err) return nil, tracerr.Wrap(err)
} }
@ -108,185 +88,47 @@ func NewStateDB(pathDB string, keep int, typ TypeStateDB, nLevels int) (*StateDB
return nil, tracerr.Wrap(fmt.Errorf("invalid StateDB parameters: StateDB type==TypeStateDB can not have nLevels!=0")) return nil, tracerr.Wrap(fmt.Errorf("invalid StateDB parameters: StateDB type==TypeStateDB can not have nLevels!=0"))
} }
sdb := &StateDB{
return &StateDB{
path: pathDB, path: pathDB,
db: sto,
db: kv,
MT: mt, MT: mt,
Typ: typ, Typ: typ,
keep: keep, keep: keep,
}
// load currentBatch
sdb.CurrentBatch, err = sdb.GetCurrentBatch()
if err != nil {
return nil, tracerr.Wrap(err)
}
// make reset (get checkpoint) at currentBatch
err = sdb.reset(sdb.CurrentBatch, false)
if err != nil {
return nil, tracerr.Wrap(err)
}
return sdb, nil
}
// DB returns the *pebble.Storage from the StateDB
func (s *StateDB) DB() *pebble.Storage {
return s.db
}
// GetCurrentBatch returns the current BatchNum stored in the StateDB
func (s *StateDB) GetCurrentBatch() (common.BatchNum, error) {
cbBytes, err := s.db.Get(KeyCurrentBatch)
if tracerr.Unwrap(err) == db.ErrNotFound {
return 0, nil
}
if err != nil {
return 0, tracerr.Wrap(err)
}
return common.BatchNumFromBytes(cbBytes)
}
// setCurrentBatch stores the current BatchNum in the StateDB
func (s *StateDB) setCurrentBatch() error {
tx, err := s.db.NewTx()
if err != nil {
return tracerr.Wrap(err)
}
err = tx.Put(KeyCurrentBatch, s.CurrentBatch.Bytes())
if err != nil {
return tracerr.Wrap(err)
}
if err := tx.Commit(); err != nil {
return tracerr.Wrap(err)
}
return nil
}, nil
} }
// MakeCheckpoint does a checkpoint at the given batchNum in the defined path. Internally this advances & stores the current BatchNum, and then stores a Checkpoint of the current state of the StateDB.
// MakeCheckpoint does a checkpoint at the given batchNum in the defined path.
// Internally this advances & stores the current BatchNum, and then stores a
// Checkpoint of the current state of the StateDB.
func (s *StateDB) MakeCheckpoint() error { func (s *StateDB) MakeCheckpoint() error {
// advance currentBatch
s.CurrentBatch++
log.Debugw("Making StateDB checkpoint", "batch", s.CurrentBatch, "type", s.Typ)
checkpointPath := path.Join(s.path, fmt.Sprintf("%s%d", PathBatchNum, s.CurrentBatch))
if err := s.setCurrentBatch(); err != nil {
return tracerr.Wrap(err)
}
// if checkpoint BatchNum already exist in disk, delete it
if _, err := os.Stat(checkpointPath); !os.IsNotExist(err) {
err := os.RemoveAll(checkpointPath)
if err != nil {
return tracerr.Wrap(err)
}
} else if err != nil && !os.IsNotExist(err) {
return tracerr.Wrap(err)
}
// execute Checkpoint
if err := s.db.Pebble().Checkpoint(checkpointPath); err != nil {
return tracerr.Wrap(err)
}
// delete old checkpoints
if err := s.deleteOldCheckpoints(); err != nil {
return tracerr.Wrap(err)
}
return nil
log.Debugw("Making StateDB checkpoint", "batch", s.CurrentBatch())
return s.db.MakeCheckpoint()
} }
// DeleteCheckpoint removes if exist the checkpoint of the given batchNum
func (s *StateDB) DeleteCheckpoint(batchNum common.BatchNum) error {
checkpointPath := path.Join(s.path, fmt.Sprintf("%s%d", PathBatchNum, batchNum))
if _, err := os.Stat(checkpointPath); os.IsNotExist(err) {
return tracerr.Wrap(fmt.Errorf("Checkpoint with batchNum %d does not exist in DB", batchNum))
}
return os.RemoveAll(checkpointPath)
// CurrentBatch returns the current in-memory CurrentBatch of the StateDB.db
func (s *StateDB) CurrentBatch() common.BatchNum {
return s.db.CurrentBatch
} }
// listCheckpoints returns the list of batchNums of the checkpoints, sorted.
// If there's a gap between the list of checkpoints, an error is returned.
func (s *StateDB) listCheckpoints() ([]int, error) {
files, err := ioutil.ReadDir(s.path)
if err != nil {
return nil, tracerr.Wrap(err)
}
checkpoints := []int{}
var checkpoint int
pattern := fmt.Sprintf("%s%%d", PathBatchNum)
for _, file := range files {
fileName := file.Name()
if file.IsDir() && strings.HasPrefix(fileName, PathBatchNum) {
if _, err := fmt.Sscanf(fileName, pattern, &checkpoint); err != nil {
return nil, tracerr.Wrap(err)
}
checkpoints = append(checkpoints, checkpoint)
}
}
sort.Ints(checkpoints)
if len(checkpoints) > 0 {
first := checkpoints[0]
for _, checkpoint := range checkpoints[1:] {
first++
if checkpoint != first {
return nil, tracerr.Wrap(fmt.Errorf("checkpoint gap at %v", checkpoint))
}
}
}
return checkpoints, nil
// CurrentIdx returns the current in-memory CurrentIdx of the StateDB.db
func (s *StateDB) CurrentIdx() common.Idx {
return s.db.CurrentIdx
} }
// deleteOldCheckpoints deletes old checkpoints when there are more than
// `s.keep` checkpoints
func (s *StateDB) deleteOldCheckpoints() error {
list, err := s.listCheckpoints()
if err != nil {
return tracerr.Wrap(err)
}
if len(list) > s.keep {
for _, checkpoint := range list[:len(list)-s.keep] {
if err := s.DeleteCheckpoint(common.BatchNum(checkpoint)); err != nil {
return tracerr.Wrap(err)
}
}
}
return nil
// GetCurrentBatch returns the current BatchNum stored in the StateDB.db
func (s *StateDB) GetCurrentBatch() (common.BatchNum, error) {
return s.db.GetCurrentBatch()
} }
func pebbleMakeCheckpoint(source, dest string) error {
// Remove dest folder (if it exists) before doing the checkpoint
if _, err := os.Stat(dest); !os.IsNotExist(err) {
err := os.RemoveAll(dest)
if err != nil {
return tracerr.Wrap(err)
}
} else if err != nil && !os.IsNotExist(err) {
return tracerr.Wrap(err)
}
sto, err := pebble.NewPebbleStorage(source, false)
if err != nil {
return tracerr.Wrap(err)
}
defer func() {
errClose := sto.Pebble().Close()
if errClose != nil {
log.Errorw("Pebble.Close", "err", errClose)
}
}()
// execute Checkpoint
err = sto.Pebble().Checkpoint(dest)
if err != nil {
return tracerr.Wrap(err)
}
// GetCurrentIdx returns the stored Idx from the localStateDB, which is the
// last Idx used for an Account in the localStateDB.
func (s *StateDB) GetCurrentIdx() (common.Idx, error) {
return s.db.GetCurrentIdx()
}
return nil
// SetCurrentIdx stores Idx in the StateDB
func (s *StateDB) SetCurrentIdx(idx common.Idx) error {
return s.db.SetCurrentIdx(idx)
} }
// Reset resets the StateDB to the checkpoint at the given batchNum. Reset // Reset resets the StateDB to the checkpoint at the given batchNum. Reset
@ -294,135 +136,30 @@ func pebbleMakeCheckpoint(source, dest string) error {
// those checkpoints will remain in the storage, and eventually will be // those checkpoints will remain in the storage, and eventually will be
// deleted when MakeCheckpoint overwrites them. // deleted when MakeCheckpoint overwrites them.
func (s *StateDB) Reset(batchNum common.BatchNum) error { func (s *StateDB) Reset(batchNum common.BatchNum) error {
return s.reset(batchNum, true)
}
// reset resets the StateDB to the checkpoint at the given batchNum. Reset
// does not delete the checkpoints between old current and the new current,
// those checkpoints will remain in the storage, and eventually will be
// deleted when MakeCheckpoint overwrites them. `closeCurrent` will close the
// currently opened db before doing the reset.
func (s *StateDB) reset(batchNum common.BatchNum, closeCurrent bool) error {
currentPath := path.Join(s.path, PathCurrent)
if closeCurrent {
if err := s.db.Pebble().Close(); err != nil {
return tracerr.Wrap(err)
}
}
// remove 'current'
err := os.RemoveAll(currentPath)
err := s.db.Reset(batchNum)
if err != nil { if err != nil {
return tracerr.Wrap(err) return tracerr.Wrap(err)
} }
// remove all checkpoints > batchNum
for i := batchNum + 1; i <= s.CurrentBatch; i++ {
if err := s.DeleteCheckpoint(i); err != nil {
return tracerr.Wrap(err)
}
}
if batchNum == 0 {
// if batchNum == 0, open the new fresh 'current'
sto, err := pebble.NewPebbleStorage(currentPath, false)
if err != nil {
return tracerr.Wrap(err)
}
s.db = sto
s.CurrentIdx = 255
s.CurrentBatch = batchNum
if s.MT != nil {
// open the MT for the current s.db
mt, err := merkletree.NewMerkleTree(s.db.WithPrefix(PrefixKeyMT), s.MT.MaxLevels())
if err != nil {
return tracerr.Wrap(err)
}
s.MT = mt
}
return nil
}
checkpointPath := path.Join(s.path, fmt.Sprintf("%s%d", PathBatchNum, batchNum))
// copy 'BatchNumX' to 'current'
err = pebbleMakeCheckpoint(checkpointPath, currentPath)
if err != nil {
return tracerr.Wrap(err)
}
// open the new 'current'
sto, err := pebble.NewPebbleStorage(currentPath, false)
if err != nil {
return tracerr.Wrap(err)
}
s.db = sto
// get currentBatch num
s.CurrentBatch, err = s.GetCurrentBatch()
if err != nil {
return tracerr.Wrap(err)
}
// idx is obtained from the statedb reset
s.CurrentIdx, err = s.GetIdx()
if err != nil {
return tracerr.Wrap(err)
}
if s.MT != nil { if s.MT != nil {
// open the MT for the current s.db // open the MT for the current s.db
mt, err := merkletree.NewMerkleTree(s.db.WithPrefix(PrefixKeyMT), s.MT.MaxLevels())
mt, err := merkletree.NewMerkleTree(s.db.StorageWithPrefix(PrefixKeyMT), s.MT.MaxLevels())
if err != nil { if err != nil {
return tracerr.Wrap(err) return tracerr.Wrap(err)
} }
s.MT = mt s.MT = mt
} }
return nil
}
// GetIdx returns the stored Idx from the localStateDB, which is the last Idx
// used for an Account in the localStateDB.
func (s *StateDB) GetIdx() (common.Idx, error) {
idxBytes, err := s.DB().Get(keyidx)
if tracerr.Unwrap(err) == db.ErrNotFound {
return 0, nil
}
if err != nil {
return 0, tracerr.Wrap(err)
}
return common.IdxFromBytes(idxBytes[:])
}
// SetIdx stores Idx in the localStateDB
func (s *StateDB) SetIdx(idx common.Idx) error {
s.CurrentIdx = idx
tx, err := s.DB().NewTx()
if err != nil {
return tracerr.Wrap(err)
}
idxBytes, err := idx.Bytes()
if err != nil {
return tracerr.Wrap(err)
}
err = tx.Put(keyidx, idxBytes[:])
if err != nil {
return tracerr.Wrap(err)
}
if err := tx.Commit(); err != nil {
return tracerr.Wrap(err)
}
return nil return nil
} }
// GetAccount returns the account for the given Idx // GetAccount returns the account for the given Idx
func (s *StateDB) GetAccount(idx common.Idx) (*common.Account, error) { func (s *StateDB) GetAccount(idx common.Idx) (*common.Account, error) {
return GetAccountInTreeDB(s.db, idx)
return GetAccountInTreeDB(s.db.DB(), idx)
} }
// GetAccounts returns all the accounts in the db. Use for debugging pruposes // GetAccounts returns all the accounts in the db. Use for debugging pruposes
// only. // only.
func (s *StateDB) GetAccounts() ([]common.Account, error) { func (s *StateDB) GetAccounts() ([]common.Account, error) {
idxDB := s.db.WithPrefix(PrefixKeyIdx)
idxDB := s.db.StorageWithPrefix(PrefixKeyIdx)
idxs := []common.Idx{} idxs := []common.Idx{}
// NOTE: Current implementation of Iterate in the pebble interface is // NOTE: Current implementation of Iterate in the pebble interface is
// not efficient, as it iterates over all keys. Improve it following // not efficient, as it iterates over all keys. Improve it following
@ -477,7 +214,7 @@ func GetAccountInTreeDB(sto db.Storage, idx common.Idx) (*common.Account, error)
// StateDB.MT==nil, MerkleTree is not affected, otherwise updates the // StateDB.MT==nil, MerkleTree is not affected, otherwise updates the
// MerkleTree, returning a CircomProcessorProof. // MerkleTree, returning a CircomProcessorProof.
func (s *StateDB) CreateAccount(idx common.Idx, account *common.Account) (*merkletree.CircomProcessorProof, error) { func (s *StateDB) CreateAccount(idx common.Idx, account *common.Account) (*merkletree.CircomProcessorProof, error) {
cpp, err := CreateAccountInTreeDB(s.db, s.MT, idx, account)
cpp, err := CreateAccountInTreeDB(s.db.DB(), s.MT, idx, account)
if err != nil { if err != nil {
return cpp, tracerr.Wrap(err) return cpp, tracerr.Wrap(err)
} }
@ -540,7 +277,7 @@ func CreateAccountInTreeDB(sto db.Storage, mt *merkletree.MerkleTree, idx common
// StateDB.mt==nil, MerkleTree is not affected, otherwise updates the // StateDB.mt==nil, MerkleTree is not affected, otherwise updates the
// MerkleTree, returning a CircomProcessorProof. // MerkleTree, returning a CircomProcessorProof.
func (s *StateDB) UpdateAccount(idx common.Idx, account *common.Account) (*merkletree.CircomProcessorProof, error) { func (s *StateDB) UpdateAccount(idx common.Idx, account *common.Account) (*merkletree.CircomProcessorProof, error) {
return UpdateAccountInTreeDB(s.db, s.MT, idx, account)
return UpdateAccountInTreeDB(s.db.DB(), s.MT, idx, account)
} }
// UpdateAccountInTreeDB is abstracted from StateDB to be used from StateDB and // UpdateAccountInTreeDB is abstracted from StateDB to be used from StateDB and
@ -623,68 +360,24 @@ func NewLocalStateDB(path string, keep int, synchronizerDB *StateDB, typ TypeSta
} }
// Reset performs a reset in the LocaStateDB. If fromSynchronizer is true, it // Reset performs a reset in the LocaStateDB. If fromSynchronizer is true, it
// gets the state from LocalStateDB.synchronizerStateDB for the given batchNum. If fromSynchronizer is false, get the state from LocalStateDB checkpoints.
// gets the state from LocalStateDB.synchronizerStateDB for the given batchNum.
// If fromSynchronizer is false, get the state from LocalStateDB checkpoints.
func (l *LocalStateDB) Reset(batchNum common.BatchNum, fromSynchronizer bool) error { func (l *LocalStateDB) Reset(batchNum common.BatchNum, fromSynchronizer bool) error {
if batchNum == 0 {
l.CurrentIdx = 0
return nil
}
synchronizerCheckpointPath := path.Join(l.synchronizerStateDB.path,
fmt.Sprintf("%s%d", PathBatchNum, batchNum))
checkpointPath := path.Join(l.path, fmt.Sprintf("%s%d", PathBatchNum, batchNum))
currentPath := path.Join(l.path, PathCurrent)
if fromSynchronizer { if fromSynchronizer {
// use checkpoint from SynchronizerStateDB
if _, err := os.Stat(synchronizerCheckpointPath); os.IsNotExist(err) {
// if synchronizerStateDB does not have checkpoint at batchNum, return err
return tracerr.Wrap(fmt.Errorf("Checkpoint \"%v\" not exist in Synchronizer",
synchronizerCheckpointPath))
}
if err := l.db.Pebble().Close(); err != nil {
return tracerr.Wrap(err)
}
// remove 'current'
err := os.RemoveAll(currentPath)
if err != nil {
return tracerr.Wrap(err)
}
// copy synchronizer'BatchNumX' to 'current'
err = pebbleMakeCheckpoint(synchronizerCheckpointPath, currentPath)
if err != nil {
return tracerr.Wrap(err)
}
// copy synchronizer'BatchNumX' to 'BatchNumX'
err = pebbleMakeCheckpoint(synchronizerCheckpointPath, checkpointPath)
if err != nil {
return tracerr.Wrap(err)
}
// open the new 'current'
sto, err := pebble.NewPebbleStorage(currentPath, false)
if err != nil {
return tracerr.Wrap(err)
}
l.db = sto
// get currentBatch num
l.CurrentBatch, err = l.GetCurrentBatch()
err := l.db.ResetFromSynchronizer(batchNum, l.synchronizerStateDB.db)
if err != nil { if err != nil {
return tracerr.Wrap(err) return tracerr.Wrap(err)
} }
// open the MT for the current s.db // open the MT for the current s.db
if l.MT != nil { if l.MT != nil {
mt, err := merkletree.NewMerkleTree(l.db.WithPrefix(PrefixKeyMT), l.MT.MaxLevels())
mt, err := merkletree.NewMerkleTree(l.db.StorageWithPrefix(PrefixKeyMT), l.MT.MaxLevels())
if err != nil { if err != nil {
return tracerr.Wrap(err) return tracerr.Wrap(err)
} }
l.MT = mt l.MT = mt
} }
return nil return nil
} }
// use checkpoint from LocalStateDB // use checkpoint from LocalStateDB
return l.StateDB.reset(batchNum, true)
return l.StateDB.Reset(batchNum)
} }

+ 40
- 34
db/statedb/statedb_test.go

@ -55,13 +55,13 @@ func TestNewStateDBIntermediateState(t *testing.T) {
v1 := []byte("testvalue1") v1 := []byte("testvalue1")
// store some data // store some data
tx, err := sdb.db.NewTx()
tx, err := sdb.db.DB().NewTx()
assert.NoError(t, err) assert.NoError(t, err)
err = tx.Put(k0, v0) err = tx.Put(k0, v0)
assert.NoError(t, err) assert.NoError(t, err)
err = tx.Commit() err = tx.Commit()
assert.NoError(t, err) assert.NoError(t, err)
v, err := sdb.db.Get(k0)
v, err := sdb.db.DB().Get(k0)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, v0, v) assert.Equal(t, v0, v)
@ -69,41 +69,41 @@ func TestNewStateDBIntermediateState(t *testing.T) {
// executing a Reset (discarding the last 'testkey0'&'testvalue0' data) // executing a Reset (discarding the last 'testkey0'&'testvalue0' data)
sdb, err = NewStateDB(dir, 128, TypeTxSelector, 0) sdb, err = NewStateDB(dir, 128, TypeTxSelector, 0)
assert.NoError(t, err) assert.NoError(t, err)
v, err = sdb.db.Get(k0)
v, err = sdb.db.DB().Get(k0)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, db.ErrNotFound, tracerr.Unwrap(err)) assert.Equal(t, db.ErrNotFound, tracerr.Unwrap(err))
assert.Nil(t, v) assert.Nil(t, v)
// store the same data from the beginning that has ben lost since last NewStateDB // store the same data from the beginning that has ben lost since last NewStateDB
tx, err = sdb.db.NewTx()
tx, err = sdb.db.DB().NewTx()
assert.NoError(t, err) assert.NoError(t, err)
err = tx.Put(k0, v0) err = tx.Put(k0, v0)
assert.NoError(t, err) assert.NoError(t, err)
err = tx.Commit() err = tx.Commit()
assert.NoError(t, err) assert.NoError(t, err)
v, err = sdb.db.Get(k0)
v, err = sdb.db.DB().Get(k0)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, v0, v) assert.Equal(t, v0, v)
// make checkpoints with the current state // make checkpoints with the current state
bn, err := sdb.GetCurrentBatch()
bn, err := sdb.db.GetCurrentBatch()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, common.BatchNum(0), bn) assert.Equal(t, common.BatchNum(0), bn)
err = sdb.MakeCheckpoint()
err = sdb.db.MakeCheckpoint()
assert.NoError(t, err) assert.NoError(t, err)
bn, err = sdb.GetCurrentBatch()
bn, err = sdb.db.GetCurrentBatch()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, common.BatchNum(1), bn) assert.Equal(t, common.BatchNum(1), bn)
// write more data // write more data
tx, err = sdb.db.NewTx()
tx, err = sdb.db.DB().NewTx()
assert.NoError(t, err) assert.NoError(t, err)
err = tx.Put(k1, v1) err = tx.Put(k1, v1)
assert.NoError(t, err) assert.NoError(t, err)
err = tx.Commit() err = tx.Commit()
assert.NoError(t, err) assert.NoError(t, err)
v, err = sdb.db.Get(k1)
v, err = sdb.db.DB().Get(k1)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, v1, v) assert.Equal(t, v1, v)
@ -112,11 +112,11 @@ func TestNewStateDBIntermediateState(t *testing.T) {
sdb, err = NewStateDB(dir, 128, TypeTxSelector, 0) sdb, err = NewStateDB(dir, 128, TypeTxSelector, 0)
assert.NoError(t, err) assert.NoError(t, err)
v, err = sdb.db.Get(k0)
v, err = sdb.db.DB().Get(k0)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, v0, v) assert.Equal(t, v0, v)
v, err = sdb.db.Get(k1)
v, err = sdb.db.DB().Get(k1)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, db.ErrNotFound, tracerr.Unwrap(err)) assert.Equal(t, db.ErrNotFound, tracerr.Unwrap(err))
assert.Nil(t, v) assert.Nil(t, v)
@ -228,6 +228,8 @@ func TestStateDBWithMT(t *testing.T) {
assert.Equal(t, accounts[0].Nonce, a.Nonce) assert.Equal(t, accounts[0].Nonce, a.Nonce)
} }
// TestCheckpoints performs almost the same test than kvdb/kvdb_test.go
// TestCheckpoints, but over the StateDB
func TestCheckpoints(t *testing.T) { func TestCheckpoints(t *testing.T) {
dir, err := ioutil.TempDir("", "sdb") dir, err := ioutil.TempDir("", "sdb")
require.NoError(t, err) require.NoError(t, err)
@ -249,17 +251,17 @@ func TestCheckpoints(t *testing.T) {
} }
// do checkpoints and check that currentBatch is correct // do checkpoints and check that currentBatch is correct
err = sdb.MakeCheckpoint()
err = sdb.db.MakeCheckpoint()
assert.NoError(t, err) assert.NoError(t, err)
cb, err := sdb.GetCurrentBatch()
cb, err := sdb.db.GetCurrentBatch()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, common.BatchNum(1), cb) assert.Equal(t, common.BatchNum(1), cb)
for i := 1; i < 10; i++ { for i := 1; i < 10; i++ {
err = sdb.MakeCheckpoint()
err = sdb.db.MakeCheckpoint()
assert.NoError(t, err) assert.NoError(t, err)
cb, err = sdb.GetCurrentBatch()
cb, err = sdb.db.GetCurrentBatch()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, common.BatchNum(i+1), cb) assert.Equal(t, common.BatchNum(i+1), cb)
} }
@ -276,24 +278,24 @@ func TestCheckpoints(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// check that currentBatch is as expected after Reset // check that currentBatch is as expected after Reset
cb, err = sdb.GetCurrentBatch()
cb, err = sdb.db.GetCurrentBatch()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, common.BatchNum(3), cb) assert.Equal(t, common.BatchNum(3), cb)
// advance one checkpoint and check that currentBatch is fine // advance one checkpoint and check that currentBatch is fine
err = sdb.MakeCheckpoint()
err = sdb.db.MakeCheckpoint()
assert.NoError(t, err) assert.NoError(t, err)
cb, err = sdb.GetCurrentBatch()
cb, err = sdb.db.GetCurrentBatch()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, common.BatchNum(4), cb) assert.Equal(t, common.BatchNum(4), cb)
err = sdb.DeleteCheckpoint(common.BatchNum(1))
err = sdb.db.DeleteCheckpoint(common.BatchNum(1))
assert.NoError(t, err) assert.NoError(t, err)
err = sdb.DeleteCheckpoint(common.BatchNum(2))
err = sdb.db.DeleteCheckpoint(common.BatchNum(2))
assert.NoError(t, err) assert.NoError(t, err)
err = sdb.DeleteCheckpoint(common.BatchNum(1)) // does not exist, should return err
err = sdb.db.DeleteCheckpoint(common.BatchNum(1)) // does not exist, should return err
assert.NotNil(t, err) assert.NotNil(t, err)
err = sdb.DeleteCheckpoint(common.BatchNum(2)) // does not exist, should return err
err = sdb.db.DeleteCheckpoint(common.BatchNum(2)) // does not exist, should return err
assert.NotNil(t, err) assert.NotNil(t, err)
// Create a LocalStateDB from the initial StateDB // Create a LocalStateDB from the initial StateDB
@ -307,13 +309,13 @@ func TestCheckpoints(t *testing.T) {
err = ldb.Reset(4, true) err = ldb.Reset(4, true)
assert.NoError(t, err) assert.NoError(t, err)
// check that currentBatch is 4 after the Reset // check that currentBatch is 4 after the Reset
cb, err = ldb.GetCurrentBatch()
cb, err = ldb.db.GetCurrentBatch()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, common.BatchNum(4), cb) assert.Equal(t, common.BatchNum(4), cb)
// advance one checkpoint in ldb // advance one checkpoint in ldb
err = ldb.MakeCheckpoint()
err = ldb.db.MakeCheckpoint()
assert.NoError(t, err) assert.NoError(t, err)
cb, err = ldb.GetCurrentBatch()
cb, err = ldb.db.GetCurrentBatch()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, common.BatchNum(5), cb) assert.Equal(t, common.BatchNum(5), cb)
@ -328,13 +330,13 @@ func TestCheckpoints(t *testing.T) {
err = ldb2.Reset(4, true) err = ldb2.Reset(4, true)
assert.NoError(t, err) assert.NoError(t, err)
// check that currentBatch is 4 after the Reset // check that currentBatch is 4 after the Reset
cb, err = ldb2.GetCurrentBatch()
cb, err = ldb2.db.GetCurrentBatch()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, common.BatchNum(4), cb) assert.Equal(t, common.BatchNum(4), cb)
// advance one checkpoint in ldb2 // advance one checkpoint in ldb2
err = ldb2.MakeCheckpoint()
err = ldb2.db.MakeCheckpoint()
assert.NoError(t, err) assert.NoError(t, err)
cb, err = ldb2.GetCurrentBatch()
cb, err = ldb2.db.GetCurrentBatch()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, common.BatchNum(5), cb) assert.Equal(t, common.BatchNum(5), cb)
@ -464,6 +466,8 @@ func TestCheckAccountsTreeTestVectors(t *testing.T) {
assert.Equal(t, "17298264051379321456969039521810887093935433569451713402227686942080129181291", sdb.MT.Root().BigInt().String()) assert.Equal(t, "17298264051379321456969039521810887093935433569451713402227686942080129181291", sdb.MT.Root().BigInt().String())
} }
// TestListCheckpoints performs almost the same test than kvdb/kvdb_test.go
// TestListCheckpoints, but over the StateDB
func TestListCheckpoints(t *testing.T) { func TestListCheckpoints(t *testing.T) {
dir, err := ioutil.TempDir("", "tmpdb") dir, err := ioutil.TempDir("", "tmpdb")
require.NoError(t, err) require.NoError(t, err)
@ -475,10 +479,10 @@ func TestListCheckpoints(t *testing.T) {
numCheckpoints := 16 numCheckpoints := 16
// do checkpoints // do checkpoints
for i := 0; i < numCheckpoints; i++ { for i := 0; i < numCheckpoints; i++ {
err = sdb.MakeCheckpoint()
err = sdb.db.MakeCheckpoint()
require.NoError(t, err) require.NoError(t, err)
} }
list, err := sdb.listCheckpoints()
list, err := sdb.db.ListCheckpoints()
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, numCheckpoints, len(list)) assert.Equal(t, numCheckpoints, len(list))
assert.Equal(t, 1, list[0]) assert.Equal(t, 1, list[0])
@ -487,13 +491,15 @@ func TestListCheckpoints(t *testing.T) {
numReset := 10 numReset := 10
err = sdb.Reset(common.BatchNum(numReset)) err = sdb.Reset(common.BatchNum(numReset))
require.NoError(t, err) require.NoError(t, err)
list, err = sdb.listCheckpoints()
list, err = sdb.db.ListCheckpoints()
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, numReset, len(list)) assert.Equal(t, numReset, len(list))
assert.Equal(t, 1, list[0]) assert.Equal(t, 1, list[0])
assert.Equal(t, numReset, list[len(list)-1]) assert.Equal(t, numReset, list[len(list)-1])
} }
// TestDeleteOldCheckpoints performs almost the same test than
// kvdb/kvdb_test.go TestDeleteOldCheckpoints, but over the StateDB
func TestDeleteOldCheckpoints(t *testing.T) { func TestDeleteOldCheckpoints(t *testing.T) {
dir, err := ioutil.TempDir("", "tmpdb") dir, err := ioutil.TempDir("", "tmpdb")
require.NoError(t, err) require.NoError(t, err)
@ -507,9 +513,9 @@ func TestDeleteOldCheckpoints(t *testing.T) {
// do checkpoints and check that we never have more than `keep` // do checkpoints and check that we never have more than `keep`
// checkpoints // checkpoints
for i := 0; i < numCheckpoints; i++ { for i := 0; i < numCheckpoints; i++ {
err = sdb.MakeCheckpoint()
err = sdb.db.MakeCheckpoint()
require.NoError(t, err) require.NoError(t, err)
checkpoints, err := sdb.listCheckpoints()
checkpoints, err := sdb.db.ListCheckpoints()
require.NoError(t, err) require.NoError(t, err)
assert.LessOrEqual(t, len(checkpoints), keep) assert.LessOrEqual(t, len(checkpoints), keep)
} }

+ 3
- 3
db/statedb/utils.go

@ -48,7 +48,7 @@ func (s *StateDB) setIdxByEthAddrBJJ(idx common.Idx, addr ethCommon.Address, pk
// have an Idx stored in the DB, and if so, the already stored Idx is // have an Idx stored in the DB, and if so, the already stored Idx is
// bigger than the given one, so should be updated to the new one // bigger than the given one, so should be updated to the new one
// (smaller) // (smaller)
tx, err := s.db.NewTx()
tx, err := s.db.DB().NewTx()
if err != nil { if err != nil {
return tracerr.Wrap(err) return tracerr.Wrap(err)
} }
@ -81,7 +81,7 @@ func (s *StateDB) setIdxByEthAddrBJJ(idx common.Idx, addr ethCommon.Address, pk
// not found in the StateDB. // not found in the StateDB.
func (s *StateDB) GetIdxByEthAddr(addr ethCommon.Address, tokenID common.TokenID) (common.Idx, error) { func (s *StateDB) GetIdxByEthAddr(addr ethCommon.Address, tokenID common.TokenID) (common.Idx, error) {
k := concatEthAddrTokenID(addr, tokenID) k := concatEthAddrTokenID(addr, tokenID)
b, err := s.db.Get(append(PrefixKeyAddr, k...))
b, err := s.db.DB().Get(append(PrefixKeyAddr, k...))
if err != nil { if err != nil {
return common.Idx(0), tracerr.Wrap(fmt.Errorf("GetIdxByEthAddr: %s: ToEthAddr: %s, TokenID: %d", return common.Idx(0), tracerr.Wrap(fmt.Errorf("GetIdxByEthAddr: %s: ToEthAddr: %s, TokenID: %d",
ErrToIdxNotFound, addr.Hex(), tokenID)) ErrToIdxNotFound, addr.Hex(), tokenID))
@ -107,7 +107,7 @@ func (s *StateDB) GetIdxByEthAddrBJJ(addr ethCommon.Address, pk babyjub.PublicKe
} else if !bytes.Equal(addr.Bytes(), common.EmptyAddr.Bytes()) && pk != common.EmptyBJJComp { } else if !bytes.Equal(addr.Bytes(), common.EmptyAddr.Bytes()) && pk != common.EmptyBJJComp {
// case ToEthAddr!=0 && ToBJJ!=0 // case ToEthAddr!=0 && ToBJJ!=0
k := concatEthAddrBJJTokenID(addr, pk, tokenID) k := concatEthAddrBJJTokenID(addr, pk, tokenID)
b, err := s.db.Get(append(PrefixKeyAddrBJJ, k...))
b, err := s.db.DB().Get(append(PrefixKeyAddrBJJ, k...))
if err != nil { if err != nil {
return common.Idx(0), tracerr.Wrap(fmt.Errorf("GetIdxByEthAddrBJJ: %s: ToEthAddr: %s, ToBJJ: %s, TokenID: %d", ErrToIdxNotFound, addr.Hex(), pk, tokenID)) return common.Idx(0), tracerr.Wrap(fmt.Errorf("GetIdxByEthAddrBJJ: %s: ToEthAddr: %s, ToBJJ: %s, TokenID: %d", ErrToIdxNotFound, addr.Hex(), pk, tokenID))
} }

+ 16
- 16
txprocessor/txprocessor.go

@ -114,10 +114,10 @@ func (tp *TxProcessor) ProcessTxs(coordIdxs []common.Idx, l1usertxs, l1coordinat
if tp.s.Typ == statedb.TypeBatchBuilder { if tp.s.Typ == statedb.TypeBatchBuilder {
tp.zki = common.NewZKInputs(tp.config.ChainID, tp.config.MaxTx, tp.config.MaxL1Tx, tp.zki = common.NewZKInputs(tp.config.ChainID, tp.config.MaxTx, tp.config.MaxL1Tx,
tp.config.MaxFeeTx, tp.config.NLevels, tp.s.CurrentBatch.BigInt())
tp.zki.OldLastIdx = tp.s.CurrentIdx.BigInt()
tp.config.MaxFeeTx, tp.config.NLevels, tp.s.CurrentBatch().BigInt())
tp.zki.OldLastIdx = tp.s.CurrentIdx().BigInt()
tp.zki.OldStateRoot = tp.s.MT.Root().BigInt() tp.zki.OldStateRoot = tp.s.MT.Root().BigInt()
tp.zki.Metadata.NewLastIdxRaw = tp.s.CurrentIdx
tp.zki.Metadata.NewLastIdxRaw = tp.s.CurrentIdx()
} }
// TBD if ExitTree is only in memory or stored in disk, for the moment // TBD if ExitTree is only in memory or stored in disk, for the moment
@ -169,7 +169,7 @@ func (tp *TxProcessor) ProcessTxs(coordIdxs []common.Idx, l1usertxs, l1coordinat
tp.zki.Metadata.L1TxsDataAvailability = tp.zki.Metadata.L1TxsDataAvailability =
append(tp.zki.Metadata.L1TxsDataAvailability, l1TxDataAvailability) append(tp.zki.Metadata.L1TxsDataAvailability, l1TxDataAvailability)
tp.zki.ISOutIdx[tp.i] = tp.s.CurrentIdx.BigInt()
tp.zki.ISOutIdx[tp.i] = tp.s.CurrentIdx().BigInt()
tp.zki.ISStateRoot[tp.i] = tp.s.MT.Root().BigInt() tp.zki.ISStateRoot[tp.i] = tp.s.MT.Root().BigInt()
if exitIdx == nil { if exitIdx == nil {
tp.zki.ISExitRoot[tp.i] = exitTree.Root().BigInt() tp.zki.ISExitRoot[tp.i] = exitTree.Root().BigInt()
@ -214,7 +214,7 @@ func (tp *TxProcessor) ProcessTxs(coordIdxs []common.Idx, l1usertxs, l1coordinat
tp.zki.Metadata.L1TxsDataAvailability = tp.zki.Metadata.L1TxsDataAvailability =
append(tp.zki.Metadata.L1TxsDataAvailability, l1TxDataAvailability) append(tp.zki.Metadata.L1TxsDataAvailability, l1TxDataAvailability)
tp.zki.ISOutIdx[tp.i] = tp.s.CurrentIdx.BigInt()
tp.zki.ISOutIdx[tp.i] = tp.s.CurrentIdx().BigInt()
tp.zki.ISStateRoot[tp.i] = tp.s.MT.Root().BigInt() tp.zki.ISStateRoot[tp.i] = tp.s.MT.Root().BigInt()
tp.i++ tp.i++
} }
@ -268,7 +268,7 @@ func (tp *TxProcessor) ProcessTxs(coordIdxs []common.Idx, l1usertxs, l1coordinat
// Intermediate States // Intermediate States
if tp.i < nTx-1 { if tp.i < nTx-1 {
tp.zki.ISOutIdx[tp.i] = tp.s.CurrentIdx.BigInt()
tp.zki.ISOutIdx[tp.i] = tp.s.CurrentIdx().BigInt()
tp.zki.ISStateRoot[tp.i] = tp.s.MT.Root().BigInt() tp.zki.ISStateRoot[tp.i] = tp.s.MT.Root().BigInt()
tp.zki.ISAccFeeOut[tp.i] = formatAccumulatedFees(collectedFees, tp.zki.FeePlanTokens) tp.zki.ISAccFeeOut[tp.i] = formatAccumulatedFees(collectedFees, tp.zki.FeePlanTokens)
if exitIdx == nil { if exitIdx == nil {
@ -296,7 +296,7 @@ func (tp *TxProcessor) ProcessTxs(coordIdxs []common.Idx, l1usertxs, l1coordinat
} }
for i := last; i < int(tp.config.MaxTx); i++ { for i := last; i < int(tp.config.MaxTx); i++ {
if i < int(tp.config.MaxTx)-1 { if i < int(tp.config.MaxTx)-1 {
tp.zki.ISOutIdx[i] = tp.s.CurrentIdx.BigInt()
tp.zki.ISOutIdx[i] = tp.s.CurrentIdx().BigInt()
tp.zki.ISStateRoot[i] = tp.s.MT.Root().BigInt() tp.zki.ISStateRoot[i] = tp.s.MT.Root().BigInt()
tp.zki.ISAccFeeOut[i] = formatAccumulatedFees(collectedFees, tp.zki.FeePlanTokens) tp.zki.ISAccFeeOut[i] = formatAccumulatedFees(collectedFees, tp.zki.FeePlanTokens)
tp.zki.ISExitRoot[i] = exitTree.Root().BigInt() tp.zki.ISExitRoot[i] = exitTree.Root().BigInt()
@ -541,7 +541,7 @@ func (tp *TxProcessor) ProcessL1Tx(exitTree *merkletree.MerkleTree, tx *common.L
(tx.Type == common.TxTypeCreateAccountDeposit || (tx.Type == common.TxTypeCreateAccountDeposit ||
tx.Type == common.TxTypeCreateAccountDepositTransfer) { tx.Type == common.TxTypeCreateAccountDepositTransfer) {
var err error var err error
createdAccount, err = tp.s.GetAccount(tp.s.CurrentIdx)
createdAccount, err = tp.s.GetAccount(tp.s.CurrentIdx())
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return nil, nil, false, nil, tracerr.Wrap(err) return nil, nil, false, nil, tracerr.Wrap(err)
@ -664,7 +664,7 @@ func (tp *TxProcessor) applyCreateAccount(tx *common.L1Tx) error {
EthAddr: tx.FromEthAddr, EthAddr: tx.FromEthAddr,
} }
p, err := tp.s.CreateAccount(common.Idx(tp.s.CurrentIdx+1), account)
p, err := tp.s.CreateAccount(common.Idx(tp.s.CurrentIdx()+1), account)
if err != nil { if err != nil {
return tracerr.Wrap(err) return tracerr.Wrap(err)
} }
@ -685,9 +685,9 @@ func (tp *TxProcessor) applyCreateAccount(tx *common.L1Tx) error {
tp.zki.OldKey1[tp.i] = p.OldKey.BigInt() tp.zki.OldKey1[tp.i] = p.OldKey.BigInt()
tp.zki.OldValue1[tp.i] = p.OldValue.BigInt() tp.zki.OldValue1[tp.i] = p.OldValue.BigInt()
tp.zki.Metadata.NewLastIdxRaw = tp.s.CurrentIdx + 1
tp.zki.Metadata.NewLastIdxRaw = tp.s.CurrentIdx() + 1
tp.zki.AuxFromIdx[tp.i] = common.Idx(tp.s.CurrentIdx + 1).BigInt()
tp.zki.AuxFromIdx[tp.i] = common.Idx(tp.s.CurrentIdx() + 1).BigInt()
tp.zki.NewAccount[tp.i] = big.NewInt(1) tp.zki.NewAccount[tp.i] = big.NewInt(1)
if tp.i < len(tp.zki.ISOnChain) { // len(tp.zki.ISOnChain) == nTx if tp.i < len(tp.zki.ISOnChain) { // len(tp.zki.ISOnChain) == nTx
@ -696,7 +696,7 @@ func (tp *TxProcessor) applyCreateAccount(tx *common.L1Tx) error {
} }
} }
return tp.s.SetIdx(tp.s.CurrentIdx + 1)
return tp.s.SetCurrentIdx(tp.s.CurrentIdx() + 1)
} }
// applyDeposit updates the balance in the account of the depositer, if // applyDeposit updates the balance in the account of the depositer, if
@ -894,7 +894,7 @@ func (tp *TxProcessor) applyTransfer(coordIdxsMap map[common.TokenID]common.Idx,
// applyCreateAccountDepositTransfer, in a single tx, creates a new account, // applyCreateAccountDepositTransfer, in a single tx, creates a new account,
// makes a deposit, and performs a transfer to another account // makes a deposit, and performs a transfer to another account
func (tp *TxProcessor) applyCreateAccountDepositTransfer(tx *common.L1Tx) error { func (tp *TxProcessor) applyCreateAccountDepositTransfer(tx *common.L1Tx) error {
auxFromIdx := common.Idx(tp.s.CurrentIdx + 1)
auxFromIdx := common.Idx(tp.s.CurrentIdx() + 1)
accSender := &common.Account{ accSender := &common.Account{
TokenID: tx.TokenID, TokenID: tx.TokenID,
Nonce: 0, Nonce: 0,
@ -920,7 +920,7 @@ func (tp *TxProcessor) applyCreateAccountDepositTransfer(tx *common.L1Tx) error
accSender.Balance = new(big.Int).Sub(accSender.Balance, tx.EffectiveAmount) accSender.Balance = new(big.Int).Sub(accSender.Balance, tx.EffectiveAmount)
// create Account of the Sender // create Account of the Sender
p, err := tp.s.CreateAccount(common.Idx(tp.s.CurrentIdx+1), accSender)
p, err := tp.s.CreateAccount(common.Idx(tp.s.CurrentIdx()+1), accSender)
if err != nil { if err != nil {
return tracerr.Wrap(err) return tracerr.Wrap(err)
} }
@ -932,7 +932,7 @@ func (tp *TxProcessor) applyCreateAccountDepositTransfer(tx *common.L1Tx) error
tp.zki.OldKey1[tp.i] = p.OldKey.BigInt() tp.zki.OldKey1[tp.i] = p.OldKey.BigInt()
tp.zki.OldValue1[tp.i] = p.OldValue.BigInt() tp.zki.OldValue1[tp.i] = p.OldValue.BigInt()
tp.zki.Metadata.NewLastIdxRaw = tp.s.CurrentIdx + 1
tp.zki.Metadata.NewLastIdxRaw = tp.s.CurrentIdx() + 1
tp.zki.AuxFromIdx[tp.i] = auxFromIdx.BigInt() tp.zki.AuxFromIdx[tp.i] = auxFromIdx.BigInt()
tp.zki.NewAccount[tp.i] = big.NewInt(1) tp.zki.NewAccount[tp.i] = big.NewInt(1)
@ -976,7 +976,7 @@ func (tp *TxProcessor) applyCreateAccountDepositTransfer(tx *common.L1Tx) error
tp.zki.Siblings2[tp.i] = siblingsToZKInputFormat(p.Siblings) tp.zki.Siblings2[tp.i] = siblingsToZKInputFormat(p.Siblings)
} }
return tp.s.SetIdx(tp.s.CurrentIdx + 1)
return tp.s.SetCurrentIdx(tp.s.CurrentIdx() + 1)
} }
// It returns the ExitAccount and a boolean determining if the Exit created a // It returns the ExitAccount and a boolean determining if the Exit created a

Loading…
Cancel
Save