Browse Source

Add StateDB ZKInputs generation for L1 & L2 Txs

feature/sql-semaphore1
arnaucube 4 years ago
parent
commit
aa0bde61d2
7 changed files with 245 additions and 20 deletions
  1. +2
    -2
      batchbuilder/batchbuilder.go
  2. +25
    -0
      common/l2tx.go
  3. +1
    -2
      coordinator/coordinator.go
  4. +2
    -0
      db/statedb/statedb.go
  5. +159
    -8
      db/statedb/txprocessors.go
  6. +39
    -8
      db/statedb/txprocessors_test.go
  7. +17
    -0
      db/statedb/utils.go

+ 2
- 2
batchbuilder/batchbuilder.go

@ -51,8 +51,8 @@ func (bb *BatchBuilder) Reset(batchNum common.BatchNum, fromSynchronizer bool) e
} }
// BuildBatch takes the transactions and returns the common.ZKInputs of the next batch // BuildBatch takes the transactions and returns the common.ZKInputs of the next batch
func (bb *BatchBuilder) BuildBatch(configBatch *ConfigBatch, l1usertxs, l1coordinatortxs []*common.L1Tx, l2txs []*common.L2Tx, tokenIDs []common.TokenID) (*common.ZKInputs, error) {
zkInputs, _, err := bb.localStateDB.ProcessTxs(false, l1usertxs, l1coordinatortxs, l2txs)
func (bb *BatchBuilder) BuildBatch(configBatch *ConfigBatch, l1usertxs, l1coordinatortxs []*common.L1Tx, pooll2txs []*common.PoolL2Tx, tokenIDs []common.TokenID) (*common.ZKInputs, error) {
zkInputs, _, err := bb.localStateDB.ProcessTxs(false, true, l1usertxs, l1coordinatortxs, pooll2txs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

+ 25
- 0
common/l2tx.go

@ -30,3 +30,28 @@ func (tx *L2Tx) Tx() *Tx {
Type: tx.Type, Type: tx.Type,
} }
} }
// PoolL2Tx returns the data structure of PoolL2Tx with the parameters of a
// L2Tx filled
func (tx *L2Tx) PoolL2Tx() *PoolL2Tx {
return &PoolL2Tx{
TxID: tx.TxID,
BatchNum: tx.BatchNum,
FromIdx: tx.FromIdx,
ToIdx: tx.ToIdx,
Amount: tx.Amount,
Fee: tx.Fee,
Nonce: tx.Nonce,
Type: tx.Type,
}
}
// L2TxsToPoolL2Txs returns an array of []*PoolL2Tx from an array of []*L2Tx,
// where the PoolL2Tx only have the parameters of a L2Tx filled.
func L2TxsToPoolL2Txs(txs []*L2Tx) []*PoolL2Tx {
var r []*PoolL2Tx
for _, tx := range txs {
r = append(r, tx.PoolL2Tx())
}
return r
}

+ 1
- 2
coordinator/coordinator.go

@ -221,8 +221,7 @@ func (c *Coordinator) forgeSequence() error {
configBatch := &batchbuilder.ConfigBatch{ configBatch := &batchbuilder.ConfigBatch{
ForgerAddress: c.config.ForgerAddress, ForgerAddress: c.config.ForgerAddress,
} }
l2Txs := common.PoolL2TxsToL2Txs(poolL2Txs)
zkInputs, err := c.batchBuilder.BuildBatch(configBatch, l1UserTxsExtra, l1OperatorTxs, l2Txs, nil) // TODO []common.TokenID --> feesInfo
zkInputs, err := c.batchBuilder.BuildBatch(configBatch, l1UserTxsExtra, l1OperatorTxs, poolL2Txs, nil) // TODO []common.TokenID --> feesInfo
if err != nil { if err != nil {
return err return err
} }

+ 2
- 0
db/statedb/statedb.go

@ -43,6 +43,8 @@ type StateDB struct {
mt *merkletree.MerkleTree mt *merkletree.MerkleTree
// idx holds the current Idx that the BatchBuilder is using // idx holds the current Idx that the BatchBuilder is using
idx common.Idx idx common.Idx
zki *common.ZKInputs
i int // i is used for zki
} }
// NewStateDB creates a new StateDB, allowing to use an in-memory or in-disk // NewStateDB creates a new StateDB, allowing to use an in-memory or in-disk

