Browse Source

Update TxProcessor L2Tx logic for case ToIdx==0

feature/sql-semaphore1
arnaucube 4 years ago
parent
commit
11c45cfc2f
3 changed files with 48 additions and 44 deletions
  1. +28
    -28
      db/statedb/txprocessors.go
  2. +19
    -14
      db/statedb/utils.go
  3. +1
    -2
      db/statedb/utils_test.go

+ 28
- 28
db/statedb/txprocessors.go

@ -1,7 +1,6 @@
package statedb package statedb
import ( import (
"bytes"
"errors" "errors"
"math/big" "math/big"
@ -245,7 +244,7 @@ func (s *StateDB) processL1Tx(exitTree *merkletree.MerkleTree, tx *common.L1Tx)
case common.TxTypeForceTransfer, common.TxTypeTransfer: case common.TxTypeForceTransfer, common.TxTypeTransfer:
// go to the MT account of sender and receiver, and update balance // go to the MT account of sender and receiver, and update balance
// & nonce // & nonce
err := s.applyTransfer(tx.Tx())
err := s.applyTransfer(tx.Tx(), 0) // 0 for the parameter toIdx, as at L1Tx ToIdx can only be 0 in the Deposit type case.
if err != nil { if err != nil {
return nil, nil, false, err return nil, nil, false, err
} }
@ -304,6 +303,16 @@ func (s *StateDB) processL1Tx(exitTree *merkletree.MerkleTree, tx *common.L1Tx)
// the Exit created a new Leaf in the ExitTree. // the Exit created a new Leaf in the ExitTree.
func (s *StateDB) processL2Tx(exitTree *merkletree.MerkleTree, tx *common.PoolL2Tx) (*common.Idx, *common.Account, bool, error) { func (s *StateDB) processL2Tx(exitTree *merkletree.MerkleTree, tx *common.PoolL2Tx) (*common.Idx, *common.Account, bool, error) {
var err error var err error
var auxToIdx common.Idx
// if tx.ToIdx==0, get toIdx by ToEthAddr or ToBJJ
if tx.ToIdx == common.Idx(0) {
auxToIdx, err = s.GetIdxByEthAddrBJJ(tx.ToEthAddr, tx.ToBJJ)
if err != nil {
log.Error(err)
return nil, nil, false, err
}
}
// ZKInputs // ZKInputs
if s.zki != nil { if s.zki != nil {
// Txs // Txs
@ -314,28 +323,12 @@ func (s *StateDB) processL2Tx(exitTree *merkletree.MerkleTree, tx *common.PoolL2
// fill AuxToIdx if needed // fill AuxToIdx if needed
if tx.ToIdx == common.Idx(0) { if tx.ToIdx == common.Idx(0) {
var idx common.Idx
if !bytes.Equal(tx.ToEthAddr.Bytes(), common.EmptyAddr.Bytes()) && tx.ToBJJ == nil {
// case ToEthAddr!=0 && ToBJJ=0
idx, err = s.GetIdxByEthAddr(tx.ToEthAddr)
if err != nil {
log.Error(err)
return nil, nil, false, ErrToIdxNotFound
}
} else if !bytes.Equal(tx.ToEthAddr.Bytes(), common.EmptyAddr.Bytes()) && tx.ToBJJ != nil {
// case ToEthAddr!=0 && ToBJJ!=0
idx, err = s.GetIdxByEthAddrBJJ(tx.ToEthAddr, tx.ToBJJ)
if err != nil {
log.Error(err)
return nil, nil, false, ErrToIdxNotFound
}
} else {
// rest of cases (included case ToEthAddr==0) are not possible
log.Error(err)
return nil, nil, false, ErrToIdxNotFound
}
s.zki.AuxToIdx[s.i] = idx.BigInt()
// use toIdx that can have been filled by tx.ToIdx or
// if tx.Idx==0 (this case), toIdx is filled by the Idx
// from db by ToEthAddr&ToBJJ
s.zki.AuxToIdx[s.i] = auxToIdx.BigInt()
} }
s.zki.ToBJJAy[s.i] = tx.ToBJJ.Y s.zki.ToBJJAy[s.i] = tx.ToBJJ.Y
s.zki.ToEthAddr[s.i] = common.EthAddrToBigInt(tx.ToEthAddr) s.zki.ToEthAddr[s.i] = common.EthAddrToBigInt(tx.ToEthAddr)
@ -356,7 +349,7 @@ func (s *StateDB) processL2Tx(exitTree *merkletree.MerkleTree, tx *common.PoolL2
case common.TxTypeTransfer: case common.TxTypeTransfer:
// go to the MT account of sender and receiver, and update // go to the MT account of sender and receiver, and update
// balance & nonce // balance & nonce
err = s.applyTransfer(tx.Tx())
err = s.applyTransfer(tx.Tx(), auxToIdx)
if err != nil { if err != nil {
return nil, nil, false, err return nil, nil, false, err
} }
@ -474,14 +467,21 @@ func (s *StateDB) applyDeposit(tx *common.L1Tx, transfer bool) error {
} }
// applyTransfer updates the balance & nonce in the account of the sender, and // applyTransfer updates the balance & nonce in the account of the sender, and
// the balance in the account of the receiver
func (s *StateDB) applyTransfer(tx *common.Tx) error {
// the balance in the account of the receiver.
// Parameter 'toIdx' should be at 0 if the tx already has tx.ToIdx!=0, if
// tx.ToIdx==0, then toIdx!=0, and will be used the toIdx parameter as Idx of
// the receiver. This parameter is used when the tx.ToIdx is not specified and
// the real ToIdx is found trhrough the ToEthAddr or ToBJJ.
func (s *StateDB) applyTransfer(tx *common.Tx, auxToIdx common.Idx) error {
if auxToIdx == 0 {
auxToIdx = tx.ToIdx
}
// get sender and receiver accounts from localStateDB // get sender and receiver accounts from localStateDB
accSender, err := s.GetAccount(tx.FromIdx) accSender, err := s.GetAccount(tx.FromIdx)
if err != nil { if err != nil {
return err return err
} }
accReceiver, err := s.GetAccount(tx.ToIdx)
accReceiver, err := s.GetAccount(auxToIdx)
if err != nil { if err != nil {
return err return err
} }
@ -512,7 +512,7 @@ func (s *StateDB) applyTransfer(tx *common.Tx) error {
} }
// update receiver account in localStateDB // update receiver account in localStateDB
pReceiver, err := s.UpdateAccount(tx.ToIdx, accReceiver)
pReceiver, err := s.UpdateAccount(auxToIdx, accReceiver)
if err != nil { if err != nil {
return err return err
} }

+ 19
- 14
db/statedb/utils.go

@ -1,6 +1,7 @@
package statedb package statedb
import ( import (
"bytes"
"math/big" "math/big"
ethCommon "github.com/ethereum/go-ethereum/common" ethCommon "github.com/ethereum/go-ethereum/common"
@ -72,11 +73,11 @@ func (s *StateDB) setIdxByEthAddrBJJ(idx common.Idx, addr ethCommon.Address, pk
func (s *StateDB) GetIdxByEthAddr(addr ethCommon.Address) (common.Idx, error) { func (s *StateDB) GetIdxByEthAddr(addr ethCommon.Address) (common.Idx, error) {
b, err := s.db.Get(addr.Bytes()) b, err := s.db.Get(addr.Bytes())
if err != nil { if err != nil {
return common.Idx(0), err
return common.Idx(0), ErrToIdxNotFound
} }
idx, err := common.IdxFromBytes(b) idx, err := common.IdxFromBytes(b)
if err != nil { if err != nil {
return common.Idx(0), err
return common.Idx(0), ErrToIdxNotFound
} }
return idx, nil return idx, nil
} }
@ -87,20 +88,24 @@ func (s *StateDB) GetIdxByEthAddr(addr ethCommon.Address) (common.Idx, error) {
// query. Will return common.Idx(0) and error in case that Idx is not found in // query. Will return common.Idx(0) and error in case that Idx is not found in
// the StateDB. // the StateDB.
func (s *StateDB) GetIdxByEthAddrBJJ(addr ethCommon.Address, pk *babyjub.PublicKey) (common.Idx, error) { func (s *StateDB) GetIdxByEthAddrBJJ(addr ethCommon.Address, pk *babyjub.PublicKey) (common.Idx, error) {
if pk == nil {
if !bytes.Equal(addr.Bytes(), common.EmptyAddr.Bytes()) && pk == nil {
// case ToEthAddr!=0 && ToBJJ=0
return s.GetIdxByEthAddr(addr) return s.GetIdxByEthAddr(addr)
} else if !bytes.Equal(addr.Bytes(), common.EmptyAddr.Bytes()) && pk != nil {
// case ToEthAddr!=0 && ToBJJ!=0
k := concatEthAddrBJJ(addr, pk)
b, err := s.db.Get(k)
if err != nil {
return common.Idx(0), ErrToIdxNotFound
}
idx, err := common.IdxFromBytes(b)
if err != nil {
return common.Idx(0), ErrToIdxNotFound
}
return idx, nil
} }
k := concatEthAddrBJJ(addr, pk)
b, err := s.db.Get(k)
if err != nil {
return common.Idx(0), err
}
idx, err := common.IdxFromBytes(b)
if err != nil {
return common.Idx(0), err
}
return idx, nil
// rest of cases (included case ToEthAddr==0) are not possible
return common.Idx(0), ErrToIdxNotFound
} }
func siblingsToZKInputFormat(s []*merkletree.Hash) []*big.Int { func siblingsToZKInputFormat(s []*merkletree.Hash) []*big.Int {

+ 1
- 2
db/statedb/utils_test.go

@ -8,7 +8,6 @@ import (
ethCommon "github.com/ethereum/go-ethereum/common" ethCommon "github.com/ethereum/go-ethereum/common"
"github.com/hermeznetwork/hermez-node/common" "github.com/hermeznetwork/hermez-node/common"
"github.com/iden3/go-iden3-crypto/babyjub" "github.com/iden3/go-iden3-crypto/babyjub"
"github.com/iden3/go-merkletree/db"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -75,7 +74,7 @@ func TestGetIdx(t *testing.T) {
// expect error when trying to get Idx by addr2 & pk2 // expect error when trying to get Idx by addr2 & pk2
idxR, err = sdb.GetIdxByEthAddrBJJ(addr2, pk2) idxR, err = sdb.GetIdxByEthAddrBJJ(addr2, pk2)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, db.ErrNotFound, err)
assert.Equal(t, ErrToIdxNotFound, err)
assert.Equal(t, common.Idx(0), idxR) assert.Equal(t, common.Idx(0), idxR)
} }

Loading…
Cancel
Save