Browse Source

Merge pull request #74 from hermeznetwork/feature/statedb-processtxs

Add abstraction method of processTxs to StateDB
feature/sql-semaphore1
a_bennassar 4 years ago
committed by GitHub
parent
commit
88906beabe
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 149 additions and 85 deletions
  1. +1
    -1
      .github/workflows/lint.yml
  2. +3
    -21
      batchbuilder/batchbuilder.go
  3. +8
    -6
      common/account.go
  4. +6
    -6
      common/account_test.go
  5. +18
    -0
      common/batch.go
  6. +0
    -1
      common/l1tx.go
  7. +12
    -0
      common/l2tx.go
  8. +24
    -3
      common/pooll2tx.go
  9. +2
    -6
      common/pooll2tx_test.go
  10. +10
    -1
      common/tx.go
  11. +4
    -0
      common/zk.go
  12. +8
    -7
      coordinator/coordinator.go
  13. +5
    -9
      db/statedb/statedb.go
  14. +9
    -9
      db/statedb/statedb_test.go
  15. +39
    -15
      db/statedb/txprocessors.go

+ 1
- 1
.github/workflows/lint.yml

@ -13,4 +13,4 @@ jobs:
- name: Lint - name: Lint
run: | run: |
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.24.0 curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.24.0
$(go env GOPATH)/bin/golangci-lint run
$(go env GOPATH)/bin/golangci-lint run --timeout=5m

+ 3
- 21
batchbuilder/batchbuilder.go

@ -51,25 +51,7 @@ func (bb *BatchBuilder) Reset(batchNum uint64, fromSynchronizer bool) error {
} }
// BuildBatch takes the transactions and returns the common.ZKInputs of the next batch // BuildBatch takes the transactions and returns the common.ZKInputs of the next batch
func (bb *BatchBuilder) BuildBatch(configBatch *ConfigBatch, l1usertxs, l1coordinatortxs []*common.L1Tx, l2txs []*common.PoolL2Tx, tokenIDs []common.TokenID) (*common.ZKInputs, error) {
for _, tx := range l1usertxs {
err := bb.localStateDB.ProcessL1Tx(tx)
if err != nil {
return nil, err
}
}
for _, tx := range l1coordinatortxs {
err := bb.localStateDB.ProcessL1Tx(tx)
if err != nil {
return nil, err
}
}
for _, tx := range l2txs {
err := bb.localStateDB.ProcessPoolL2Tx(tx)
if err != nil {
return nil, err
}
}
return nil, nil
func (bb *BatchBuilder) BuildBatch(configBatch *ConfigBatch, l1usertxs, l1coordinatortxs []*common.L1Tx, l2txs []*common.L2Tx, tokenIDs []common.TokenID) (*common.ZKInputs, error) {
zkInputs, _, err := bb.localStateDB.ProcessTxs(l1usertxs, l1coordinatortxs, l2txs)
return zkInputs, err
} }

+ 8
- 6
common/account.go

