From aa0bde61d223fd2894f6a2256c0b6272a6b5f1d8 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Wed, 9 Sep 2020 14:40:55 +0200 Subject: [PATCH] Add StateDB ZKInputs generation for L1 & L2 Txs --- batchbuilder/batchbuilder.go | 4 +- common/l2tx.go | 25 +++++ coordinator/coordinator.go | 3 +- db/statedb/statedb.go | 2 + db/statedb/txprocessors.go | 167 ++++++++++++++++++++++++++++++-- db/statedb/txprocessors_test.go | 47 +++++++-- db/statedb/utils.go | 17 ++++ 7 files changed, 245 insertions(+), 20 deletions(-) create mode 100644 db/statedb/utils.go diff --git a/batchbuilder/batchbuilder.go b/batchbuilder/batchbuilder.go index 0f864e5..35c5ad3 100644 --- a/batchbuilder/batchbuilder.go +++ b/batchbuilder/batchbuilder.go @@ -51,8 +51,8 @@ func (bb *BatchBuilder) Reset(batchNum common.BatchNum, fromSynchronizer bool) e } // BuildBatch takes the transactions and returns the common.ZKInputs of the next batch -func (bb *BatchBuilder) BuildBatch(configBatch *ConfigBatch, l1usertxs, l1coordinatortxs []*common.L1Tx, l2txs []*common.L2Tx, tokenIDs []common.TokenID) (*common.ZKInputs, error) { - zkInputs, _, err := bb.localStateDB.ProcessTxs(false, l1usertxs, l1coordinatortxs, l2txs) +func (bb *BatchBuilder) BuildBatch(configBatch *ConfigBatch, l1usertxs, l1coordinatortxs []*common.L1Tx, pooll2txs []*common.PoolL2Tx, tokenIDs []common.TokenID) (*common.ZKInputs, error) { + zkInputs, _, err := bb.localStateDB.ProcessTxs(false, true, l1usertxs, l1coordinatortxs, pooll2txs) if err != nil { return nil, err } diff --git a/common/l2tx.go b/common/l2tx.go index 44a0283..3d49694 100644 --- a/common/l2tx.go +++ b/common/l2tx.go @@ -30,3 +30,28 @@ func (tx *L2Tx) Tx() *Tx { Type: tx.Type, } } + +// PoolL2Tx returns the data structure of PoolL2Tx with the parameters of a +// L2Tx filled +func (tx *L2Tx) PoolL2Tx() *PoolL2Tx { + return &PoolL2Tx{ + TxID: tx.TxID, + BatchNum: tx.BatchNum, + FromIdx: tx.FromIdx, + ToIdx: tx.ToIdx, + Amount: tx.Amount, + Fee: tx.Fee, + Nonce: tx.Nonce, + Type: tx.Type, + } +} + +// L2TxsToPoolL2Txs returns an array of []*PoolL2Tx from an array of []*L2Tx, +// where the PoolL2Tx only have the parameters of a L2Tx filled. +func L2TxsToPoolL2Txs(txs []*L2Tx) []*PoolL2Tx { + var r []*PoolL2Tx + for _, tx := range txs { + r = append(r, tx.PoolL2Tx()) + } + return r +} diff --git a/coordinator/coordinator.go b/coordinator/coordinator.go index a2d1d1b..85b6a98 100644 --- a/coordinator/coordinator.go +++ b/coordinator/coordinator.go @@ -221,8 +221,7 @@ func (c *Coordinator) forgeSequence() error { configBatch := &batchbuilder.ConfigBatch{ ForgerAddress: c.config.ForgerAddress, } - l2Txs := common.PoolL2TxsToL2Txs(poolL2Txs) - zkInputs, err := c.batchBuilder.BuildBatch(configBatch, l1UserTxsExtra, l1OperatorTxs, l2Txs, nil) // TODO []common.TokenID --> feesInfo + zkInputs, err := c.batchBuilder.BuildBatch(configBatch, l1UserTxsExtra, l1OperatorTxs, poolL2Txs, nil) // TODO []common.TokenID --> feesInfo if err != nil { return err } diff --git a/db/statedb/statedb.go b/db/statedb/statedb.go index 5d69494..e5e83ff 100644 --- a/db/statedb/statedb.go +++ b/db/statedb/statedb.go @@ -43,6 +43,8 @@ type StateDB struct { mt *merkletree.MerkleTree // idx holds the current Idx that the BatchBuilder is using idx common.Idx + zki *common.ZKInputs + i int // i is used for zki } // NewStateDB creates a new StateDB, allowing to use an in-memory or in-disk diff --git a/db/statedb/txprocessors.go b/db/statedb/txprocessors.go index 280942c..d3ea768 100644 --- a/db/statedb/txprocessors.go +++ b/db/statedb/txprocessors.go @@ -1,8 +1,12 @@ package statedb import ( + "bytes" + "errors" + "fmt" "math/big" + ethCommon "github.com/ethereum/go-ethereum/common" "github.com/hermeznetwork/hermez-node/common" "github.com/iden3/go-iden3-crypto/poseidon" "github.com/iden3/go-merkletree" @@ -10,19 +14,42 @@ import ( "github.com/iden3/go-merkletree/db/memory" ) -// keyidx is used as key in the db to store the current Idx -var keyidx = []byte("idx") +var ( + // keyidx is used as key in the db to store the current Idx + keyidx = []byte("idx") + + ffAddr = ethCommon.HexToAddress("0xffffffffffffffffffffffffffffffffffffffff") +) + +func (s *StateDB) resetZKInputs() { + s.zki = nil + s.i = 0 +} // ProcessTxs process the given L1Txs & L2Txs applying the needed updates to // the StateDB depending on the transaction Type. Returns the common.ZKInputs // to generate the SnarkProof later used by the BatchBuilder, and if // cmpExitTree is set to true, returns common.ExitTreeLeaf that is later used // by the Synchronizer to update the HistoryDB. -func (s *StateDB) ProcessTxs(cmpExitTree bool, l1usertxs, l1coordinatortxs []*common.L1Tx, l2txs []*common.L2Tx) (*common.ZKInputs, []*common.ExitInfo, error) { +func (s *StateDB) ProcessTxs(cmpExitTree, cmpZKInputs bool, l1usertxs, l1coordinatortxs []*common.L1Tx, l2txs []*common.PoolL2Tx) (*common.ZKInputs, []*common.ExitInfo, error) { var err error var exitTree *merkletree.MerkleTree exits := make(map[common.Idx]common.Account) + if s.zki != nil { + return nil, nil, errors.New("Expected StateDB.zki==nil, something went wrong ans is not empty") + } + defer s.resetZKInputs() + + nTx := len(l1usertxs) + len(l1coordinatortxs) + len(l2txs) + if nTx == 0 { + return nil, nil, nil // TBD if return an error in the case of no Txs to process + } + + if cmpZKInputs { + s.zki = common.NewZKInputs(nTx, 24, 32) // TODO this values will be parameters of the function + } + // TBD if ExitTree is only in memory or stored in disk, for the moment // only needed in memory exitTree, err = merkletree.NewMerkleTree(memory.NewMemoryStorage(), s.mt.MaxLevels()) @@ -30,7 +57,8 @@ func (s *StateDB) ProcessTxs(cmpExitTree bool, l1usertxs, l1coordinatortxs []*co return nil, nil, err } - for _, tx := range l1coordinatortxs { + // assumption: l1usertx are sorted by L1Tx.Position + for _, tx := range l1usertxs { exitIdx, exitAccount, err := s.processL1Tx(exitTree, tx) if err != nil { return nil, nil, err @@ -38,8 +66,11 @@ func (s *StateDB) ProcessTxs(cmpExitTree bool, l1usertxs, l1coordinatortxs []*co if exitIdx != nil && cmpExitTree { exits[*exitIdx] = *exitAccount } + if s.zki != nil { + s.i++ + } } - for _, tx := range l1usertxs { + for _, tx := range l1coordinatortxs { exitIdx, exitAccount, err := s.processL1Tx(exitTree, tx) if err != nil { return nil, nil, err @@ -47,6 +78,9 @@ func (s *StateDB) ProcessTxs(cmpExitTree bool, l1usertxs, l1coordinatortxs []*co if exitIdx != nil && cmpExitTree { exits[*exitIdx] = *exitAccount } + if s.zki != nil { + s.i++ + } } for _, tx := range l2txs { exitIdx, exitAccount, err := s.processL2Tx(exitTree, tx) @@ -56,9 +90,12 @@ func (s *StateDB) ProcessTxs(cmpExitTree bool, l1usertxs, l1coordinatortxs []*co if exitIdx != nil && cmpExitTree { exits[*exitIdx] = *exitAccount } + if s.zki != nil { + s.i++ + } } - if !cmpExitTree { + if !cmpExitTree && !cmpZKInputs { return nil, nil, nil } @@ -93,15 +130,78 @@ func (s *StateDB) ProcessTxs(cmpExitTree bool, l1usertxs, l1coordinatortxs []*co } exitInfos = append(exitInfos, ei) } + if !cmpZKInputs { + return nil, exitInfos, nil + } + + // compute last ZKInputs parameters + s.zki.OldLastIdx = (s.idx - 1).BigInt() + s.zki.OldStateRoot = s.mt.Root().BigInt() + s.zki.GlobalChainID = big.NewInt(0) // TODO, 0: ethereum, get this from config file + // zki.FeeIdxs = ? // TODO, this will be get from the config file + tokenIDs, err := s.getTokenIDsBigInt(l1usertxs, l1coordinatortxs, l2txs) + if err != nil { + return nil, nil, err + } + s.zki.FeePlanTokens = tokenIDs + + // s.zki.ISInitStateRootFee = s.mt.Root().BigInt() + // compute fees + + // once fees are computed // return exitInfos, so Synchronizer will be able to store it into // HistoryDB for the concrete BatchNum - return nil, exitInfos, nil + return s.zki, exitInfos, nil +} + +// getTokenIDsBigInt returns the list of TokenIDs in *big.Int format +func (s *StateDB) getTokenIDsBigInt(l1usertxs, l1coordinatortxs []*common.L1Tx, l2txs []*common.PoolL2Tx) ([]*big.Int, error) { + tokenIDs := make(map[common.TokenID]bool) + for i := 0; i < len(l1usertxs); i++ { + tokenIDs[l1usertxs[i].TokenID] = true + } + for i := 0; i < len(l1coordinatortxs); i++ { + tokenIDs[l1coordinatortxs[i].TokenID] = true + } + for i := 0; i < len(l2txs); i++ { + // as L2Tx does not have parameter TokenID, get it from the + // AccountsDB (in the StateDB) + acc, err := s.GetAccount(l2txs[i].ToIdx) + if err != nil { + return nil, err + } + tokenIDs[acc.TokenID] = true + } + var tBI []*big.Int + for t := range tokenIDs { + tBI = append(tBI, t.BigInt()) + } + return tBI, nil } // processL1Tx process the given L1Tx applying the needed updates to the // StateDB depending on the transaction Type. func (s *StateDB) processL1Tx(exitTree *merkletree.MerkleTree, tx *common.L1Tx) (*common.Idx, *common.Account, error) { + // ZKInputs + if s.zki != nil { + // Txs + // s.zki.TxCompressedData[s.i] = tx.TxCompressedData() // uncomment once L1Tx.TxCompressedData is ready + s.zki.FromIdx[s.i] = tx.FromIdx.BigInt() + s.zki.ToIdx[s.i] = tx.ToIdx.BigInt() + s.zki.OnChain[s.i] = big.NewInt(1) + + // L1Txs + s.zki.LoadAmountF[s.i] = tx.LoadAmount + s.zki.FromEthAddr[s.i] = common.EthAddrToBigInt(tx.FromEthAddr) + if tx.FromBJJ != nil { + s.zki.FromBJJCompressed[s.i] = common.BJJCompressedTo256BigInts(tx.FromBJJ.Compress()) + } + + // Intermediate States + s.zki.ISOnChain[s.i] = big.NewInt(1) + } + switch tx.Type { case common.TxTypeForceTransfer, common.TxTypeTransfer: // go to the MT account of sender and receiver, and update balance @@ -116,6 +216,11 @@ func (s *StateDB) processL1Tx(exitTree *merkletree.MerkleTree, tx *common.L1Tx) if err != nil { return nil, nil, err } + + if s.zki != nil { + s.zki.AuxFromIdx[s.i] = s.idx.BigInt() // last s.idx is the one used for creating the new account + s.zki.NewAccount[s.i] = big.NewInt(1) + } case common.TxTypeDeposit: // update balance of the MT account err := s.applyDeposit(tx, false) @@ -140,6 +245,11 @@ func (s *StateDB) processL1Tx(exitTree *merkletree.MerkleTree, tx *common.L1Tx) if err != nil { return nil, nil, err } + + if s.zki != nil { + s.zki.AuxFromIdx[s.i] = s.idx.BigInt() // last s.idx is the one used for creating the new account + s.zki.NewAccount[s.i] = big.NewInt(1) + } case common.TxTypeExit: // execute exit flow exitAccount, err := s.applyExit(exitTree, tx.Tx()) @@ -155,7 +265,48 @@ func (s *StateDB) processL1Tx(exitTree *merkletree.MerkleTree, tx *common.L1Tx) // processL2Tx process the given L2Tx applying the needed updates to // the StateDB depending on the transaction Type. -func (s *StateDB) processL2Tx(exitTree *merkletree.MerkleTree, tx *common.L2Tx) (*common.Idx, *common.Account, error) { +func (s *StateDB) processL2Tx(exitTree *merkletree.MerkleTree, tx *common.PoolL2Tx) (*common.Idx, *common.Account, error) { + // ZKInputs + if s.zki != nil { + // Txs + // s.zki.TxCompressedData[s.i] = tx.TxCompressedData() // uncomment once L1Tx.TxCompressedData is ready + // s.zki.TxCompressedDataV2[s.i] = tx.TxCompressedDataV2() // uncomment once L2Tx.TxCompressedDataV2 is ready + s.zki.FromIdx[s.i] = tx.FromIdx.BigInt() + s.zki.ToIdx[s.i] = tx.ToIdx.BigInt() + + // fill AuxToIdx if needed + if tx.ToIdx == common.Idx(0) { + // Idx not set in the Tx, get it from DB through ToEthAddr or ToBJJ + var idx common.Idx + if !bytes.Equal(tx.ToEthAddr.Bytes(), ffAddr.Bytes()) { + idx = s.getIdxByEthAddr(tx.ToEthAddr) + if idx == common.Idx(0) { + return nil, nil, fmt.Errorf("Idx can not be found for given tx.FromEthAddr") + } + } else { + idx = s.getIdxByBJJ(tx.ToBJJ) + if idx == common.Idx(0) { + return nil, nil, fmt.Errorf("Idx can not be found for given tx.FromBJJ") + } + } + s.zki.AuxToIdx[s.i] = idx.BigInt() + } + s.zki.ToBJJAy[s.i] = tx.ToBJJ.Y + s.zki.ToEthAddr[s.i] = common.EthAddrToBigInt(tx.ToEthAddr) + + s.zki.OnChain[s.i] = big.NewInt(0) + s.zki.NewAccount[s.i] = big.NewInt(0) + + // L2Txs + // s.zki.RqOffset[s.i] = // TODO + // s.zki.RqTxCompressedDataV2[s.i] = // TODO + // s.zki.RqToEthAddr[s.i] = common.EthAddrToBigInt(tx.RqToEthAddr) // TODO + // s.zki.RqToBJJAy[s.i] = tx.ToBJJ.Y // TODO + s.zki.S[s.i] = tx.Signature.S + s.zki.R8x[s.i] = tx.Signature.R8.X + s.zki.R8y[s.i] = tx.Signature.R8.Y + } + switch tx.Type { case common.TxTypeTransfer: // go to the MT account of sender and receiver, and update diff --git a/db/statedb/txprocessors_test.go b/db/statedb/txprocessors_test.go index 5a9dd06..3e95d92 100644 --- a/db/statedb/txprocessors_test.go +++ b/db/statedb/txprocessors_test.go @@ -1,6 +1,8 @@ package statedb import ( + "encoding/json" + "fmt" "io/ioutil" "strings" "testing" @@ -11,6 +13,8 @@ import ( "github.com/stretchr/testify/require" ) +var debug = false + func TestProcessTxs(t *testing.T) { dir, err := ioutil.TempDir("", "tmpdb") require.Nil(t, err) @@ -30,9 +34,9 @@ func TestProcessTxs(t *testing.T) { // iterate for each batch for i := 0; i < len(l1Txs); i++ { - l2Txs := common.PoolL2TxsToL2Txs(poolL2Txs[i]) + // l2Txs := common.PoolL2TxsToL2Txs(poolL2Txs[i]) - _, _, err := sdb.ProcessTxs(true, l1Txs[i], coordinatorL1Txs[i], l2Txs) + _, _, err := sdb.ProcessTxs(true, true, l1Txs[i], coordinatorL1Txs[i], poolL2Txs[i]) require.Nil(t, err) } @@ -65,8 +69,8 @@ func TestProcessTxsBatchByBatch(t *testing.T) { assert.Equal(t, 7, len(poolL2Txs[2])) // use first batch - l2txs := common.PoolL2TxsToL2Txs(poolL2Txs[0]) - _, exitInfos, err := sdb.ProcessTxs(true, l1Txs[0], coordinatorL1Txs[0], l2txs) + // l2txs := common.PoolL2TxsToL2Txs(poolL2Txs[0]) + _, exitInfos, err := sdb.ProcessTxs(true, true, l1Txs[0], coordinatorL1Txs[0], poolL2Txs[0]) require.Nil(t, err) assert.Equal(t, 0, len(exitInfos)) acc, err := sdb.GetAccount(common.Idx(1)) @@ -74,8 +78,8 @@ func TestProcessTxsBatchByBatch(t *testing.T) { assert.Equal(t, "28", acc.Balance.String()) // use second batch - l2txs = common.PoolL2TxsToL2Txs(poolL2Txs[1]) - _, exitInfos, err = sdb.ProcessTxs(true, l1Txs[1], coordinatorL1Txs[1], l2txs) + // l2txs = common.PoolL2TxsToL2Txs(poolL2Txs[1]) + _, exitInfos, err = sdb.ProcessTxs(true, true, l1Txs[1], coordinatorL1Txs[1], poolL2Txs[1]) require.Nil(t, err) assert.Equal(t, 5, len(exitInfos)) acc, err = sdb.GetAccount(common.Idx(1)) @@ -83,11 +87,38 @@ func TestProcessTxsBatchByBatch(t *testing.T) { assert.Equal(t, "48", acc.Balance.String()) // use third batch - l2txs = common.PoolL2TxsToL2Txs(poolL2Txs[2]) - _, exitInfos, err = sdb.ProcessTxs(true, l1Txs[2], coordinatorL1Txs[2], l2txs) + // l2txs = common.PoolL2TxsToL2Txs(poolL2Txs[2]) + _, exitInfos, err = sdb.ProcessTxs(true, true, l1Txs[2], coordinatorL1Txs[2], poolL2Txs[2]) require.Nil(t, err) assert.Equal(t, 1, len(exitInfos)) acc, err = sdb.GetAccount(common.Idx(1)) assert.Nil(t, err) assert.Equal(t, "23", acc.Balance.String()) } + +func TestZKInputsGeneration(t *testing.T) { + dir, err := ioutil.TempDir("", "tmpdb") + require.Nil(t, err) + + sdb, err := NewStateDB(dir, true, 32) + assert.Nil(t, err) + + // generate test transactions from test.SetTest0 code + parser := test.NewParser(strings.NewReader(test.SetTest0)) + instructions, err := parser.Parse() + assert.Nil(t, err) + + l1Txs, coordinatorL1Txs, poolL2Txs := test.GenerateTestTxs(t, instructions) + assert.Equal(t, 29, len(l1Txs[0])) + assert.Equal(t, 0, len(coordinatorL1Txs[0])) + assert.Equal(t, 21, len(poolL2Txs[0])) + + zki, _, err := sdb.ProcessTxs(false, true, l1Txs[0], coordinatorL1Txs[0], poolL2Txs[0]) + require.Nil(t, err) + + s, err := json.Marshal(zki) + require.Nil(t, err) + if debug { + fmt.Println(string(s)) + } +} diff --git a/db/statedb/utils.go b/db/statedb/utils.go new file mode 100644 index 0000000..ce46240 --- /dev/null +++ b/db/statedb/utils.go @@ -0,0 +1,17 @@ +package statedb + +import ( + ethCommon "github.com/ethereum/go-ethereum/common" + "github.com/hermeznetwork/hermez-node/common" + "github.com/iden3/go-iden3-crypto/babyjub" +) + +// TODO +func (s *StateDB) getIdxByEthAddr(addr ethCommon.Address) common.Idx { + return common.Idx(0) +} + +// TODO +func (s *StateDB) getIdxByBJJ(pk *babyjub.PublicKey) common.Idx { + return common.Idx(0) +}