Add methods for ZKInputs IntermStates generation

- Add L1Tx TxCompressedData method
- Add PoolL2Tx TxCompressedDataV2 method
- Update ProcessTxs logic
- Add ZKInputs Intermediate States & Fee parameters calculation
This commit is contained in:
arnaucube
2020-11-18 22:16:33 +01:00
parent bf88eb60b8
commit d3a38a3ee1
10 changed files with 388 additions and 76 deletions

View File

@@ -2,6 +2,7 @@ package statedb
import (
"errors"
"fmt"
"io/ioutil"
"math/big"
"os"
@@ -72,6 +73,10 @@ func (s *StateDB) ProcessTxs(ptc ProcessTxsConfig, coordIdxs []common.Idx, l1use
}
defer s.resetZKInputs()
if len(coordIdxs) >= int(ptc.MaxFeeTx) {
return nil, fmt.Errorf("CoordIdxs (%d) length must be smaller than MaxFeeTx (%d)", len(coordIdxs), ptc.MaxFeeTx)
}
s.accumulatedFees = make(map[common.Idx]*big.Int)
nTx := len(l1usertxs) + len(l1coordinatortxs) + len(l2txs)
@@ -94,7 +99,7 @@ func (s *StateDB) ProcessTxs(ptc ProcessTxsConfig, coordIdxs []common.Idx, l1use
}
// TBD if ExitTree is only in memory or stored in disk, for the moment
// only needed in memory
// is only needed in memory
if s.typ == TypeSynchronizer || s.typ == TypeBatchBuilder {
tmpDir, err := ioutil.TempDir("", "hermez-statedb-exittree")
if err != nil {
@@ -122,17 +127,6 @@ func (s *StateDB) ProcessTxs(ptc ProcessTxsConfig, coordIdxs []common.Idx, l1use
if err != nil {
return nil, err
}
if s.typ == TypeSynchronizer || s.typ == TypeBatchBuilder {
if exitIdx != nil && exitTree != nil {
exits[s.i] = processedExit{
exit: true,
newExit: newExit,
idx: *exitIdx,
acc: *exitAccount,
}
}
s.i++
}
if s.typ == TypeSynchronizer && createdAccount != nil {
createdAccounts = append(createdAccounts, *createdAccount)
}
@@ -143,6 +137,22 @@ func (s *StateDB) ProcessTxs(ptc ProcessTxsConfig, coordIdxs []common.Idx, l1use
return nil, err
}
s.zki.Metadata.L1TxsData = append(s.zki.Metadata.L1TxsData, l1TxData)
if s.i < nTx-1 {
s.zki.ISOutIdx[s.i] = s.idx.BigInt()
s.zki.ISStateRoot[s.i] = s.mt.Root().BigInt()
}
}
if s.typ == TypeSynchronizer || s.typ == TypeBatchBuilder {
if exitIdx != nil && exitTree != nil {
exits[s.i] = processedExit{
exit: true,
newExit: newExit,
idx: *exitIdx,
acc: *exitAccount,
}
}
s.i++
}
}
@@ -164,6 +174,12 @@ func (s *StateDB) ProcessTxs(ptc ProcessTxsConfig, coordIdxs []common.Idx, l1use
return nil, err
}
s.zki.Metadata.L1TxsData = append(s.zki.Metadata.L1TxsData, l1TxData)
if s.i < nTx-1 {
s.zki.ISOutIdx[s.i] = s.idx.BigInt()
s.zki.ISStateRoot[s.i] = s.mt.Root().BigInt()
}
s.i++
}
}
@@ -179,20 +195,49 @@ func (s *StateDB) ProcessTxs(ptc ProcessTxsConfig, coordIdxs []common.Idx, l1use
if err != nil {
return nil, err
}
// collectedFees will contain the amount of fee collected for each
// TokenID
var collectedFees map[common.TokenID]*big.Int
if s.typ == TypeSynchronizer {
if s.typ == TypeSynchronizer || s.typ == TypeBatchBuilder {
collectedFees = make(map[common.TokenID]*big.Int)
for tokenID := range coordIdxsMap {
collectedFees[tokenID] = big.NewInt(0)
}
}
if s.zki != nil {
// get the feePlanTokens
feePlanTokens, err := s.getFeePlanTokens(coordIdxs, l2txs)
if err != nil {
log.Error(err)
return nil, err
}
copy(s.zki.FeePlanTokens, feePlanTokens)
}
// Process L2Txs
for i := 0; i < len(l2txs); i++ {
exitIdx, exitAccount, newExit, err := s.processL2Tx(coordIdxsMap, collectedFees, exitTree, &l2txs[i])
if err != nil {
return nil, err
}
if s.zki != nil {
l2TxData, err := l2txs[i].L2Tx().Bytes(s.zki.Metadata.NLevels)
if err != nil {
return nil, err
}
s.zki.Metadata.L2TxsData = append(s.zki.Metadata.L2TxsData, l2TxData)
if s.i < nTx-1 {
// Intermediate States
s.zki.ISOutIdx[s.i] = s.idx.BigInt()
s.zki.ISStateRoot[s.i] = s.mt.Root().BigInt()
s.zki.ISAccFeeOut[s.i] = formatAccumulatedFees(collectedFees, s.zki.FeePlanTokens)
}
if s.i == nTx-1 {
s.zki.ISFinalAccFee = formatAccumulatedFees(collectedFees, s.zki.FeePlanTokens)
}
}
if s.typ == TypeSynchronizer || s.typ == TypeBatchBuilder {
if exitIdx != nil && exitTree != nil {
exits[s.i] = processedExit{
@@ -204,13 +249,11 @@ func (s *StateDB) ProcessTxs(ptc ProcessTxsConfig, coordIdxs []common.Idx, l1use
}
s.i++
}
if s.zki != nil {
l2TxData, err := l2txs[i].L2Tx().Bytes(s.zki.Metadata.NLevels)
if err != nil {
return nil, err
}
s.zki.Metadata.L2TxsData = append(s.zki.Metadata.L2TxsData, l2TxData)
}
}
if s.zki != nil {
// before computing the Fees txs, set the ISInitStateRootFee
s.zki.ISInitStateRootFee = s.mt.Root().BigInt()
}
// distribute the AccumulatedFees from the processed L2Txs into the
@@ -242,6 +285,8 @@ func (s *StateDB) ProcessTxs(ptc ProcessTxsConfig, coordIdxs []common.Idx, l1use
// add Coord Idx to ZKInputs.FeeTxsData
s.zki.FeeIdxs[iFee] = idx.BigInt()
s.zki.ISStateRootFee[iFee] = s.mt.Root().BigInt()
}
iFee++
}
@@ -293,6 +338,10 @@ func (s *StateDB) ProcessTxs(ptc ProcessTxsConfig, coordIdxs []common.Idx, l1use
}
s.zki.OldKey2[i] = p.OldKey.BigInt()
s.zki.OldValue2[i] = p.OldValue.BigInt()
if i < nTx-1 {
s.zki.ISExitRoot[i] = exitTree.Root().BigInt()
}
}
}
if s.typ == TypeSynchronizer {
@@ -310,17 +359,9 @@ func (s *StateDB) ProcessTxs(ptc ProcessTxsConfig, coordIdxs []common.Idx, l1use
// compute last ZKInputs parameters
s.zki.GlobalChainID = big.NewInt(0) // TODO, 0: ethereum, this will be get from config file
// zki.FeeIdxs = ? // TODO, this will be get from the config file
tokenIDs, err := s.getTokenIDsBigInt(l1usertxs, l1coordinatortxs, l2txs)
if err != nil {
log.Error(err)
return nil, err
}
s.zki.FeePlanTokens = tokenIDs
s.zki.Metadata.NewStateRootRaw = s.mt.Root()
s.zki.Metadata.NewExitRootRaw = exitTree.Root()
// s.zki.ISInitStateRootFee = s.mt.Root().BigInt()
// return ZKInputs as the BatchBuilder will return it to forge the Batch
return &ProcessTxOutput{
ZKInputs: s.zki,
@@ -331,15 +372,22 @@ func (s *StateDB) ProcessTxs(ptc ProcessTxsConfig, coordIdxs []common.Idx, l1use
}, nil
}
// getTokenIDsBigInt returns the list of TokenIDs in *big.Int format
func (s *StateDB) getTokenIDsBigInt(l1usertxs, l1coordinatortxs []common.L1Tx, l2txs []common.PoolL2Tx) ([]*big.Int, error) {
// getFeePlanTokens returns an array of *big.Int containing a list of tokenIDs
// corresponding to the given CoordIdxs and the processed L2Txs
func (s *StateDB) getFeePlanTokens(coordIdxs []common.Idx, l2txs []common.PoolL2Tx) ([]*big.Int, error) {
// get Coordinator TokenIDs corresponding to the Idxs where the Fees
// will be sent
coordTokenIDs := make(map[common.TokenID]bool)
for i := 0; i < len(coordIdxs); i++ {
acc, err := s.GetAccount(coordIdxs[i])
if err != nil {
log.Errorf("could not get account to determine TokenID of CoordIdx %d not found: %s", coordIdxs[i], err.Error())
return nil, err
}
coordTokenIDs[acc.TokenID] = true
}
tokenIDs := make(map[common.TokenID]bool)
for i := 0; i < len(l1usertxs); i++ {
tokenIDs[l1usertxs[i].TokenID] = true
}
for i := 0; i < len(l1coordinatortxs); i++ {
tokenIDs[l1coordinatortxs[i].TokenID] = true
}
for i := 0; i < len(l2txs); i++ {
// as L2Tx does not have parameter TokenID, get it from the
// AccountsDB (in the StateDB)
@@ -348,7 +396,9 @@ func (s *StateDB) getTokenIDsBigInt(l1usertxs, l1coordinatortxs []common.L1Tx, l
log.Errorf("could not get account to determine TokenID of L2Tx: FromIdx %d not found: %s", l2txs[i].FromIdx, err.Error())
return nil, err
}
tokenIDs[acc.TokenID] = true
if _, ok := coordTokenIDs[acc.TokenID]; ok {
tokenIDs[acc.TokenID] = true
}
}
var tBI []*big.Int
for t := range tokenIDs {
@@ -368,7 +418,12 @@ func (s *StateDB) processL1Tx(exitTree *merkletree.MerkleTree, tx *common.L1Tx)
// ZKInputs
if s.zki != nil {
// Txs
// s.zki.TxCompressedData[s.i] = tx.TxCompressedData() // uncomment once L1Tx.TxCompressedData is ready
var err error
s.zki.TxCompressedData[s.i], err = tx.TxCompressedData()
if err != nil {
log.Error(err)
return nil, nil, false, nil, err
}
s.zki.FromIdx[s.i] = tx.FromIdx.BigInt()
s.zki.ToIdx[s.i] = tx.ToIdx.BigInt()
s.zki.OnChain[s.i] = big.NewInt(1)
@@ -408,11 +463,6 @@ func (s *StateDB) processL1Tx(exitTree *merkletree.MerkleTree, tx *common.L1Tx)
// TODO applyCreateAccount will return the created account,
// which in the case type==TypeSynchronizer will be added to an
// array of created accounts that will be returned
if s.zki != nil {
s.zki.AuxFromIdx[s.i] = s.idx.BigInt() // last s.idx is the one used for creating the new account
s.zki.NewAccount[s.i] = big.NewInt(1)
}
case common.TxTypeDeposit:
// update balance of the MT account
err := s.applyDeposit(tx, false)
@@ -436,11 +486,6 @@ func (s *StateDB) processL1Tx(exitTree *merkletree.MerkleTree, tx *common.L1Tx)
log.Error(err)
return nil, nil, false, nil, err
}
if s.zki != nil {
s.zki.AuxFromIdx[s.i] = s.idx.BigInt() // last s.idx is the one used for creating the new account
s.zki.NewAccount[s.i] = big.NewInt(1)
}
case common.TxTypeForceExit:
// execute exit flow
// coordIdxsMap is 'nil', as at L1Txs there is no L2 fees
@@ -485,8 +530,14 @@ func (s *StateDB) processL2Tx(coordIdxsMap map[common.TokenID]common.Idx, collec
// ZKInputs
if s.zki != nil {
// Txs
// s.zki.TxCompressedData[s.i] = tx.TxCompressedData() // uncomment once L1Tx.TxCompressedData is ready
// s.zki.TxCompressedDataV2[s.i] = tx.TxCompressedDataV2() // uncomment once L2Tx.TxCompressedDataV2 is ready
s.zki.TxCompressedData[s.i], err = tx.TxCompressedData()
if err != nil {
return nil, nil, false, err
}
s.zki.TxCompressedDataV2[s.i], err = tx.TxCompressedDataV2()
if err != nil {
return nil, nil, false, err
}
s.zki.FromIdx[s.i] = tx.FromIdx.BigInt()
s.zki.ToIdx[s.i] = tx.ToIdx.BigInt()
@@ -587,6 +638,14 @@ func (s *StateDB) applyCreateAccount(tx *common.L1Tx) error {
s.zki.OldValue1[s.i] = p.OldValue.BigInt()
s.zki.Metadata.NewLastIdxRaw = s.idx + 1
s.zki.AuxFromIdx[s.i] = common.Idx(s.idx + 1).BigInt()
s.zki.NewAccount[s.i] = big.NewInt(1)
if s.i < len(s.zki.ISOnChain) { // len(s.zki.ISOnChain) == nTx
// intermediate states
s.zki.ISOnChain[s.i] = big.NewInt(1)
}
}
s.idx = s.idx + 1
@@ -695,7 +754,7 @@ func (s *StateDB) applyTransfer(coordIdxsMap map[common.TokenID]common.Idx, coll
accumulated := s.accumulatedFees[accCoord.Idx]
accumulated.Add(accumulated, fee)
if s.typ == TypeSynchronizer {
if s.typ == TypeSynchronizer || s.typ == TypeBatchBuilder {
collected := collectedFees[accCoord.TokenID]
collected.Add(collected, fee)
}
@@ -800,6 +859,12 @@ func (s *StateDB) applyCreateAccountDepositTransfer(tx *common.L1Tx) error {
s.zki.OldValue1[s.i] = p.OldValue.BigInt()
s.zki.Metadata.NewLastIdxRaw = s.idx + 1
s.zki.AuxFromIdx[s.i] = common.Idx(s.idx + 1).BigInt()
s.zki.NewAccount[s.i] = big.NewInt(1)
// intermediate states
s.zki.ISOnChain[s.i] = big.NewInt(1)
}
// update receiver account in localStateDB
@@ -854,7 +919,7 @@ func (s *StateDB) applyExit(coordIdxsMap map[common.TokenID]common.Idx, collecte
accumulated := s.accumulatedFees[accCoord.Idx]
accumulated.Add(accumulated, fee)
if s.typ == TypeSynchronizer {
if s.typ == TypeSynchronizer || s.typ == TypeBatchBuilder {
collected := collectedFees[accCoord.TokenID]
collected.Add(collected, fee)
}

View File

@@ -477,6 +477,80 @@ func TestProcessTxsRootTestVectors(t *testing.T) {
assert.Equal(t, "9827704113668630072730115158977131501210702363656902211840117643154933433410", sdb.mt.Root().BigInt().String())
}
func TestCircomTest(t *testing.T) {
dir, err := ioutil.TempDir("", "tmpdb")
require.Nil(t, err)
defer assert.Nil(t, os.RemoveAll(dir))
sdb, err := NewStateDB(dir, TypeBatchBuilder, 8)
assert.Nil(t, err)
// same values than in the js test
bjj0, err := common.BJJFromStringWithChecksum("21b0a1688b37f77b1d1d5539ec3b826db5ac78b2513f574a04c50a7d4f8246d7")
assert.Nil(t, err)
l1Txs := []common.L1Tx{
{
FromIdx: 0,
// LoadAmount: big.NewInt(10400),
LoadAmount: big.NewInt(16000000),
Amount: big.NewInt(0),
TokenID: 1,
FromBJJ: bjj0,
FromEthAddr: ethCommon.HexToAddress("0x7e5f4552091a69125d5dfcb7b8c2659029395bdf"),
ToIdx: 0,
Type: common.TxTypeCreateAccountDeposit,
},
}
l2Txs := []common.PoolL2Tx{
{
FromIdx: 256,
ToIdx: 256,
TokenID: 1,
Amount: big.NewInt(1000),
Nonce: 0,
Fee: 126,
Type: common.TxTypeTransfer,
},
}
ptc := ProcessTxsConfig{
NLevels: 8,
MaxFeeTx: 2,
MaxTx: 5,
MaxL1Tx: 2,
}
ptOut, err := sdb.ProcessTxs(ptc, nil, l1Txs, nil, l2Txs)
require.Nil(t, err)
// check expected account keys values from tx inputs
acc, err := sdb.GetAccount(common.Idx(256))
require.Nil(t, err)
assert.Equal(t, "d746824f7d0ac5044a573f51b278acb56d823bec39551d1d7bf7378b68a1b021", acc.PublicKey.Compress().String())
assert.Equal(t, "0x7E5F4552091A69125d5DfCb7b8C2659029395Bdf", acc.EthAddr.Hex())
// check that there no exist more accounts
_, err = sdb.GetAccount(common.Idx(257))
require.NotNil(t, err)
ptOut.ZKInputs.FeeIdxs[0] = common.Idx(256).BigInt()
s, err := json.Marshal(ptOut.ZKInputs)
require.Nil(t, err)
debug := false
if debug {
fmt.Println("\nCopy&Paste into js circom test:\n let zkInput = JSON.parse(`" + string(s) + "`);")
h, err := ptOut.ZKInputs.HashGlobalData()
require.Nil(t, err)
fmt.Printf(`
const output={
hashGlobalInputs: "%s",
};
await circuit.assertOut(w, output);
`, h.String())
fmt.Println("")
}
}
func TestZKInputsHashTestVector0(t *testing.T) {
dir, err := ioutil.TempDir("", "tmpdb")
require.Nil(t, err)

View File

@@ -160,3 +160,19 @@ func BJJCompressedTo256BigInts(pkComp babyjub.PublicKeyComp) [256]*big.Int {
return r
}
// formatAccumulatedFees returns an array of [nFeeAccounts]*big.Int containing
// the balance of each FeeAccount, taken from the 'collectedFees' map, in the
// order of the 'orderTokenIDs'
func formatAccumulatedFees(collectedFees map[common.TokenID]*big.Int, orderTokenIDs []*big.Int) []*big.Int {
accFeeOut := make([]*big.Int, len(orderTokenIDs))
for i := 0; i < len(orderTokenIDs); i++ {
tokenID := common.TokenIDFromBigInt(orderTokenIDs[i])
if _, ok := collectedFees[tokenID]; ok {
accFeeOut[i] = collectedFees[tokenID]
} else {
accFeeOut[i] = big.NewInt(0)
}
}
return accFeeOut
}