@ -17,7 +17,7 @@ const NLEAFELEMS = 4
// Account is a struct that gives information of the holdings of an address and a specific token. Is the data structure that generates the Value stored in the leaf of the MerkleTree // Account is a struct that gives information of the holdings of an address and a specific token. Is the data structure that generates the Value stored in the leaf of the MerkleTree
type Account struct { type Account struct {
TokenID TokenID TokenID TokenID
Nonce uint64 // max of 40 bits used
Nonce Nonce // max of 40 bits used
Balance *big.Int // max of 192 bits used Balance *big.Int // max of 192 bits used
PublicKey *babyjub.PublicKey PublicKey *babyjub.PublicKey
EthAddr ethCommon.Address EthAddr ethCommon.Address
@ -44,8 +44,10 @@ func (a *Account) Bytes() ([32 * NLEAFELEMS]byte, error) {
return b, fmt.Errorf("%s Balance", ErrNumOverflow) return b, fmt.Errorf("%s Balance", ErrNumOverflow)
} }
var nonceBytes [8]byte
binary.LittleEndian.PutUint64(nonceBytes[:], a.Nonce)
nonceBytes, err := a.Nonce.Bytes()
if err != nil {
return b, err
}
copy(b[0:4], a.TokenID.Bytes()) copy(b[0:4], a.TokenID.Bytes())
copy(b[4:9], nonceBytes[:]) copy(b[4:9], nonceBytes[:])
@ -107,9 +109,9 @@ func AccountFromBigInts(e [NLEAFELEMS]*big.Int) (*Account, error) {
// AccountFromBytes returns a Account from a byte array // AccountFromBytes returns a Account from a byte array
func AccountFromBytes(b [32 * NLEAFELEMS]byte) (*Account, error) { func AccountFromBytes(b [32 * NLEAFELEMS]byte) (*Account, error) {
tokenID := binary.LittleEndian.Uint32(b[0:4]) tokenID := binary.LittleEndian.Uint32(b[0:4])
var nonceBytes [8]byte
copy(nonceBytes[:], b[4:9])
nonce := binary.LittleEndian.Uint64(nonceBytes[:])
var nonceBytes5 [5]byte
copy(nonceBytes5[:], b[4:9])
nonce := NonceFromBytes(nonceBytes5)
sign := b[10] == 1 sign := b[10] == 1
balance := new(big.Int).SetBytes(SwapEndianness(b[32:56])) // b[32:56], as Balance is 192 bits (24 bytes) balance := new(big.Int).SetBytes(SwapEndianness(b[32:56])) // b[32:56], as Balance is 192 bits (24 bytes)
if !bytes.Equal(b[56:64], []byte{0, 0, 0, 0, 0, 0, 0, 0}) { if !bytes.Equal(b[56:64], []byte{0, 0, 0, 0, 0, 0, 0, 0}) {

+ 6
- 6
common/account_test.go

@ -23,7 +23,7 @@ func TestAccount(t *testing.T) {
account := &Account{ account := &Account{
TokenID: TokenID(1), TokenID: TokenID(1),
Nonce: uint64(1234),
Nonce: Nonce(1234),
Balance: big.NewInt(1000), Balance: big.NewInt(1000),
PublicKey: pk, PublicKey: pk,
EthAddr: ethCommon.HexToAddress("0xc58d29fA6e86E4FAe04DDcEd660d45BCf3Cb2370"), EthAddr: ethCommon.HexToAddress("0xc58d29fA6e86E4FAe04DDcEd660d45BCf3Cb2370"),
@ -66,7 +66,7 @@ func TestAccountLoop(t *testing.T) {
account := &Account{ account := &Account{
TokenID: TokenID(i), TokenID: TokenID(i),
Nonce: uint64(i),
Nonce: Nonce(i),
Balance: big.NewInt(1000), Balance: big.NewInt(1000),
PublicKey: pk, PublicKey: pk,
EthAddr: address, EthAddr: address,
@ -98,7 +98,7 @@ func TestAccountHashValue(t *testing.T) {
account := &Account{ account := &Account{
TokenID: TokenID(1), TokenID: TokenID(1),
Nonce: uint64(1234),
Nonce: Nonce(1234),
Balance: big.NewInt(1000), Balance: big.NewInt(1000),
PublicKey: pk, PublicKey: pk,
EthAddr: ethCommon.HexToAddress("0xc58d29fA6e86E4FAe04DDcEd660d45BCf3Cb2370"), EthAddr: ethCommon.HexToAddress("0xc58d29fA6e86E4FAe04DDcEd660d45BCf3Cb2370"),
@ -142,7 +142,7 @@ func TestAccountErrNumOverflowNonce(t *testing.T) {
// check limit // check limit
account := &Account{ account := &Account{
TokenID: TokenID(1), TokenID: TokenID(1),
Nonce: uint64(math.Pow(2, 40) - 1),
Nonce: Nonce(math.Pow(2, 40) - 1),
Balance: big.NewInt(1000), Balance: big.NewInt(1000),
PublicKey: pk, PublicKey: pk,
EthAddr: ethCommon.HexToAddress("0xc58d29fA6e86E4FAe04DDcEd660d45BCf3Cb2370"), EthAddr: ethCommon.HexToAddress("0xc58d29fA6e86E4FAe04DDcEd660d45BCf3Cb2370"),
@ -151,7 +151,7 @@ func TestAccountErrNumOverflowNonce(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
// force value overflow // force value overflow
account.Nonce = uint64(math.Pow(2, 40))
account.Nonce = Nonce(math.Pow(2, 40))
b, err := account.Bytes() b, err := account.Bytes()
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, fmt.Errorf("%s Nonce", ErrNumOverflow), err) assert.Equal(t, fmt.Errorf("%s Nonce", ErrNumOverflow), err)
@ -169,7 +169,7 @@ func TestAccountErrNumOverflowBalance(t *testing.T) {
// check limit // check limit
account := &Account{ account := &Account{
TokenID: TokenID(1), TokenID: TokenID(1),
Nonce: uint64(math.Pow(2, 40) - 1),
Nonce: Nonce(math.Pow(2, 40) - 1),
Balance: new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil), big.NewInt(1)), Balance: new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(192), nil), big.NewInt(1)),
PublicKey: pk, PublicKey: pk,
EthAddr: ethCommon.HexToAddress("0xc58d29fA6e86E4FAe04DDcEd660d45BCf3Cb2370"), EthAddr: ethCommon.HexToAddress("0xc58d29fA6e86E4FAe04DDcEd660d45BCf3Cb2370"),

+ 18
- 0
common/batch.go

@ -1,6 +1,8 @@
package common package common
import ( import (
"encoding/binary"
"fmt"
"math/big" "math/big"
ethCommon "github.com/ethereum/go-ethereum/common" ethCommon "github.com/ethereum/go-ethereum/common"
@ -26,3 +28,19 @@ type Batch struct {
// BatchNum identifies a batch // BatchNum identifies a batch
type BatchNum uint32 type BatchNum uint32
// Bytes returns a byte array of length 4 representing the BatchNum
func (bn BatchNum) Bytes() []byte {
var batchNumBytes [4]byte
binary.LittleEndian.PutUint32(batchNumBytes[:], uint32(bn))
return batchNumBytes[:]
}
// BatchNumFromBytes returns BatchNum from a []byte
func BatchNumFromBytes(b []byte) (BatchNum, error) {
if len(b) != 4 {
return 0, fmt.Errorf("can not parse BatchNumFromBytes, bytes len %d, expected 4", len(b))
}
batchNum := binary.LittleEndian.Uint32(b[:4])
return BatchNum(batchNum), nil
}

+ 0
- 1
common/l1tx.go

@ -31,7 +31,6 @@ func (tx *L1Tx) Tx() *Tx {
TxID: tx.TxID, TxID: tx.TxID,
FromIdx: tx.FromIdx, FromIdx: tx.FromIdx,
ToIdx: tx.ToIdx, ToIdx: tx.ToIdx,
TokenID: tx.TokenID,
Amount: tx.Amount, Amount: tx.Amount,
Nonce: 0, Nonce: 0,
Fee: 0, Fee: 0,

+ 12
- 0
common/l2tx.go

@ -18,3 +18,15 @@ type L2Tx struct {
// Extra metadata, may be uninitialized // Extra metadata, may be uninitialized
Type TxType `meddler:"-"` // optional, descrives which kind of tx it's Type TxType `meddler:"-"` // optional, descrives which kind of tx it's
} }
func (tx *L2Tx) Tx() *Tx {
return &Tx{
TxID: tx.TxID,
FromIdx: tx.FromIdx,
ToIdx: tx.ToIdx,
Amount: tx.Amount,
Nonce: tx.Nonce,
Fee: tx.Fee,
Type: tx.Type,
}
}

+ 24
- 3
common/pooll2tx.go

@ -27,11 +27,12 @@ func (n Nonce) Bytes() ([5]byte, error) {
return b, nil return b, nil
} }
func NonceFromBytes(b [5]byte) (Nonce, error) {
// NonceFromBytes returns Nonce from a [5]byte
func NonceFromBytes(b [5]byte) Nonce {
var nonceBytes [8]byte var nonceBytes [8]byte
copy(nonceBytes[:], b[:5]) copy(nonceBytes[:], b[:5])
nonce := binary.LittleEndian.Uint64(nonceBytes[:]) nonce := binary.LittleEndian.Uint64(nonceBytes[:])
return Nonce(nonce), nil
return Nonce(nonce)
} }
// PoolL2Tx is a struct that represents a L2Tx sent by an account to the coordinator hat is waiting to be forged // PoolL2Tx is a struct that represents a L2Tx sent by an account to the coordinator hat is waiting to be forged
@ -171,12 +172,24 @@ func (tx *PoolL2Tx) VerifySignature(pk *babyjub.PublicKey) bool {
return pk.VerifyPoseidon(h, tx.Signature) return pk.VerifyPoseidon(h, tx.Signature)
} }
func (tx *PoolL2Tx) L2Tx() *L2Tx {
return &L2Tx{
TxID: tx.TxID,
BatchNum: tx.BatchNum,
FromIdx: tx.FromIdx,
ToIdx: tx.ToIdx,
Amount: tx.Amount,
Fee: tx.Fee,
Nonce: tx.Nonce,
Type: tx.Type,
}
}
func (tx *PoolL2Tx) Tx() *Tx { func (tx *PoolL2Tx) Tx() *Tx {
return &Tx{ return &Tx{
TxID: tx.TxID, TxID: tx.TxID,
FromIdx: tx.FromIdx, FromIdx: tx.FromIdx,
ToIdx: tx.ToIdx, ToIdx: tx.ToIdx,
TokenID: tx.TokenID,
Amount: tx.Amount, Amount: tx.Amount,
Nonce: tx.Nonce, Nonce: tx.Nonce,
Fee: tx.Fee, Fee: tx.Fee,
@ -184,6 +197,14 @@ func (tx *PoolL2Tx) Tx() *Tx {
} }
} }
func PoolL2TxsToL2Txs(txs []*PoolL2Tx) []*L2Tx {
var r []*L2Tx
for _, tx := range txs {
r = append(r, tx.L2Tx())
}
return r
}
// PoolL2TxState is a struct that represents the status of a L2 transaction // PoolL2TxState is a struct that represents the status of a L2 transaction
type PoolL2TxState string type PoolL2TxState string

+ 2
- 6
common/pooll2tx_test.go

@ -16,8 +16,7 @@ func TestNonceParser(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 5, len(nBytes)) assert.Equal(t, 5, len(nBytes))
assert.Equal(t, "0100000000", hex.EncodeToString(nBytes[:])) assert.Equal(t, "0100000000", hex.EncodeToString(nBytes[:]))
n2, err := NonceFromBytes(nBytes)
assert.Nil(t, err)
n2 := NonceFromBytes(nBytes)
assert.Equal(t, n, n2) assert.Equal(t, n, n2)
// value before overflow // value before overflow
@ -26,8 +25,7 @@ func TestNonceParser(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 5, len(nBytes)) assert.Equal(t, 5, len(nBytes))
assert.Equal(t, "ffffffffff", hex.EncodeToString(nBytes[:])) assert.Equal(t, "ffffffffff", hex.EncodeToString(nBytes[:]))
n2, err = NonceFromBytes(nBytes)
assert.Nil(t, err)
n2 = NonceFromBytes(nBytes)
assert.Equal(t, n, n2) assert.Equal(t, n, n2)
// expect value overflow // expect value overflow
@ -35,8 +33,6 @@ func TestNonceParser(t *testing.T) {
nBytes, err = n.Bytes() nBytes, err = n.Bytes()
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, ErrNonceOverflow, err) assert.Equal(t, ErrNonceOverflow, err)
_, err = NonceFromBytes(nBytes)
assert.Nil(t, err)
} }
func TestTxCompressedData(t *testing.T) { func TestTxCompressedData(t *testing.T) {

+ 10
- 1
common/tx.go

@ -2,6 +2,7 @@ package common
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"math/big" "math/big"
) )
@ -20,6 +21,15 @@ func (idx Idx) BigInt() *big.Int {
return big.NewInt(int64(idx)) return big.NewInt(int64(idx))
} }
// IdxFromBytes returns Idx from a byte array
func IdxFromBytes(b []byte) (Idx, error) {
if len(b) != 4 {
return 0, fmt.Errorf("can not parse Idx, bytes len %d, expected 4", len(b))
}
idx := binary.LittleEndian.Uint32(b[:4])
return Idx(idx), nil
}
// IdxFromBigInt converts a *big.Int to Idx type // IdxFromBigInt converts a *big.Int to Idx type
func IdxFromBigInt(b *big.Int) (Idx, error) { func IdxFromBigInt(b *big.Int) (Idx, error) {
if b.Int64() > 0xffffffff { // 2**32-1 if b.Int64() > 0xffffffff { // 2**32-1
@ -64,7 +74,6 @@ type Tx struct {
TxID TxID TxID TxID
FromIdx Idx // FromIdx is used by L1Tx/Deposit to indicate the Idx receiver of the L1Tx.LoadAmount (deposit) FromIdx Idx // FromIdx is used by L1Tx/Deposit to indicate the Idx receiver of the L1Tx.LoadAmount (deposit)
ToIdx Idx // ToIdx is ignored in L1Tx/Deposit, but used in the L1Tx/DepositTransfer ToIdx Idx // ToIdx is ignored in L1Tx/Deposit, but used in the L1Tx/DepositTransfer
TokenID TokenID
Amount *big.Int Amount *big.Int
Nonce Nonce // effective 40 bits used Nonce Nonce // effective 40 bits used
Fee FeeSelector Fee FeeSelector

+ 4
- 0
common/zk.go

@ -58,3 +58,7 @@ type ZKInputs struct {
type CallDataForge struct { type CallDataForge struct {
// TBD // TBD
} }
type ExitTreeLeaf struct {
// TBD
}

+ 8
- 7
coordinator/coordinator.go

@ -98,21 +98,21 @@ func (c *Coordinator) forgeSequence() error {
c.batchNum = c.batchNum + 1 c.batchNum = c.batchNum + 1
batchInfo := NewBatchInfo(c.batchNum, serverProofInfo) // to accumulate metadata of the batch batchInfo := NewBatchInfo(c.batchNum, serverProofInfo) // to accumulate metadata of the batch
var l2Txs []*common.PoolL2Tx
var poolL2Txs []*common.PoolL2Tx
// var feesInfo // var feesInfo
var l1UserTxsExtra, l1OperatorTxs []*common.L1Tx var l1UserTxsExtra, l1OperatorTxs []*common.L1Tx
// 1. Decide if we forge L2Tx or L1+L2Tx // 1. Decide if we forge L2Tx or L1+L2Tx
if c.shouldL1L2Batch() { if c.shouldL1L2Batch() {
// 2a: L1+L2 txs // 2a: L1+L2 txs
// l1UserTxs, toForgeL1TxsNumber := c.synchronizer.GetNextL1UserTxs() // TODO once synchronizer is ready, uncomment // l1UserTxs, toForgeL1TxsNumber := c.synchronizer.GetNextL1UserTxs() // TODO once synchronizer is ready, uncomment
var l1UserTxs []*common.L1Tx = nil // tmp, depends on synchronizer
l1UserTxsExtra, l1OperatorTxs, l2Txs, err = c.txsel.GetL1L2TxSelection(c.batchNum, l1UserTxs) // TODO once feesInfo is added to method return, add the var
var l1UserTxs []*common.L1Tx = nil // tmp, depends on synchronizer
l1UserTxsExtra, l1OperatorTxs, poolL2Txs, err = c.txsel.GetL1L2TxSelection(c.batchNum, l1UserTxs) // TODO once feesInfo is added to method return, add the var
if err != nil { if err != nil {
return err return err
} }
} else { } else {
// 2b: only L2 txs // 2b: only L2 txs
l2Txs, err = c.txsel.GetL2TxSelection(c.batchNum) // TODO once feesInfo is added to method return, add the var
poolL2Txs, err = c.txsel.GetL2TxSelection(c.batchNum) // TODO once feesInfo is added to method return, add the var
if err != nil { if err != nil {
return err return err
} }
@ -121,21 +121,22 @@ func (c *Coordinator) forgeSequence() error {
} }
// Run purger to invalidate transactions that become invalid beause of // Run purger to invalidate transactions that become invalid beause of
// the l2Txs selected. Will mark as invalid the txs that have a
// the poolL2Txs selected. Will mark as invalid the txs that have a
// (fromIdx, nonce) which already appears in the selected txs (includes // (fromIdx, nonce) which already appears in the selected txs (includes
// all the nonces smaller than the current one) // all the nonces smaller than the current one)
err = c.purgeInvalidDueToL2TxsSelection(l2Txs)
err = c.purgeInvalidDueToL2TxsSelection(poolL2Txs)
if err != nil { if err != nil {
return err return err
} }
// 3. Save metadata from TxSelector output for BatchNum // 3. Save metadata from TxSelector output for BatchNum
batchInfo.SetTxsInfo(l1UserTxsExtra, l1OperatorTxs, l2Txs) // TODO feesInfo
batchInfo.SetTxsInfo(l1UserTxsExtra, l1OperatorTxs, poolL2Txs) // TODO feesInfo
// 4. Call BatchBuilder with TxSelector output // 4. Call BatchBuilder with TxSelector output
configBatch := &batchbuilder.ConfigBatch{ configBatch := &batchbuilder.ConfigBatch{
ForgerAddress: c.config.ForgerAddress, ForgerAddress: c.config.ForgerAddress,
} }
l2Txs := common.PoolL2TxsToL2Txs(poolL2Txs)
zkInputs, err := c.batchBuilder.BuildBatch(configBatch, l1UserTxsExtra, l1OperatorTxs, l2Txs, nil) // TODO []common.TokenID --> feesInfo zkInputs, err := c.batchBuilder.BuildBatch(configBatch, l1UserTxsExtra, l1OperatorTxs, l2Txs, nil) // TODO []common.TokenID --> feesInfo
if err != nil { if err != nil {
return err return err

+ 5
- 9
db/statedb/statedb.go

@ -1,7 +1,6 @@
package statedb package statedb
import ( import (
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"os" "os"
@ -31,11 +30,11 @@ const PATHCURRENT = "/current"
// StateDB represents the StateDB object // StateDB represents the StateDB object
type StateDB struct { type StateDB struct {
path string path string
currentBatch uint64
currentBatch common.BatchNum
db *pebble.PebbleStorage db *pebble.PebbleStorage
mt *merkletree.MerkleTree mt *merkletree.MerkleTree
// idx holds the current Idx that the BatchBuilder is using // idx holds the current Idx that the BatchBuilder is using
idx uint64
idx common.Idx
} }
// NewStateDB creates a new StateDB, allowing to use an in-memory or in-disk // NewStateDB creates a new StateDB, allowing to use an in-memory or in-disk
@ -77,7 +76,7 @@ func (s *StateDB) DB() *pebble.PebbleStorage {
} }
// GetCurrentBatch returns the current BatchNum stored in the StateDB // GetCurrentBatch returns the current BatchNum stored in the StateDB
func (s *StateDB) GetCurrentBatch() (uint64, error) {
func (s *StateDB) GetCurrentBatch() (common.BatchNum, error) {
cbBytes, err := s.db.Get(KEYCURRENTBATCH) cbBytes, err := s.db.Get(KEYCURRENTBATCH)
if err == db.ErrNotFound { if err == db.ErrNotFound {
return 0, nil return 0, nil
@ -85,8 +84,7 @@ func (s *StateDB) GetCurrentBatch() (uint64, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
cb := binary.LittleEndian.Uint64(cbBytes[:8])
return cb, nil
return common.BatchNumFromBytes(cbBytes)
} }
// setCurrentBatch stores the current BatchNum in the StateDB // setCurrentBatch stores the current BatchNum in the StateDB
@ -95,9 +93,7 @@ func (s *StateDB) setCurrentBatch() error {
if err != nil { if err != nil {
return err return err
} }
var cbBytes [8]byte
binary.LittleEndian.PutUint64(cbBytes[:], s.currentBatch)
tx.Put(KEYCURRENTBATCH, cbBytes[:])
tx.Put(KEYCURRENTBATCH, s.currentBatch.Bytes())
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return err return err
} }

+ 9
- 9
db/statedb/statedb_test.go

@ -27,7 +27,7 @@ func newAccount(t *testing.T, i int) *common.Account {
return &common.Account{ return &common.Account{
TokenID: common.TokenID(i), TokenID: common.TokenID(i),
Nonce: uint64(i),
Nonce: common.Nonce(i),
Balance: big.NewInt(1000), Balance: big.NewInt(1000),
PublicKey: pk, PublicKey: pk,
EthAddr: address, EthAddr: address,
@ -159,7 +159,7 @@ func TestCheckpoints(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
cb, err := sdb.GetCurrentBatch() cb, err := sdb.GetCurrentBatch()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, uint64(1), cb)
assert.Equal(t, common.BatchNum(1), cb)
for i := 1; i < 10; i++ { for i := 1; i < 10; i++ {
err = sdb.MakeCheckpoint() err = sdb.MakeCheckpoint()
@ -167,7 +167,7 @@ func TestCheckpoints(t *testing.T) {
cb, err = sdb.GetCurrentBatch() cb, err = sdb.GetCurrentBatch()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, uint64(i+1), cb)
assert.Equal(t, common.BatchNum(i+1), cb)
} }
// printCheckpoints(t, sdb.path) // printCheckpoints(t, sdb.path)
@ -184,14 +184,14 @@ func TestCheckpoints(t *testing.T) {
// check that currentBatch is as expected after Reset // check that currentBatch is as expected after Reset
cb, err = sdb.GetCurrentBatch() cb, err = sdb.GetCurrentBatch()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, uint64(3), cb)
assert.Equal(t, common.BatchNum(3), cb)
// advance one checkpoint and check that currentBatch is fine // advance one checkpoint and check that currentBatch is fine
err = sdb.MakeCheckpoint() err = sdb.MakeCheckpoint()
assert.Nil(t, err) assert.Nil(t, err)
cb, err = sdb.GetCurrentBatch() cb, err = sdb.GetCurrentBatch()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, uint64(4), cb)
assert.Equal(t, common.BatchNum(4), cb)
err = sdb.DeleteCheckpoint(uint64(9)) err = sdb.DeleteCheckpoint(uint64(9))
assert.Nil(t, err) assert.Nil(t, err)
@ -214,13 +214,13 @@ func TestCheckpoints(t *testing.T) {
// check that currentBatch is 4 after the Reset // check that currentBatch is 4 after the Reset
cb, err = ldb.GetCurrentBatch() cb, err = ldb.GetCurrentBatch()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, uint64(4), cb)
assert.Equal(t, common.BatchNum(4), cb)
// advance one checkpoint in ldb // advance one checkpoint in ldb
err = ldb.MakeCheckpoint() err = ldb.MakeCheckpoint()
assert.Nil(t, err) assert.Nil(t, err)
cb, err = ldb.GetCurrentBatch() cb, err = ldb.GetCurrentBatch()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, uint64(5), cb)
assert.Equal(t, common.BatchNum(5), cb)
// Create a 2nd LocalStateDB from the initial StateDB // Create a 2nd LocalStateDB from the initial StateDB
dirLocal2, err := ioutil.TempDir("", "ldb2") dirLocal2, err := ioutil.TempDir("", "ldb2")
@ -234,13 +234,13 @@ func TestCheckpoints(t *testing.T) {
// check that currentBatch is 4 after the Reset // check that currentBatch is 4 after the Reset
cb, err = ldb2.GetCurrentBatch() cb, err = ldb2.GetCurrentBatch()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, uint64(4), cb)
assert.Equal(t, common.BatchNum(4), cb)
// advance one checkpoint in ldb2 // advance one checkpoint in ldb2
err = ldb2.MakeCheckpoint() err = ldb2.MakeCheckpoint()
assert.Nil(t, err) assert.Nil(t, err)
cb, err = ldb2.GetCurrentBatch() cb, err = ldb2.GetCurrentBatch()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, uint64(5), cb)
assert.Equal(t, common.BatchNum(5), cb)
debug := false debug := false
if debug { if debug {

+ 39
- 15
db/statedb/txprocessors.go

@ -1,19 +1,46 @@
package statedb package statedb
import ( import (
"encoding/binary"
"math/big" "math/big"
"github.com/hermeznetwork/hermez-node/common" "github.com/hermeznetwork/hermez-node/common"
"github.com/iden3/go-merkletree/db" "github.com/iden3/go-merkletree/db"
) )
// KEYIDX is used as key in the db to store the current Idx
var KEYIDX = []byte("idx")
// keyidx is used as key in the db to store the current Idx
var keyidx = []byte("idx")
// ProcessTxs process the given L1Txs & L2Txs applying the needed updates
// to the StateDB depending on the transaction Type. Returns the
// common.ZKInputs to generate the SnarkProof later used by the BatchBuilder,
// and returns common.ExitTreeLeaf that is later used by the Synchronizer to
// update the HistoryDB.
func (s *StateDB) ProcessTxs(l1usertxs, l1coordinatortxs []*common.L1Tx, l2txs []*common.L2Tx) (*common.ZKInputs, []*common.ExitTreeLeaf, error) {
for _, tx := range l1usertxs {
err := s.processL1Tx(tx)
if err != nil {
return nil, nil, err
}
}
for _, tx := range l1coordinatortxs {
err := s.processL1Tx(tx)
if err != nil {
return nil, nil, err
}
}
for _, tx := range l2txs {
err := s.processL2Tx(tx)
if err != nil {
return nil, nil, err
}
}
return nil, nil, nil
}
// ProcessPoolL2Tx process the given PoolL2Tx applying the needed updates to
// processL2Tx process the given L2Tx applying the needed updates to
// the StateDB depending on the transaction Type. // the StateDB depending on the transaction Type.
func (s *StateDB) ProcessPoolL2Tx(tx *common.PoolL2Tx) error {
func (s *StateDB) processL2Tx(tx *common.L2Tx) error {
switch tx.Type { switch tx.Type {
case common.TxTypeTransfer: case common.TxTypeTransfer:
// go to the MT account of sender and receiver, and update // go to the MT account of sender and receiver, and update
@ -29,9 +56,9 @@ func (s *StateDB) ProcessPoolL2Tx(tx *common.PoolL2Tx) error {
return nil return nil
} }
// ProcessL1Tx process the given L1Tx applying the needed updates to the
// processL1Tx process the given L1Tx applying the needed updates to the
// StateDB depending on the transaction Type. // StateDB depending on the transaction Type.
func (s *StateDB) ProcessL1Tx(tx *common.L1Tx) error {
func (s *StateDB) processL1Tx(tx *common.L1Tx) error {
switch tx.Type { switch tx.Type {
case common.TxTypeForceTransfer, common.TxTypeTransfer: case common.TxTypeForceTransfer, common.TxTypeTransfer:
// go to the MT account of sender and receiver, and update balance // go to the MT account of sender and receiver, and update balance
@ -170,27 +197,24 @@ func (s *StateDB) applyTransfer(tx *common.Tx) error {
// getIdx returns the stored Idx from the localStateDB, which is the last Idx // getIdx returns the stored Idx from the localStateDB, which is the last Idx
// used for an Account in the localStateDB. // used for an Account in the localStateDB.
func (s *StateDB) getIdx() (uint64, error) {
idxBytes, err := s.DB().Get(KEYIDX)
func (s *StateDB) getIdx() (common.Idx, error) {
idxBytes, err := s.DB().Get(keyidx)
if err == db.ErrNotFound { if err == db.ErrNotFound {
return 0, nil return 0, nil
} }
if err != nil { if err != nil {
return 0, err return 0, err
} }
idx := binary.LittleEndian.Uint64(idxBytes[:8])
return idx, nil
return common.IdxFromBytes(idxBytes[:4])
} }
// setIdx stores Idx in the localStateDB // setIdx stores Idx in the localStateDB
func (s *StateDB) setIdx(idx uint64) error {
func (s *StateDB) setIdx(idx common.Idx) error {
tx, err := s.DB().NewTx() tx, err := s.DB().NewTx()
if err != nil { if err != nil {
return err return err
} }
var idxBytes [8]byte
binary.LittleEndian.PutUint64(idxBytes[:], idx)
tx.Put(KEYIDX, idxBytes[:])
tx.Put(keyidx, idx.Bytes())
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return err return err
} }

Loading…
Cancel
Save