+ 159
- 8
db/statedb/txprocessors.go

@ -1,8 +1,12 @@
package statedb package statedb
import ( import (
"bytes"
"errors"
"fmt"
"math/big" "math/big"
ethCommon "github.com/ethereum/go-ethereum/common"
"github.com/hermeznetwork/hermez-node/common" "github.com/hermeznetwork/hermez-node/common"
"github.com/iden3/go-iden3-crypto/poseidon" "github.com/iden3/go-iden3-crypto/poseidon"
"github.com/iden3/go-merkletree" "github.com/iden3/go-merkletree"
@ -10,19 +14,42 @@ import (
"github.com/iden3/go-merkletree/db/memory" "github.com/iden3/go-merkletree/db/memory"
) )
// keyidx is used as key in the db to store the current Idx
var keyidx = []byte("idx")
var (
// keyidx is used as key in the db to store the current Idx
keyidx = []byte("idx")
ffAddr = ethCommon.HexToAddress("0xffffffffffffffffffffffffffffffffffffffff")
)
func (s *StateDB) resetZKInputs() {
s.zki = nil
s.i = 0
}
// ProcessTxs process the given L1Txs & L2Txs applying the needed updates to // ProcessTxs process the given L1Txs & L2Txs applying the needed updates to
// the StateDB depending on the transaction Type. Returns the common.ZKInputs // the StateDB depending on the transaction Type. Returns the common.ZKInputs
// to generate the SnarkProof later used by the BatchBuilder, and if // to generate the SnarkProof later used by the BatchBuilder, and if
// cmpExitTree is set to true, returns common.ExitTreeLeaf that is later used // cmpExitTree is set to true, returns common.ExitTreeLeaf that is later used
// by the Synchronizer to update the HistoryDB. // by the Synchronizer to update the HistoryDB.
func (s *StateDB) ProcessTxs(cmpExitTree bool, l1usertxs, l1coordinatortxs []*common.L1Tx, l2txs []*common.L2Tx) (*common.ZKInputs, []*common.ExitInfo, error) {
func (s *StateDB) ProcessTxs(cmpExitTree, cmpZKInputs bool, l1usertxs, l1coordinatortxs []*common.L1Tx, l2txs []*common.PoolL2Tx) (*common.ZKInputs, []*common.ExitInfo, error) {
var err error var err error
var exitTree *merkletree.MerkleTree var exitTree *merkletree.MerkleTree
exits := make(map[common.Idx]common.Account) exits := make(map[common.Idx]common.Account)
if s.zki != nil {
return nil, nil, errors.New("Expected StateDB.zki==nil, something went wrong ans is not empty")
}
defer s.resetZKInputs()
nTx := len(l1usertxs) + len(l1coordinatortxs) + len(l2txs)
if nTx == 0 {
return nil, nil, nil // TBD if return an error in the case of no Txs to process
}
if cmpZKInputs {
s.zki = common.NewZKInputs(nTx, 24, 32) // TODO this values will be parameters of the function
}
// TBD if ExitTree is only in memory or stored in disk, for the moment // TBD if ExitTree is only in memory or stored in disk, for the moment
// only needed in memory // only needed in memory
exitTree, err = merkletree.NewMerkleTree(memory.NewMemoryStorage(), s.mt.MaxLevels()) exitTree, err = merkletree.NewMerkleTree(memory.NewMemoryStorage(), s.mt.MaxLevels())
@ -30,7 +57,8 @@ func (s *StateDB) ProcessTxs(cmpExitTree bool, l1usertxs, l1coordinatortxs []*co
return nil, nil, err return nil, nil, err
} }
for _, tx := range l1coordinatortxs {
// assumption: l1usertx are sorted by L1Tx.Position
for _, tx := range l1usertxs {
exitIdx, exitAccount, err := s.processL1Tx(exitTree, tx) exitIdx, exitAccount, err := s.processL1Tx(exitTree, tx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -38,8 +66,11 @@ func (s *StateDB) ProcessTxs(cmpExitTree bool, l1usertxs, l1coordinatortxs []*co
if exitIdx != nil && cmpExitTree { if exitIdx != nil && cmpExitTree {
exits[*exitIdx] = *exitAccount exits[*exitIdx] = *exitAccount
} }
if s.zki != nil {
s.i++
}
} }
for _, tx := range l1usertxs {
for _, tx := range l1coordinatortxs {
exitIdx, exitAccount, err := s.processL1Tx(exitTree, tx) exitIdx, exitAccount, err := s.processL1Tx(exitTree, tx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -47,6 +78,9 @@ func (s *StateDB) ProcessTxs(cmpExitTree bool, l1usertxs, l1coordinatortxs []*co
if exitIdx != nil && cmpExitTree { if exitIdx != nil && cmpExitTree {
exits[*exitIdx] = *exitAccount exits[*exitIdx] = *exitAccount
} }
if s.zki != nil {
s.i++
}
} }
for _, tx := range l2txs { for _, tx := range l2txs {
exitIdx, exitAccount, err := s.processL2Tx(exitTree, tx) exitIdx, exitAccount, err := s.processL2Tx(exitTree, tx)
@ -56,9 +90,12 @@ func (s *StateDB) ProcessTxs(cmpExitTree bool, l1usertxs, l1coordinatortxs []*co
if exitIdx != nil && cmpExitTree { if exitIdx != nil && cmpExitTree {
exits[*exitIdx] = *exitAccount exits[*exitIdx] = *exitAccount
} }
if s.zki != nil {
s.i++
}
} }
if !cmpExitTree {
if !cmpExitTree && !cmpZKInputs {
return nil, nil, nil return nil, nil, nil
} }
@ -93,15 +130,78 @@ func (s *StateDB) ProcessTxs(cmpExitTree bool, l1usertxs, l1coordinatortxs []*co
} }
exitInfos = append(exitInfos, ei) exitInfos = append(exitInfos, ei)
} }
if !cmpZKInputs {
return nil, exitInfos, nil
}
// compute last ZKInputs parameters
s.zki.OldLastIdx = (s.idx - 1).BigInt()
s.zki.OldStateRoot = s.mt.Root().BigInt()
s.zki.GlobalChainID = big.NewInt(0) // TODO, 0: ethereum, get this from config file
// zki.FeeIdxs = ? // TODO, this will be get from the config file
tokenIDs, err := s.getTokenIDsBigInt(l1usertxs, l1coordinatortxs, l2txs)
if err != nil {
return nil, nil, err
}
s.zki.FeePlanTokens = tokenIDs
// s.zki.ISInitStateRootFee = s.mt.Root().BigInt()
// compute fees
// once fees are computed
// return exitInfos, so Synchronizer will be able to store it into // return exitInfos, so Synchronizer will be able to store it into
// HistoryDB for the concrete BatchNum // HistoryDB for the concrete BatchNum
return nil, exitInfos, nil
return s.zki, exitInfos, nil
}
// getTokenIDsBigInt returns the list of TokenIDs in *big.Int format
func (s *StateDB) getTokenIDsBigInt(l1usertxs, l1coordinatortxs []*common.L1Tx, l2txs []*common.PoolL2Tx) ([]*big.Int, error) {
tokenIDs := make(map[common.TokenID]bool)
for i := 0; i < len(l1usertxs); i++ {
tokenIDs[l1usertxs[i].TokenID] = true
}
for i := 0; i < len(l1coordinatortxs); i++ {
tokenIDs[l1coordinatortxs[i].TokenID] = true
}
for i := 0; i < len(l2txs); i++ {
// as L2Tx does not have parameter TokenID, get it from the
// AccountsDB (in the StateDB)
acc, err := s.GetAccount(l2txs[i].ToIdx)
if err != nil {
return nil, err
}
tokenIDs[acc.TokenID] = true
}
var tBI []*big.Int
for t := range tokenIDs {
tBI = append(tBI, t.BigInt())
}
return tBI, nil
} }
// processL1Tx process the given L1Tx applying the needed updates to the // processL1Tx process the given L1Tx applying the needed updates to the
// StateDB depending on the transaction Type. // StateDB depending on the transaction Type.
func (s *StateDB) processL1Tx(exitTree *merkletree.MerkleTree, tx *common.L1Tx) (*common.Idx, *common.Account, error) { func (s *StateDB) processL1Tx(exitTree *merkletree.MerkleTree, tx *common.L1Tx) (*common.Idx, *common.Account, error) {
// ZKInputs
if s.zki != nil {
// Txs
// s.zki.TxCompressedData[s.i] = tx.TxCompressedData() // uncomment once L1Tx.TxCompressedData is ready
s.zki.FromIdx[s.i] = tx.FromIdx.BigInt()
s.zki.ToIdx[s.i] = tx.ToIdx.BigInt()
s.zki.OnChain[s.i] = big.NewInt(1)
// L1Txs
s.zki.LoadAmountF[s.i] = tx.LoadAmount
s.zki.FromEthAddr[s.i] = common.EthAddrToBigInt(tx.FromEthAddr)
if tx.FromBJJ != nil {
s.zki.FromBJJCompressed[s.i] = common.BJJCompressedTo256BigInts(tx.FromBJJ.Compress())
}
// Intermediate States
s.zki.ISOnChain[s.i] = big.NewInt(1)
}
switch tx.Type { switch tx.Type {
case common.TxTypeForceTransfer, common.TxTypeTransfer: case common.TxTypeForceTransfer, common.TxTypeTransfer:
// go to the MT account of sender and receiver, and update balance // go to the MT account of sender and receiver, and update balance
@ -116,6 +216,11 @@ func (s *StateDB) processL1Tx(exitTree *merkletree.MerkleTree, tx *common.L1Tx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if s.zki != nil {
s.zki.AuxFromIdx[s.i] = s.idx.BigInt() // last s.idx is the one used for creating the new account
s.zki.NewAccount[s.i] = big.NewInt(1)
}
case common.TxTypeDeposit: case common.TxTypeDeposit:
// update balance of the MT account // update balance of the MT account
err := s.applyDeposit(tx, false) err := s.applyDeposit(tx, false)
@ -140,6 +245,11 @@ func (s *StateDB) processL1Tx(exitTree *merkletree.MerkleTree, tx *common.L1Tx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if s.zki != nil {
s.zki.AuxFromIdx[s.i] = s.idx.BigInt() // last s.idx is the one used for creating the new account
s.zki.NewAccount[s.i] = big.NewInt(1)
}
case common.TxTypeExit: case common.TxTypeExit:
// execute exit flow // execute exit flow
exitAccount, err := s.applyExit(exitTree, tx.Tx()) exitAccount, err := s.applyExit(exitTree, tx.Tx())
@ -155,7 +265,48 @@ func (s *StateDB) processL1Tx(exitTree *merkletree.MerkleTree, tx *common.L1Tx)
// processL2Tx process the given L2Tx applying the needed updates to // processL2Tx process the given L2Tx applying the needed updates to
// the StateDB depending on the transaction Type. // the StateDB depending on the transaction Type.
func (s *StateDB) processL2Tx(exitTree *merkletree.MerkleTree, tx *common.L2Tx) (*common.Idx, *common.Account, error) {
func (s *StateDB) processL2Tx(exitTree *merkletree.MerkleTree, tx *common.PoolL2Tx) (*common.Idx, *common.Account, error) {
// ZKInputs
if s.zki != nil {
// Txs
// s.zki.TxCompressedData[s.i] = tx.TxCompressedData() // uncomment once L1Tx.TxCompressedData is ready
// s.zki.TxCompressedDataV2[s.i] = tx.TxCompressedDataV2() // uncomment once L2Tx.TxCompressedDataV2 is ready
s.zki.FromIdx[s.i] = tx.FromIdx.BigInt()
s.zki.ToIdx[s.i] = tx.ToIdx.BigInt()
// fill AuxToIdx if needed
if tx.ToIdx == common.Idx(0) {
// Idx not set in the Tx, get it from DB through ToEthAddr or ToBJJ
var idx common.Idx
if !bytes.Equal(tx.ToEthAddr.Bytes(), ffAddr.Bytes()) {
idx = s.getIdxByEthAddr(tx.ToEthAddr)
if idx == common.Idx(0) {
return nil, nil, fmt.Errorf("Idx can not be found for given tx.FromEthAddr")
}
} else {
idx = s.getIdxByBJJ(tx.ToBJJ)
if idx == common.Idx(0) {
return nil, nil, fmt.Errorf("Idx can not be found for given tx.FromBJJ")
}
}
s.zki.AuxToIdx[s.i] = idx.BigInt()
}
s.zki.ToBJJAy[s.i] = tx.ToBJJ.Y
s.zki.ToEthAddr[s.i] = common.EthAddrToBigInt(tx.ToEthAddr)
s.zki.OnChain[s.i] = big.NewInt(0)
s.zki.NewAccount[s.i] = big.NewInt(0)
// L2Txs
// s.zki.RqOffset[s.i] = // TODO
// s.zki.RqTxCompressedDataV2[s.i] = // TODO
// s.zki.RqToEthAddr[s.i] = common.EthAddrToBigInt(tx.RqToEthAddr) // TODO
// s.zki.RqToBJJAy[s.i] = tx.ToBJJ.Y // TODO
s.zki.S[s.i] = tx.Signature.S
s.zki.R8x[s.i] = tx.Signature.R8.X
s.zki.R8y[s.i] = tx.Signature.R8.Y
}
switch tx.Type { switch tx.Type {
case common.TxTypeTransfer: case common.TxTypeTransfer:
// go to the MT account of sender and receiver, and update // go to the MT account of sender and receiver, and update

+ 39
- 8
db/statedb/txprocessors_test.go

@ -1,6 +1,8 @@
package statedb package statedb
import ( import (
"encoding/json"
"fmt"
"io/ioutil" "io/ioutil"
"strings" "strings"
"testing" "testing"
@ -11,6 +13,8 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
var debug = false
func TestProcessTxs(t *testing.T) { func TestProcessTxs(t *testing.T) {
dir, err := ioutil.TempDir("", "tmpdb") dir, err := ioutil.TempDir("", "tmpdb")
require.Nil(t, err) require.Nil(t, err)
@ -30,9 +34,9 @@ func TestProcessTxs(t *testing.T) {
// iterate for each batch // iterate for each batch
for i := 0; i < len(l1Txs); i++ { for i := 0; i < len(l1Txs); i++ {
l2Txs := common.PoolL2TxsToL2Txs(poolL2Txs[i])
// l2Txs := common.PoolL2TxsToL2Txs(poolL2Txs[i])
_, _, err := sdb.ProcessTxs(true, l1Txs[i], coordinatorL1Txs[i], l2Txs)
_, _, err := sdb.ProcessTxs(true, true, l1Txs[i], coordinatorL1Txs[i], poolL2Txs[i])
require.Nil(t, err) require.Nil(t, err)
} }
@ -65,8 +69,8 @@ func TestProcessTxsBatchByBatch(t *testing.T) {
assert.Equal(t, 7, len(poolL2Txs[2])) assert.Equal(t, 7, len(poolL2Txs[2]))
// use first batch // use first batch
l2txs := common.PoolL2TxsToL2Txs(poolL2Txs[0])
_, exitInfos, err := sdb.ProcessTxs(true, l1Txs[0], coordinatorL1Txs[0], l2txs)
// l2txs := common.PoolL2TxsToL2Txs(poolL2Txs[0])
_, exitInfos, err := sdb.ProcessTxs(true, true, l1Txs[0], coordinatorL1Txs[0], poolL2Txs[0])
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, 0, len(exitInfos)) assert.Equal(t, 0, len(exitInfos))
acc, err := sdb.GetAccount(common.Idx(1)) acc, err := sdb.GetAccount(common.Idx(1))
@ -74,8 +78,8 @@ func TestProcessTxsBatchByBatch(t *testing.T) {
assert.Equal(t, "28", acc.Balance.String()) assert.Equal(t, "28", acc.Balance.String())
// use second batch // use second batch
l2txs = common.PoolL2TxsToL2Txs(poolL2Txs[1])
_, exitInfos, err = sdb.ProcessTxs(true, l1Txs[1], coordinatorL1Txs[1], l2txs)
// l2txs = common.PoolL2TxsToL2Txs(poolL2Txs[1])
_, exitInfos, err = sdb.ProcessTxs(true, true, l1Txs[1], coordinatorL1Txs[1], poolL2Txs[1])
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, 5, len(exitInfos)) assert.Equal(t, 5, len(exitInfos))
acc, err = sdb.GetAccount(common.Idx(1)) acc, err = sdb.GetAccount(common.Idx(1))
@ -83,11 +87,38 @@ func TestProcessTxsBatchByBatch(t *testing.T) {
assert.Equal(t, "48", acc.Balance.String()) assert.Equal(t, "48", acc.Balance.String())
// use third batch // use third batch
l2txs = common.PoolL2TxsToL2Txs(poolL2Txs[2])
_, exitInfos, err = sdb.ProcessTxs(true, l1Txs[2], coordinatorL1Txs[2], l2txs)
// l2txs = common.PoolL2TxsToL2Txs(poolL2Txs[2])
_, exitInfos, err = sdb.ProcessTxs(true, true, l1Txs[2], coordinatorL1Txs[2], poolL2Txs[2])
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, 1, len(exitInfos)) assert.Equal(t, 1, len(exitInfos))
acc, err = sdb.GetAccount(common.Idx(1)) acc, err = sdb.GetAccount(common.Idx(1))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "23", acc.Balance.String()) assert.Equal(t, "23", acc.Balance.String())
} }
func TestZKInputsGeneration(t *testing.T) {
dir, err := ioutil.TempDir("", "tmpdb")
require.Nil(t, err)
sdb, err := NewStateDB(dir, true, 32)
assert.Nil(t, err)
// generate test transactions from test.SetTest0 code
parser := test.NewParser(strings.NewReader(test.SetTest0))
instructions, err := parser.Parse()
assert.Nil(t, err)
l1Txs, coordinatorL1Txs, poolL2Txs := test.GenerateTestTxs(t, instructions)
assert.Equal(t, 29, len(l1Txs[0]))
assert.Equal(t, 0, len(coordinatorL1Txs[0]))
assert.Equal(t, 21, len(poolL2Txs[0]))
zki, _, err := sdb.ProcessTxs(false, true, l1Txs[0], coordinatorL1Txs[0], poolL2Txs[0])
require.Nil(t, err)
s, err := json.Marshal(zki)
require.Nil(t, err)
if debug {
fmt.Println(string(s))
}
}

+ 17
- 0
db/statedb/utils.go

@ -0,0 +1,17 @@
package statedb
import (
ethCommon "github.com/ethereum/go-ethereum/common"
"github.com/hermeznetwork/hermez-node/common"
"github.com/iden3/go-iden3-crypto/babyjub"
)
// TODO
func (s *StateDB) getIdxByEthAddr(addr ethCommon.Address) common.Idx {
return common.Idx(0)
}
// TODO
func (s *StateDB) getIdxByBJJ(pk *babyjub.PublicKey) common.Idx {
return common.Idx(0)
}

Loading…
Cancel
Save