Browse Source

Unify StateDB MT{Create/Update}Account

Previously as the txprocessor.go methods were specific for the BatchBuilder,
the MTCreateAccount & CreateAccount and MTUpdateAccount & UpdateAccount were
also designed to be used by BatchBuilder and TxSelector depending on the
MerkleTree usage calling one kind of method or anotherone.

But now that this methods are being called directly by the StateDB (through the
methods in txprocessors.go), to allow also the methods usage from the
Synchronizer, there can not be the MT and no-MT methods separated, so this
commit unifies MTCreateAccount with CreateAccount, and MTUpdateAccount with
UpdateAccount, which internally will update the MerkleTree depending if the
specific StateDB in usage has the MerkleTree defined or not.
feature/sql-semaphore1
arnaucube 4 years ago
parent
commit
12aa31e46b
3 changed files with 42 additions and 70 deletions
  1. +27
    -49
      db/statedb/statedb.go
  2. +7
    -16
      db/statedb/statedb_test.go
  3. +8
    -5
      db/statedb/txprocessors.go

+ 27
- 49
db/statedb/statedb.go

@ -201,95 +201,73 @@ func (s *StateDB) GetAccount(idx common.Idx) (*common.Account, error) {
return common.AccountFromBytes(b) return common.AccountFromBytes(b)
} }
// CreateAccount creates a new Account in the StateDB for the given Idx.
// MerkleTree is not affected.
func (s *StateDB) CreateAccount(idx common.Idx, account *common.Account) error {
// CreateAccount creates a new Account in the StateDB for the given Idx. If
// StateDB.mt==nil, MerkleTree is not affected, otherwise updates the
// MerkleTree, returning a CircomProcessorProof.
func (s *StateDB) CreateAccount(idx common.Idx, account *common.Account) (*merkletree.CircomProcessorProof, error) {
// store at the DB the key: v, and value: leaf.Bytes() // store at the DB the key: v, and value: leaf.Bytes()
v, err := account.HashValue() v, err := account.HashValue()
if err != nil { if err != nil {
return err
return nil, err
} }
accountBytes, err := account.Bytes() accountBytes, err := account.Bytes()
if err != nil { if err != nil {
return err
return nil, err
} }
// store the Leaf value // store the Leaf value
tx, err := s.db.NewTx() tx, err := s.db.NewTx()
if err != nil { if err != nil {
return err
return nil, err
} }
_, err = tx.Get(idx.Bytes()) _, err = tx.Get(idx.Bytes())
if err != db.ErrNotFound { if err != db.ErrNotFound {
return ErrAccountAlreadyExists
return nil, ErrAccountAlreadyExists
} }
tx.Put(v.Bytes(), accountBytes[:]) tx.Put(v.Bytes(), accountBytes[:])
tx.Put(idx.Bytes(), v.Bytes()) tx.Put(idx.Bytes(), v.Bytes())
return tx.Commit()
if err := tx.Commit(); err != nil {
return nil, err
}
if s.mt != nil {
return s.mt.AddAndGetCircomProof(idx.BigInt(), v)
}
return nil, nil
} }
// UpdateAccount updates the Account in the StateDB for the given Idx.
// MerkleTree is not affected.
func (s *StateDB) UpdateAccount(idx common.Idx, account *common.Account) error {
// UpdateAccount updates the Account in the StateDB for the given Idx. If
// StateDB.mt==nil, MerkleTree is not affected, otherwise updates the
// MerkleTree, returning a CircomProcessorProof.
func (s *StateDB) UpdateAccount(idx common.Idx, account *common.Account) (*merkletree.CircomProcessorProof, error) {
// store at the DB the key: v, and value: leaf.Bytes() // store at the DB the key: v, and value: leaf.Bytes()
v, err := account.HashValue() v, err := account.HashValue()
if err != nil { if err != nil {
return err
return nil, err
} }
accountBytes, err := account.Bytes() accountBytes, err := account.Bytes()
if err != nil { if err != nil {
return err
return nil, err
} }
tx, err := s.db.NewTx() tx, err := s.db.NewTx()
if err != nil { if err != nil {
return err
return nil, err
} }
tx.Put(v.Bytes(), accountBytes[:]) tx.Put(v.Bytes(), accountBytes[:])
tx.Put(idx.Bytes(), v.Bytes()) tx.Put(idx.Bytes(), v.Bytes())
return tx.Commit()
}
// MTCreateAccount creates a new Account in the StateDB for the given Idx,
// and updates the MerkleTree, returning a CircomProcessorProof
func (s *StateDB) MTCreateAccount(idx common.Idx, account *common.Account) (*merkletree.CircomProcessorProof, error) {
if s.mt == nil {
return nil, ErrStateDBWithoutMT
}
err := s.CreateAccount(idx, account)
if err != nil {
return nil, err
}
v, err := account.HashValue() // already computed in s.CreateAccount, next iteration reuse first computation
if err != nil {
return nil, err
}
// Add k & v into the MT
return s.mt.AddAndGetCircomProof(idx.BigInt(), v)
}
// MTUpdateAccount updates the Account in the StateDB for the given Idx, and
// updates the MerkleTree, returning a CircomProcessorProof
func (s *StateDB) MTUpdateAccount(idx common.Idx, account *common.Account) (*merkletree.CircomProcessorProof, error) {
if s.mt == nil {
return nil, ErrStateDBWithoutMT
}
err := s.UpdateAccount(idx, account)
if err != nil {
if err := tx.Commit(); err != nil {
return nil, err return nil, err
} }
v, err := account.HashValue() // already computed in s.CreateAccount, next iteration reuse first computation
if err != nil {
return nil, err
if s.mt != nil {
return s.mt.Update(idx.BigInt(), v)
} }
// Add k & v into the MT
return s.mt.Update(idx.BigInt(), v)
return nil, nil
} }
// MTGetProof returns the CircomVerifierProof for a given Idx // MTGetProof returns the CircomVerifierProof for a given Idx

