diff --git a/db/statedb/statedb.go b/db/statedb/statedb.go index 3113982..c9b1a31 100644 --- a/db/statedb/statedb.go +++ b/db/statedb/statedb.go @@ -201,95 +201,73 @@ func (s *StateDB) GetAccount(idx common.Idx) (*common.Account, error) { return common.AccountFromBytes(b) } -// CreateAccount creates a new Account in the StateDB for the given Idx. -// MerkleTree is not affected. -func (s *StateDB) CreateAccount(idx common.Idx, account *common.Account) error { +// CreateAccount creates a new Account in the StateDB for the given Idx. If +// StateDB.mt==nil, MerkleTree is not affected, otherwise updates the +// MerkleTree, returning a CircomProcessorProof. +func (s *StateDB) CreateAccount(idx common.Idx, account *common.Account) (*merkletree.CircomProcessorProof, error) { // store at the DB the key: v, and value: leaf.Bytes() v, err := account.HashValue() if err != nil { - return err + return nil, err } accountBytes, err := account.Bytes() if err != nil { - return err + return nil, err } // store the Leaf value tx, err := s.db.NewTx() if err != nil { - return err + return nil, err } _, err = tx.Get(idx.Bytes()) if err != db.ErrNotFound { - return ErrAccountAlreadyExists + return nil, ErrAccountAlreadyExists } tx.Put(v.Bytes(), accountBytes[:]) tx.Put(idx.Bytes(), v.Bytes()) - return tx.Commit() + if err := tx.Commit(); err != nil { + return nil, err + } + + if s.mt != nil { + return s.mt.AddAndGetCircomProof(idx.BigInt(), v) + } + return nil, nil } -// UpdateAccount updates the Account in the StateDB for the given Idx. -// MerkleTree is not affected. -func (s *StateDB) UpdateAccount(idx common.Idx, account *common.Account) error { +// UpdateAccount updates the Account in the StateDB for the given Idx. If +// StateDB.mt==nil, MerkleTree is not affected, otherwise updates the +// MerkleTree, returning a CircomProcessorProof. +func (s *StateDB) UpdateAccount(idx common.Idx, account *common.Account) (*merkletree.CircomProcessorProof, error) { // store at the DB the key: v, and value: leaf.Bytes() v, err := account.HashValue() if err != nil { - return err + return nil, err } accountBytes, err := account.Bytes() if err != nil { - return err + return nil, err } tx, err := s.db.NewTx() if err != nil { - return err + return nil, err } tx.Put(v.Bytes(), accountBytes[:]) tx.Put(idx.Bytes(), v.Bytes()) - return tx.Commit() -} - -// MTCreateAccount creates a new Account in the StateDB for the given Idx, -// and updates the MerkleTree, returning a CircomProcessorProof -func (s *StateDB) MTCreateAccount(idx common.Idx, account *common.Account) (*merkletree.CircomProcessorProof, error) { - if s.mt == nil { - return nil, ErrStateDBWithoutMT - } - err := s.CreateAccount(idx, account) - if err != nil { - return nil, err - } - - v, err := account.HashValue() // already computed in s.CreateAccount, next iteration reuse first computation - if err != nil { - return nil, err - } - // Add k & v into the MT - return s.mt.AddAndGetCircomProof(idx.BigInt(), v) -} - -// MTUpdateAccount updates the Account in the StateDB for the given Idx, and -// updates the MerkleTree, returning a CircomProcessorProof -func (s *StateDB) MTUpdateAccount(idx common.Idx, account *common.Account) (*merkletree.CircomProcessorProof, error) { - if s.mt == nil { - return nil, ErrStateDBWithoutMT - } - err := s.UpdateAccount(idx, account) - if err != nil { + if err := tx.Commit(); err != nil { return nil, err } - v, err := account.HashValue() // already computed in s.CreateAccount, next iteration reuse first computation - if err != nil { - return nil, err + if s.mt != nil { + return s.mt.Update(idx.BigInt(), v) } - // Add k & v into the MT - return s.mt.Update(idx.BigInt(), v) + return nil, nil } // MTGetProof returns the CircomVerifierProof for a given Idx diff --git a/db/statedb/statedb_test.go b/db/statedb/statedb_test.go index 96e8cd8..60c558e 100644 --- a/db/statedb/statedb_test.go +++ b/db/statedb/statedb_test.go @@ -55,7 +55,7 @@ func TestStateDBWithoutMT(t *testing.T) { // add test accounts for i := 0; i < len(accounts); i++ { - err = sdb.CreateAccount(common.Idx(i), accounts[i]) + _, err = sdb.CreateAccount(common.Idx(i), accounts[i]) assert.Nil(t, err) } @@ -68,26 +68,17 @@ func TestStateDBWithoutMT(t *testing.T) { // try already existing idx and get error _, err = sdb.GetAccount(common.Idx(1)) // check that exist assert.Nil(t, err) - err = sdb.CreateAccount(common.Idx(1), accounts[1]) // check that can not be created twice + _, err = sdb.CreateAccount(common.Idx(1), accounts[1]) // check that can not be created twice assert.NotNil(t, err) assert.Equal(t, ErrAccountAlreadyExists, err) // update accounts for i := 0; i < len(accounts); i++ { accounts[i].Nonce = accounts[i].Nonce + 1 - err = sdb.UpdateAccount(common.Idx(i), accounts[i]) + _, err = sdb.UpdateAccount(common.Idx(i), accounts[i]) assert.Nil(t, err) } - // check that can not call MerkleTree methods of the StateDB - _, err = sdb.MTCreateAccount(common.Idx(1), accounts[1]) - assert.NotNil(t, err) - assert.Equal(t, ErrStateDBWithoutMT, err) - - _, err = sdb.MTUpdateAccount(common.Idx(1), accounts[1]) - assert.NotNil(t, err) - assert.Equal(t, ErrStateDBWithoutMT, err) - _, err = sdb.MTGetProof(common.Idx(1)) assert.NotNil(t, err) assert.Equal(t, ErrStateDBWithoutMT, err) @@ -113,7 +104,7 @@ func TestStateDBWithMT(t *testing.T) { // add test accounts for i := 0; i < len(accounts); i++ { - _, err = sdb.MTCreateAccount(common.Idx(i), accounts[i]) + _, err = sdb.CreateAccount(common.Idx(i), accounts[i]) assert.Nil(t, err) } @@ -126,7 +117,7 @@ func TestStateDBWithMT(t *testing.T) { // try already existing idx and get error _, err = sdb.GetAccount(common.Idx(1)) // check that exist assert.Nil(t, err) - _, err = sdb.MTCreateAccount(common.Idx(1), accounts[1]) // check that can not be created twice + _, err = sdb.CreateAccount(common.Idx(1), accounts[1]) // check that can not be created twice assert.NotNil(t, err) assert.Equal(t, ErrAccountAlreadyExists, err) @@ -136,7 +127,7 @@ func TestStateDBWithMT(t *testing.T) { // update accounts for i := 0; i < len(accounts); i++ { accounts[i].Nonce = accounts[i].Nonce + 1 - _, err = sdb.MTUpdateAccount(common.Idx(i), accounts[i]) + _, err = sdb.UpdateAccount(common.Idx(i), accounts[i]) assert.Nil(t, err) } a, err := sdb.GetAccount(common.Idx(1)) // check that account value has been updated @@ -159,7 +150,7 @@ func TestCheckpoints(t *testing.T) { // add test accounts for i := 0; i < len(accounts); i++ { - _, err = sdb.MTCreateAccount(common.Idx(i), accounts[i]) + _, err = sdb.CreateAccount(common.Idx(i), accounts[i]) assert.Nil(t, err) } diff --git a/db/statedb/txprocessors.go b/db/statedb/txprocessors.go index a0148ad..14180ea 100644 --- a/db/statedb/txprocessors.go +++ b/db/statedb/txprocessors.go @@ -89,7 +89,7 @@ func (s *StateDB) applyCreateAccount(tx *common.L1Tx) error { EthAddr: tx.FromEthAddr, } - err := s.CreateAccount(common.Idx(s.idx+1), account) + _, err := s.CreateAccount(common.Idx(s.idx+1), account) if err != nil { return err } @@ -120,13 +120,13 @@ func (s *StateDB) applyDeposit(tx *common.L1Tx, transfer bool) error { // add amount to the receiver accReceiver.Balance = new(big.Int).Add(accReceiver.Balance, tx.Amount) // update receiver account in localStateDB - err = s.UpdateAccount(tx.ToIdx, accReceiver) + _, err = s.UpdateAccount(tx.ToIdx, accReceiver) if err != nil { return err } } // update sender account in localStateDB - err = s.UpdateAccount(tx.FromIdx, accSender) + _, err = s.UpdateAccount(tx.FromIdx, accSender) if err != nil { return err } @@ -146,18 +146,21 @@ func (s *StateDB) applyTransfer(tx *common.Tx) error { return err } + // increment nonce + accSender.Nonce++ + // substract amount to the sender accSender.Balance = new(big.Int).Sub(accSender.Balance, tx.Amount) // add amount to the receiver accReceiver.Balance = new(big.Int).Add(accReceiver.Balance, tx.Amount) // update receiver account in localStateDB - err = s.UpdateAccount(tx.ToIdx, accReceiver) + _, err = s.UpdateAccount(tx.ToIdx, accReceiver) if err != nil { return err } // update sender account in localStateDB - err = s.UpdateAccount(tx.FromIdx, accSender) + _, err = s.UpdateAccount(tx.FromIdx, accSender) if err != nil { return err }