Browse Source

Merge pull request #3 from iden3/feature/delete-leaf

Add mt.Delete(key) with test vectors
circomproofs
arnau 3 years ago
committed by GitHub
parent
commit
e68d9e31d4
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 263 additions and 6 deletions
  1. +4
    -1
      README.md
  2. +126
    -2
      merkletree.go
  3. +130
    -0
      merkletree_test.go
  4. +3
    -3
      node.go

+ 4
- 1
README.md

@ -2,7 +2,7 @@
MerkleTree compatible with version from [circomlib](https://github.com/iden3/circomlib).
Adaptation of the merkletree from https://github.com/iden3/go-iden3-core/tree/v0.0.8
Adaptation of the merkletree from https://github.com/iden3/go-iden3-core/tree/v0.0.8 with several changes and more functionalities.
## Usage
More detailed examples can be found at the [tests](https://github.com/iden3/go-merkletree/blob/master/merkletree_test.go), and in the [documentation](https://godoc.org/github.com/iden3/go-merkletree).
@ -33,5 +33,8 @@ func TestExampleMerkleTree(t *testing.T) {
assert.Nil(t, err)
assert.True(t, VerifyProof(mt.Root(), proof, key, value))
err := mt.Delete(big.NewInt(1)) // delete the leaf of key=1
assert.Nil(t, err)
}
```

+ 126
- 2
merkletree.go

@ -27,9 +27,9 @@ var (
ErrNodeKeyAlreadyExists = errors.New("node already exists")
// ErrEntryIndexNotFound is used when no entry is found for an index.
ErrEntryIndexNotFound = errors.New("node index not found in the DB")
// ErrNodeDataBadSize is used when the data of a node has an incorrect
// ErrNodeBytesBadSize is used when the data of a node has an incorrect
// size and can't be parsed.
ErrNodeDataBadSize = errors.New("node data has incorrect size in the DB")
ErrNodeBytesBadSize = errors.New("node data has incorrect size in the DB")
// ErrReachedMaxLevel is used when a traversal of the MT reaches the
// maximum level.
ErrReachedMaxLevel = errors.New("reached maximum level of the merkle tree")
@ -47,6 +47,8 @@ var (
ErrEntryIndexAlreadyExists = errors.New("the entry index already exists in the tree")
// ErrNotWritable is used when the MerkleTree is not writable and a write function is called
ErrNotWritable = errors.New("Merkle Tree not writable")
// ErrKeyNotFound is used when a key is not found in the MerkleTree.
ErrKeyNotFound = errors.New("Key not found in the tree")
rootNodeValue = []byte("currentroot")
// HashZero is used at Empty nodes
HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
@ -277,6 +279,128 @@ func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) {
return k, nil
}
// Delete removes the specified Key from the MerkleTree, and updates the pad from the delted key to the Root with the new values.
// This method removes the key from the MerkleTree, but does not remove the old nodes from the key-value database; this means that if the tree is accessed by an old Root where the key was not deleted yet, the key will still exist. If is desired to remove the key-values from the database that are not under the current Root, an option could be to dump all the claims and import them in a new MerkleTree in a new database, but this will loose all the Root history of the MerkleTree
func (mt *MerkleTree) Delete(k *big.Int) error {
// verify that the MerkleTree is writable
if !mt.writable {
return ErrNotWritable
}
// verfy that the ElemBytes are valid and fit inside the Finite Field.
if !cryptoUtils.CheckBigIntInField(k) {
return errors.New("Key not inside the Finite Field")
}
tx, err := mt.db.NewTx()
if err != nil {
return err
}
mt.Lock()
defer mt.Unlock()
kHash := NewHashFromBigInt(k)
path := getPath(mt.maxLevels, kHash[:])
nextKey := mt.rootKey
var siblings []*Hash
for i := 0; i < mt.maxLevels; i++ {
n, err := mt.GetNode(nextKey)
if err != nil {
return err
}
switch n.Type {
case NodeTypeEmpty:
return nil
case NodeTypeLeaf:
if bytes.Equal(kHash[:], n.Entry[0][:]) {
// remove and go up with the sibling
err = mt.rmAndUpload(tx, path, kHash, siblings)
return err
} else {
return ErrKeyNotFound
}
case NodeTypeMiddle:
if path[i] {
nextKey = n.ChildR
siblings = append(siblings, n.ChildL)
} else {
nextKey = n.ChildL
siblings = append(siblings, n.ChildR)
}
default:
return ErrInvalidNodeFound
}
}
return nil
}
// rmAndUpload removes the key, and goes up until the root updating all the nodes with the new values.
func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings []*Hash) error {
toUpload := siblings[len(siblings)-1]
if len(siblings) < 2 {
mt.rootKey = siblings[0]
mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
return tx.Commit()
}
for i := len(siblings) - 2; i >= 0; i-- {
if !bytes.Equal(siblings[i][:], HashZero[:]) {
var newNode *Node
if path[i] {
newNode = NewNodeMiddle(siblings[i], toUpload)
} else {
newNode = NewNodeMiddle(toUpload, siblings[i])
}
_, err := mt.addNode(tx, newNode)
if err != ErrNodeKeyAlreadyExists && err != nil {
return err
}
// go up until the root
newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNode, siblings[:i])
if err != nil {
return err
}
mt.rootKey = newRootKey
mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
break
}
// if i==0 (root position), stop and store the sibling of the deleted leaf as root
if i == 0 {
mt.rootKey = toUpload
mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
break
}
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
// recalculatePathUntilRoot recalculates the nodes until the Root
func (mt *MerkleTree) recalculatePathUntilRoot(tx db.Tx, path []bool, node *Node, siblings []*Hash) (*Hash, error) {
for i := len(siblings) - 1; i >= 0; i-- {
nodeKey, err := node.Key()
if err != nil {
return nil, err
}
if path[i] {
node = NewNodeMiddle(siblings[i], nodeKey)
} else {
node = NewNodeMiddle(nodeKey, siblings[i])
}
_, err = mt.addNode(tx, node)
if err != ErrNodeKeyAlreadyExists && err != nil {
return nil, err
}
}
// return last node added, which is the root
nodeKey, err := node.Key()
return nodeKey, err
}
// dbGet is a helper function to get the node of a key from the internal
// storage.
func (mt *MerkleTree) dbGet(k []byte) (NodeType, []byte, error) {

+ 130
- 0
merkletree_test.go

@ -290,3 +290,133 @@ node [fontname=Monospace,fontsize=10,shape=box]
mt.GraphViz(w, nil)
assert.Equal(t, []byte(expected), w.Bytes())
}
func TestDelete(t *testing.T) {
mt, err := NewMerkleTree(db.NewMemoryStorage(), 10)
assert.Nil(t, err)
assert.Equal(t, "0", mt.Root().String())
// test vectors generated using https://github.com/iden3/circomlib smt.js
err = mt.Add(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", mt.Root().BigInt().String())
err = mt.Add(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, "13563340744765267202993741297198970774200042973817962221376874695587906013050", mt.Root().BigInt().String())
err = mt.Add(big.NewInt(1234), big.NewInt(9876))
assert.Nil(t, err)
assert.Equal(t, "16970503620176669663662021947486532860010370357132361783766545149750777353066", mt.Root().BigInt().String())
// mt.PrintGraphViz(nil)
err = mt.Delete(big.NewInt(33))
// mt.PrintGraphViz(nil)
assert.Nil(t, err)
assert.Equal(t, "12820263606494630162816839760750120928463716794691735985748071431547370997091", mt.Root().BigInt().String())
}
func TestDelete2(t *testing.T) {
mt := newTestingMerkle(t, 140)
defer mt.db.Close()
for i := 0; i < 8; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(0)
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
expectedRoot := mt.Root()
k := big.NewInt(8)
v := big.NewInt(0)
err := mt.Add(k, v)
require.Nil(t, err)
err = mt.Delete(big.NewInt(8))
assert.Nil(t, err)
assert.Equal(t, expectedRoot, mt.Root())
mt2 := newTestingMerkle(t, 140)
defer mt2.db.Close()
for i := 0; i < 8; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(0)
if err := mt2.Add(k, v); err != nil {
t.Fatal(err)
}
}
assert.Equal(t, mt2.Root(), mt.Root())
}
func TestDelete3(t *testing.T) {
mt := newTestingMerkle(t, 140)
defer mt.db.Close()
err := mt.Add(big.NewInt(1), big.NewInt(1))
assert.Nil(t, err)
err = mt.Add(big.NewInt(2), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, "2427629547967522489273866134471574861207714751136138191708011221765688788661", mt.Root().BigInt().String())
err = mt.Delete(big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "10822920717809411688334493481050035035708810950159417482558569847174767667301", mt.Root().BigInt().String())
mt2 := newTestingMerkle(t, 140)
defer mt2.db.Close()
err = mt2.Add(big.NewInt(2), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, mt2.Root(), mt.Root())
}
func TestDelete4(t *testing.T) {
mt := newTestingMerkle(t, 140)
defer mt.db.Close()
err := mt.Add(big.NewInt(1), big.NewInt(1))
assert.Nil(t, err)
err = mt.Add(big.NewInt(2), big.NewInt(2))
assert.Nil(t, err)
err = mt.Add(big.NewInt(3), big.NewInt(3))
assert.Nil(t, err)
assert.Equal(t, "16614298246517994771186095530428786749320098419259206061045083278756632941513", mt.Root().BigInt().String())
err = mt.Delete(big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "6117330520107511783353383870014397665359816230889739699667943862706617498952", mt.Root().BigInt().String())
mt2 := newTestingMerkle(t, 140)
defer mt2.db.Close()
err = mt2.Add(big.NewInt(2), big.NewInt(2))
assert.Nil(t, err)
err = mt2.Add(big.NewInt(3), big.NewInt(3))
assert.Nil(t, err)
assert.Equal(t, mt2.Root(), mt.Root())
}
func TestDelete5(t *testing.T) {
mt, err := NewMerkleTree(db.NewMemoryStorage(), 10)
assert.Nil(t, err)
err = mt.Add(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
err = mt.Add(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, "13563340744765267202993741297198970774200042973817962221376874695587906013050", mt.Root().BigInt().String())
err = mt.Delete(big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "12075524681474630909546786277734445038384732558409197537058769521806571391765", mt.Root().BigInt().String())
mt2 := newTestingMerkle(t, 140)
defer mt2.db.Close()
err = mt2.Add(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, mt2.Root(), mt.Root())
}

+ 3
- 3
node.go

@ -53,21 +53,21 @@ func NewNodeEmpty() *Node {
// NewNodeFromBytes creates a new node by parsing the input []byte.
func NewNodeFromBytes(b []byte) (*Node, error) {
if len(b) < 1 {
return nil, ErrNodeDataBadSize
return nil, ErrNodeBytesBadSize
}
n := Node{Type: NodeType(b[0])}
b = b[1:]
switch n.Type {
case NodeTypeMiddle:
if len(b) != 2*ElemBytesLen {
return nil, ErrNodeDataBadSize
return nil, ErrNodeBytesBadSize
}
n.ChildL, n.ChildR = &Hash{}, &Hash{}
copy(n.ChildL[:], b[:ElemBytesLen])
copy(n.ChildR[:], b[ElemBytesLen:ElemBytesLen*2])
case NodeTypeLeaf:
if len(b) != 2*ElemBytesLen {
return nil, ErrNodeDataBadSize
return nil, ErrNodeBytesBadSize
}
n.Entry = [2]*Hash{{}, {}}
copy(n.Entry[0][:], b[0:32])

Loading…
Cancel
Save