diff --git a/txprocessor/txprocessor.go b/txprocessor/txprocessor.go index 38de86a..7e5fcee 100644 --- a/txprocessor/txprocessor.go +++ b/txprocessor/txprocessor.go @@ -233,6 +233,22 @@ func (tp *TxProcessor) ProcessTxs(coordIdxs []common.Idx, l1usertxs, l1coordinat } } + // remove repeated CoordIdxs that are for the same TokenID (use the + // first occurrence) + usedCoordTokenIDs := make(map[common.TokenID]bool) + var filteredCoordIdxs []common.Idx + for i := 0; i < len(coordIdxs); i++ { + accCoord, err := tp.s.GetAccount(coordIdxs[i]) + if err != nil { + return nil, tracerr.Wrap(err) + } + if !usedCoordTokenIDs[accCoord.TokenID] { + usedCoordTokenIDs[accCoord.TokenID] = true + filteredCoordIdxs = append(filteredCoordIdxs, coordIdxs[i]) + } + } + coordIdxs = filteredCoordIdxs + tp.AccumulatedFees = make(map[common.Idx]*big.Int) for _, idx := range coordIdxs { tp.AccumulatedFees[idx] = big.NewInt(0) diff --git a/txprocessor/txprocessor_test.go b/txprocessor/txprocessor_test.go index 1386849..9d9ec48 100644 --- a/txprocessor/txprocessor_test.go +++ b/txprocessor/txprocessor_test.go @@ -25,6 +25,12 @@ func checkBalance(t *testing.T, tc *til.Context, sdb *statedb.StateDB, username assert.Equal(t, expected, acc.Balance.String()) } +func checkBalanceByIdx(t *testing.T, sdb *statedb.StateDB, idx common.Idx, expected string) { + acc, err := sdb.GetAccount(idx) + require.NoError(t, err) + assert.Equal(t, expected, acc.Balance.String()) +} + func TestComputeEffectiveAmounts(t *testing.T) { dir, err := ioutil.TempDir("", "tmpdb") require.NoError(t, err) @@ -688,3 +694,98 @@ func TestCreateAccountDepositMaxValue(t *testing.T) { require.NoError(t, err) assert.Equal(t, daMax1BI, acc.Balance) } + +func initTestMultipleCoordIdxForTokenID(t *testing.T) (*TxProcessor, *til.Context, []common.BlockData) { + dir, err := ioutil.TempDir("", "tmpdb") + require.NoError(t, err) + defer assert.NoError(t, os.RemoveAll(dir)) + + sdb, err := statedb.NewStateDB(dir, 128, statedb.TypeBatchBuilder, 32) + assert.NoError(t, err) + + chainID := uint16(1) + + // generate test transactions from test.SetBlockchain0 code + tc := til.NewContext(chainID, common.RollupConstMaxL1UserTx) + + set := ` + Type: Blockchain + + CreateAccountDeposit(0) A: 200 + + > batchL1 // freeze L1User{1} + + CreateAccountCoordinator(0) Coord + CreateAccountCoordinator(0) B + + Transfer(0) A-B: 100 (126) + + > batchL1 // forge L1User{1}, forge L1Coord{4}, forge L2{2} + > block + ` + blocks, err := tc.GenerateBlocks(set) + require.NoError(t, err) + + config := Config{ + NLevels: 32, + MaxFeeTx: 64, + MaxTx: 512, + MaxL1Tx: 16, + ChainID: chainID, + } + tp := NewTxProcessor(sdb, config) + // batch1 + _, err = tp.ProcessTxs(nil, nil, nil, nil) // to simulate the first batch from the Til set + require.NoError(t, err) + + return tp, tc, blocks +} + +func TestMultipleCoordIdxForTokenID(t *testing.T) { + // Check that ProcessTxs always uses the first occurrence of the + // CoordIdx for each TokenID + + coordIdxs := []common.Idx{257, 257, 257} + tp, tc, blocks := initTestMultipleCoordIdxForTokenID(t) + l1UserTxs := til.L1TxsToCommonL1Txs(tc.Queues[*blocks[0].Rollup.Batches[1].Batch.ForgeL1TxsNum]) + l1CoordTxs := blocks[0].Rollup.Batches[1].L1CoordinatorTxs + l1CoordTxs = append(l1CoordTxs, l1CoordTxs[0]) // duplicate the CoordAccount for TokenID=0 + l2Txs := common.L2TxsToPoolL2Txs(blocks[0].Rollup.Batches[1].L2Txs) + _, err := tp.ProcessTxs(coordIdxs, l1UserTxs, l1CoordTxs, l2Txs) + require.NoError(t, err) + + checkBalanceByIdx(t, tp.s, 256, "90") // A + checkBalanceByIdx(t, tp.s, 257, "10") // Coord0 + checkBalanceByIdx(t, tp.s, 258, "100") // B + checkBalanceByIdx(t, tp.s, 259, "0") // Coord0 + + // reset StateDB values + coordIdxs = []common.Idx{259, 257} + tp, tc, blocks = initTestMultipleCoordIdxForTokenID(t) + l1UserTxs = til.L1TxsToCommonL1Txs(tc.Queues[*blocks[0].Rollup.Batches[1].Batch.ForgeL1TxsNum]) + l1CoordTxs = blocks[0].Rollup.Batches[1].L1CoordinatorTxs + l1CoordTxs = append(l1CoordTxs, l1CoordTxs[0]) // duplicate the CoordAccount for TokenID=0 + l2Txs = common.L2TxsToPoolL2Txs(blocks[0].Rollup.Batches[1].L2Txs) + _, err = tp.ProcessTxs(coordIdxs, l1UserTxs, l1CoordTxs, l2Txs) + require.NoError(t, err) + + checkBalanceByIdx(t, tp.s, 256, "90") // A + checkBalanceByIdx(t, tp.s, 257, "0") // Coord0 + checkBalanceByIdx(t, tp.s, 258, "100") // B + checkBalanceByIdx(t, tp.s, 259, "10") // Coord0 + + // reset StateDB values + coordIdxs = []common.Idx{257, 259} + tp, tc, blocks = initTestMultipleCoordIdxForTokenID(t) + l1UserTxs = til.L1TxsToCommonL1Txs(tc.Queues[*blocks[0].Rollup.Batches[1].Batch.ForgeL1TxsNum]) + l1CoordTxs = blocks[0].Rollup.Batches[1].L1CoordinatorTxs + l1CoordTxs = append(l1CoordTxs, l1CoordTxs[0]) // duplicate the CoordAccount for TokenID=0 + l2Txs = common.L2TxsToPoolL2Txs(blocks[0].Rollup.Batches[1].L2Txs) + _, err = tp.ProcessTxs(coordIdxs, l1UserTxs, l1CoordTxs, l2Txs) + require.NoError(t, err) + + checkBalanceByIdx(t, tp.s, 256, "90") // A + checkBalanceByIdx(t, tp.s, 257, "10") // Coord0 + checkBalanceByIdx(t, tp.s, 258, "100") // B + checkBalanceByIdx(t, tp.s, 259, "0") // Coord0 +}