diff --git a/merkletree.go b/merkletree.go index 16e210b..b89fa29 100644 --- a/merkletree.go +++ b/merkletree.go @@ -310,7 +310,7 @@ func (mt *MerkleTree) Delete(k *big.Int) error { } switch n.Type { case NodeTypeEmpty: - return nil + return ErrKeyNotFound case NodeTypeLeaf: if bytes.Equal(kHash[:], n.Entry[0][:]) { // remove and go up with the sibling @@ -337,6 +337,12 @@ func (mt *MerkleTree) Delete(k *big.Int) error { // rmAndUpload removes the key, and goes up until the root updating all the nodes with the new values. func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings []*Hash) error { + if len(siblings) == 0 { + mt.rootKey = &HashZero + mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + return tx.Commit() + } + toUpload := siblings[len(siblings)-1] if len(siblings) < 2 { mt.rootKey = siblings[0] diff --git a/merkletree_test.go b/merkletree_test.go index 5b48e3c..83e18c2 100644 --- a/merkletree_test.go +++ b/merkletree_test.go @@ -315,6 +315,12 @@ func TestDelete(t *testing.T) { // mt.PrintGraphViz(nil) assert.Nil(t, err) assert.Equal(t, "12820263606494630162816839760750120928463716794691735985748071431547370997091", mt.Root().BigInt().String()) + + err = mt.Delete(big.NewInt(1234)) + assert.Nil(t, err) + err = mt.Delete(big.NewInt(1)) + assert.Nil(t, err) + assert.Equal(t, "0", mt.Root().String()) } func TestDelete2(t *testing.T) { @@ -420,3 +426,26 @@ func TestDelete5(t *testing.T) { assert.Nil(t, err) assert.Equal(t, mt2.Root(), mt.Root()) } + +func TestDeleteNonExistingKeys(t *testing.T) { + mt, err := NewMerkleTree(db.NewMemoryStorage(), 10) + assert.Nil(t, err) + + err = mt.Add(big.NewInt(1), big.NewInt(2)) + assert.Nil(t, err) + err = mt.Add(big.NewInt(33), big.NewInt(44)) + assert.Nil(t, err) + + err = mt.Delete(big.NewInt(33)) + assert.Nil(t, err) + err = mt.Delete(big.NewInt(33)) + assert.Equal(t, ErrKeyNotFound, err) + + err = mt.Delete(big.NewInt(1)) + assert.Nil(t, err) + + assert.Equal(t, "0", mt.Root().String()) + + err = mt.Delete(big.NewInt(33)) + assert.Equal(t, ErrKeyNotFound, err) +}