+ 7
- 16
db/statedb/statedb_test.go

@ -55,7 +55,7 @@ func TestStateDBWithoutMT(t *testing.T) {
// add test accounts // add test accounts
for i := 0; i < len(accounts); i++ { for i := 0; i < len(accounts); i++ {
err = sdb.CreateAccount(common.Idx(i), accounts[i])
_, err = sdb.CreateAccount(common.Idx(i), accounts[i])
assert.Nil(t, err) assert.Nil(t, err)
} }
@ -68,26 +68,17 @@ func TestStateDBWithoutMT(t *testing.T) {
// try already existing idx and get error // try already existing idx and get error
_, err = sdb.GetAccount(common.Idx(1)) // check that exist _, err = sdb.GetAccount(common.Idx(1)) // check that exist
assert.Nil(t, err) assert.Nil(t, err)
err = sdb.CreateAccount(common.Idx(1), accounts[1]) // check that can not be created twice
_, err = sdb.CreateAccount(common.Idx(1), accounts[1]) // check that can not be created twice
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, ErrAccountAlreadyExists, err) assert.Equal(t, ErrAccountAlreadyExists, err)
// update accounts // update accounts
for i := 0; i < len(accounts); i++ { for i := 0; i < len(accounts); i++ {
accounts[i].Nonce = accounts[i].Nonce + 1 accounts[i].Nonce = accounts[i].Nonce + 1
err = sdb.UpdateAccount(common.Idx(i), accounts[i])
_, err = sdb.UpdateAccount(common.Idx(i), accounts[i])
assert.Nil(t, err) assert.Nil(t, err)
} }
// check that can not call MerkleTree methods of the StateDB
_, err = sdb.MTCreateAccount(common.Idx(1), accounts[1])
assert.NotNil(t, err)
assert.Equal(t, ErrStateDBWithoutMT, err)
_, err = sdb.MTUpdateAccount(common.Idx(1), accounts[1])
assert.NotNil(t, err)
assert.Equal(t, ErrStateDBWithoutMT, err)
_, err = sdb.MTGetProof(common.Idx(1)) _, err = sdb.MTGetProof(common.Idx(1))
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, ErrStateDBWithoutMT, err) assert.Equal(t, ErrStateDBWithoutMT, err)
@ -113,7 +104,7 @@ func TestStateDBWithMT(t *testing.T) {
// add test accounts // add test accounts
for i := 0; i < len(accounts); i++ { for i := 0; i < len(accounts); i++ {
_, err = sdb.MTCreateAccount(common.Idx(i), accounts[i])
_, err = sdb.CreateAccount(common.Idx(i), accounts[i])
assert.Nil(t, err) assert.Nil(t, err)
} }
@ -126,7 +117,7 @@ func TestStateDBWithMT(t *testing.T) {
// try already existing idx and get error // try already existing idx and get error
_, err = sdb.GetAccount(common.Idx(1)) // check that exist _, err = sdb.GetAccount(common.Idx(1)) // check that exist
assert.Nil(t, err) assert.Nil(t, err)
_, err = sdb.MTCreateAccount(common.Idx(1), accounts[1]) // check that can not be created twice
_, err = sdb.CreateAccount(common.Idx(1), accounts[1]) // check that can not be created twice
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, ErrAccountAlreadyExists, err) assert.Equal(t, ErrAccountAlreadyExists, err)
@ -136,7 +127,7 @@ func TestStateDBWithMT(t *testing.T) {
// update accounts // update accounts
for i := 0; i < len(accounts); i++ { for i := 0; i < len(accounts); i++ {
accounts[i].Nonce = accounts[i].Nonce + 1 accounts[i].Nonce = accounts[i].Nonce + 1
_, err = sdb.MTUpdateAccount(common.Idx(i), accounts[i])
_, err = sdb.UpdateAccount(common.Idx(i), accounts[i])
assert.Nil(t, err) assert.Nil(t, err)
} }
a, err := sdb.GetAccount(common.Idx(1)) // check that account value has been updated a, err := sdb.GetAccount(common.Idx(1)) // check that account value has been updated
@ -159,7 +150,7 @@ func TestCheckpoints(t *testing.T) {
// add test accounts // add test accounts
for i := 0; i < len(accounts); i++ { for i := 0; i < len(accounts); i++ {
_, err = sdb.MTCreateAccount(common.Idx(i), accounts[i])
_, err = sdb.CreateAccount(common.Idx(i), accounts[i])
assert.Nil(t, err) assert.Nil(t, err)
} }

+ 8
- 5
db/statedb/txprocessors.go

@ -89,7 +89,7 @@ func (s *StateDB) applyCreateAccount(tx *common.L1Tx) error {
EthAddr: tx.FromEthAddr, EthAddr: tx.FromEthAddr,
} }
err := s.CreateAccount(common.Idx(s.idx+1), account)
_, err := s.CreateAccount(common.Idx(s.idx+1), account)
if err != nil { if err != nil {
return err return err
} }
@ -120,13 +120,13 @@ func (s *StateDB) applyDeposit(tx *common.L1Tx, transfer bool) error {
// add amount to the receiver // add amount to the receiver
accReceiver.Balance = new(big.Int).Add(accReceiver.Balance, tx.Amount) accReceiver.Balance = new(big.Int).Add(accReceiver.Balance, tx.Amount)
// update receiver account in localStateDB // update receiver account in localStateDB
err = s.UpdateAccount(tx.ToIdx, accReceiver)
_, err = s.UpdateAccount(tx.ToIdx, accReceiver)
if err != nil { if err != nil {
return err return err
} }
} }
// update sender account in localStateDB // update sender account in localStateDB
err = s.UpdateAccount(tx.FromIdx, accSender)
_, err = s.UpdateAccount(tx.FromIdx, accSender)
if err != nil { if err != nil {
return err return err
} }
@ -146,18 +146,21 @@ func (s *StateDB) applyTransfer(tx *common.Tx) error {
return err return err
} }
// increment nonce
accSender.Nonce++
// substract amount to the sender // substract amount to the sender
accSender.Balance = new(big.Int).Sub(accSender.Balance, tx.Amount) accSender.Balance = new(big.Int).Sub(accSender.Balance, tx.Amount)
// add amount to the receiver // add amount to the receiver
accReceiver.Balance = new(big.Int).Add(accReceiver.Balance, tx.Amount) accReceiver.Balance = new(big.Int).Add(accReceiver.Balance, tx.Amount)
// update receiver account in localStateDB // update receiver account in localStateDB
err = s.UpdateAccount(tx.ToIdx, accReceiver)
_, err = s.UpdateAccount(tx.ToIdx, accReceiver)
if err != nil { if err != nil {
return err return err
} }
// update sender account in localStateDB // update sender account in localStateDB
err = s.UpdateAccount(tx.FromIdx, accSender)
_, err = s.UpdateAccount(tx.FromIdx, accSender)
if err != nil { if err != nil {
return err return err
} }

Loading…
Cancel
Save