diff --git a/db/statedb/statedb.go b/db/statedb/statedb.go index d1c451c..f0ce3dd 100644 --- a/db/statedb/statedb.go +++ b/db/statedb/statedb.go @@ -372,7 +372,7 @@ func (s *StateDB) CreateAccount(idx common.Idx, account *common.Account) (*merkl return cpp, err } // store idx by EthAddr & BJJ - err = s.setIdxByEthAddrBJJ(idx, account.EthAddr, account.PublicKey) + err = s.setIdxByEthAddrBJJ(idx, account.EthAddr, account.PublicKey, account.TokenID) return cpp, err } diff --git a/db/statedb/txprocessors.go b/db/statedb/txprocessors.go index cc2f182..efcd324 100644 --- a/db/statedb/txprocessors.go +++ b/db/statedb/txprocessors.go @@ -378,7 +378,7 @@ func (s *StateDB) processL2Tx(coordIdxsMap map[common.TokenID]common.Idx, exitTr // if tx.ToIdx==0, get toIdx by ToEthAddr or ToBJJ if tx.ToIdx == common.Idx(0) && tx.AuxToIdx == common.Idx(0) { // case when tx.Type== common.TxTypeTransferToEthAddr or common.TxTypeTransferToBJJ - tx.AuxToIdx, err = s.GetIdxByEthAddrBJJ(tx.ToEthAddr, tx.ToBJJ) + tx.AuxToIdx, err = s.GetIdxByEthAddrBJJ(tx.ToEthAddr, tx.ToBJJ, tx.TokenID) if err != nil { log.Error(err) return nil, nil, false, err diff --git a/db/statedb/txprocessors_test.go b/db/statedb/txprocessors_test.go index 10498fc..70869c0 100644 --- a/db/statedb/txprocessors_test.go +++ b/db/statedb/txprocessors_test.go @@ -44,6 +44,8 @@ func TestProcessTxsSynchronizer(t *testing.T) { // Idx of user 'A' idxA1 := tc.Users["A"].Accounts[common.TokenID(1)].Idx + // Process the 1st batch, which contains the L1CoordinatorTxs necessary + // to create the Coordinator accounts to receive the fees log.Debug("1st batch, 1st block, only L1CoordinatorTxs") ptOut, err := sdb.ProcessTxs(nil, nil, blocks[0].Batches[0].L1CoordinatorTxs, nil) require.Nil(t, err) diff --git a/db/statedb/utils.go b/db/statedb/utils.go index 95bdddf..360ab9c 100644 --- a/db/statedb/utils.go +++ b/db/statedb/utils.go @@ -12,11 +12,18 @@ import ( "github.com/iden3/go-merkletree" ) -func concatEthAddrBJJ(addr ethCommon.Address, pk *babyjub.PublicKey) []byte { +func concatEthAddrTokenID(addr ethCommon.Address, tokenID common.TokenID) []byte { + var b []byte + b = append(b, addr.Bytes()...) + b = append(b[:], tokenID.Bytes()[:]...) + return b +} +func concatEthAddrBJJTokenID(addr ethCommon.Address, pk *babyjub.PublicKey, tokenID common.TokenID) []byte { pkComp := pk.Compress() var b []byte b = append(b, addr.Bytes()...) b = append(b[:], pkComp[:]...) + b = append(b[:], tokenID.Bytes()[:]...) return b } @@ -25,8 +32,8 @@ func concatEthAddrBJJ(addr ethCommon.Address, pk *babyjub.PublicKey) []byte { // - key: EthAddr & BabyJubJub PublicKey Compressed, value: idx // If Idx already exist for the given EthAddr & BJJ, the remaining Idx will be // always the smallest one. -func (s *StateDB) setIdxByEthAddrBJJ(idx common.Idx, addr ethCommon.Address, pk *babyjub.PublicKey) error { - oldIdx, err := s.GetIdxByEthAddrBJJ(addr, pk) +func (s *StateDB) setIdxByEthAddrBJJ(idx common.Idx, addr ethCommon.Address, pk *babyjub.PublicKey, tokenID common.TokenID) error { + oldIdx, err := s.GetIdxByEthAddrBJJ(addr, pk, tokenID) if err == nil { // EthAddr & BJJ already have an Idx // check which Idx is smaller @@ -46,7 +53,7 @@ func (s *StateDB) setIdxByEthAddrBJJ(idx common.Idx, addr ethCommon.Address, pk if err != nil { return err } - k := concatEthAddrBJJ(addr, pk) + k := concatEthAddrBJJTokenID(addr, pk, tokenID) // store Addr&BJJ-idx idxBytes, err := idx.Bytes() if err != nil { @@ -57,7 +64,8 @@ func (s *StateDB) setIdxByEthAddrBJJ(idx common.Idx, addr ethCommon.Address, pk return err } // store Addr-idx - err = tx.Put(append(PrefixKeyAddr, addr.Bytes()...), idxBytes[:]) + k = concatEthAddrTokenID(addr, tokenID) + err = tx.Put(append(PrefixKeyAddr, k...), idxBytes[:]) if err != nil { return err } @@ -71,8 +79,9 @@ func (s *StateDB) setIdxByEthAddrBJJ(idx common.Idx, addr ethCommon.Address, pk // GetIdxByEthAddr returns the smallest Idx in the StateDB for the given // Ethereum Address. Will return common.Idx(0) and error in case that Idx is // not found in the StateDB. -func (s *StateDB) GetIdxByEthAddr(addr ethCommon.Address) (common.Idx, error) { - b, err := s.db.Get(append(PrefixKeyAddr, addr.Bytes()...)) +func (s *StateDB) GetIdxByEthAddr(addr ethCommon.Address, tokenID common.TokenID) (common.Idx, error) { + k := concatEthAddrTokenID(addr, tokenID) + b, err := s.db.Get(append(PrefixKeyAddr, k...)) if err != nil { return common.Idx(0), ErrToIdxNotFound } @@ -88,13 +97,13 @@ func (s *StateDB) GetIdxByEthAddr(addr ethCommon.Address) (common.Idx, error) { // address, it's ignored in the query. If `pk` is nil, it's ignored in the // query. Will return common.Idx(0) and error in case that Idx is not found in // the StateDB. -func (s *StateDB) GetIdxByEthAddrBJJ(addr ethCommon.Address, pk *babyjub.PublicKey) (common.Idx, error) { +func (s *StateDB) GetIdxByEthAddrBJJ(addr ethCommon.Address, pk *babyjub.PublicKey, tokenID common.TokenID) (common.Idx, error) { if !bytes.Equal(addr.Bytes(), common.EmptyAddr.Bytes()) && pk == nil { // case ToEthAddr!=0 && ToBJJ=0 - return s.GetIdxByEthAddr(addr) + return s.GetIdxByEthAddr(addr, tokenID) } else if !bytes.Equal(addr.Bytes(), common.EmptyAddr.Bytes()) && pk != nil { // case ToEthAddr!=0 && ToBJJ!=0 - k := concatEthAddrBJJ(addr, pk) + k := concatEthAddrBJJTokenID(addr, pk, tokenID) b, err := s.db.Get(append(PrefixKeyAddrBJJ, k...)) if err != nil { return common.Idx(0), ErrToIdxNotFound diff --git a/db/statedb/utils_test.go b/db/statedb/utils_test.go index b608acf..84a4ab8 100644 --- a/db/statedb/utils_test.go +++ b/db/statedb/utils_test.go @@ -33,51 +33,60 @@ func TestGetIdx(t *testing.T) { idx2 := common.Idx(12345) idx3 := common.Idx(1233) + tokenID0 := common.TokenID(0) + tokenID1 := common.TokenID(1) + // store the keys for idx by Addr & BJJ - err = sdb.setIdxByEthAddrBJJ(idx, addr, pk) + err = sdb.setIdxByEthAddrBJJ(idx, addr, pk, tokenID0) require.Nil(t, err) - idxR, err := sdb.GetIdxByEthAddrBJJ(addr, pk) + idxR, err := sdb.GetIdxByEthAddrBJJ(addr, pk, tokenID0) assert.Nil(t, err) assert.Equal(t, idx, idxR) // expect error when getting only by EthAddr, as value does not exist // in the db for only EthAddr - _, err = sdb.GetIdxByEthAddr(addr) + _, err = sdb.GetIdxByEthAddr(addr, tokenID0) assert.Nil(t, err) - _, err = sdb.GetIdxByEthAddr(addr2) + _, err = sdb.GetIdxByEthAddr(addr2, tokenID0) + assert.NotNil(t, err) + // expect error when getting by EthAddr and BJJ, but for another TokenID + _, err = sdb.GetIdxByEthAddrBJJ(addr, pk, tokenID1) assert.NotNil(t, err) // expect to fail - idxR, err = sdb.GetIdxByEthAddrBJJ(addr2, pk) + idxR, err = sdb.GetIdxByEthAddrBJJ(addr2, pk, tokenID0) assert.NotNil(t, err) assert.Equal(t, common.Idx(0), idxR) - idxR, err = sdb.GetIdxByEthAddrBJJ(addr, pk2) + idxR, err = sdb.GetIdxByEthAddrBJJ(addr, pk2, tokenID0) assert.NotNil(t, err) assert.Equal(t, common.Idx(0), idxR) // try to store bigger idx, will not affect as already exist a smaller // Idx for that Addr & BJJ - err = sdb.setIdxByEthAddrBJJ(idx2, addr, pk) + err = sdb.setIdxByEthAddrBJJ(idx2, addr, pk, tokenID0) assert.Nil(t, err) // store smaller idx - err = sdb.setIdxByEthAddrBJJ(idx3, addr, pk) + err = sdb.setIdxByEthAddrBJJ(idx3, addr, pk, tokenID0) assert.Nil(t, err) - idxR, err = sdb.GetIdxByEthAddrBJJ(addr, pk) + idxR, err = sdb.GetIdxByEthAddrBJJ(addr, pk, tokenID0) assert.Nil(t, err) assert.Equal(t, idx3, idxR) // by EthAddr should work - idxR, err = sdb.GetIdxByEthAddr(addr) + idxR, err = sdb.GetIdxByEthAddr(addr, tokenID0) assert.Nil(t, err) assert.Equal(t, idx3, idxR) // expect error when trying to get Idx by addr2 & pk2 - idxR, err = sdb.GetIdxByEthAddrBJJ(addr2, pk2) + idxR, err = sdb.GetIdxByEthAddrBJJ(addr2, pk2, tokenID0) assert.NotNil(t, err) assert.Equal(t, ErrToIdxNotFound, err) assert.Equal(t, common.Idx(0), idxR) + // expect error when trying to get Idx by addr with not used TokenID + _, err = sdb.GetIdxByEthAddr(addr, tokenID1) + assert.NotNil(t, err) } func TestBJJCompressedTo256BigInt(t *testing.T) {