From 3b1436b49547e1df8133eeba64655afa467b2669 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Wed, 8 Jul 2020 00:05:30 +0200 Subject: [PATCH] Add mt.Delete(key) with test vectors --- README.md | 5 +- merkletree.go | 128 +++++++++++++++++++++++++++++++++++++++++++- merkletree_test.go | 130 +++++++++++++++++++++++++++++++++++++++++++++ node.go | 6 +-- 4 files changed, 263 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index a28475c..424e881 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ MerkleTree compatible with version from [circomlib](https://github.com/iden3/circomlib). -Adaptation of the merkletree from https://github.com/iden3/go-iden3-core/tree/v0.0.8 +Adaptation of the merkletree from https://github.com/iden3/go-iden3-core/tree/v0.0.8 with several changes and more functionalities. ## Usage More detailed examples can be found at the [tests](https://github.com/iden3/go-merkletree/blob/master/merkletree_test.go), and in the [documentation](https://godoc.org/github.com/iden3/go-merkletree). @@ -33,5 +33,8 @@ func TestExampleMerkleTree(t *testing.T) { assert.Nil(t, err) assert.True(t, VerifyProof(mt.Root(), proof, key, value)) + + err := mt.Delete(big.NewInt(1)) // delete the leaf of key=1 + assert.Nil(t, err) } ``` diff --git a/merkletree.go b/merkletree.go index 7f4c551..16e210b 100644 --- a/merkletree.go +++ b/merkletree.go @@ -27,9 +27,9 @@ var ( ErrNodeKeyAlreadyExists = errors.New("node already exists") // ErrEntryIndexNotFound is used when no entry is found for an index. ErrEntryIndexNotFound = errors.New("node index not found in the DB") - // ErrNodeDataBadSize is used when the data of a node has an incorrect + // ErrNodeBytesBadSize is used when the data of a node has an incorrect // size and can't be parsed. - ErrNodeDataBadSize = errors.New("node data has incorrect size in the DB") + ErrNodeBytesBadSize = errors.New("node data has incorrect size in the DB") // ErrReachedMaxLevel is used when a traversal of the MT reaches the // maximum level. ErrReachedMaxLevel = errors.New("reached maximum level of the merkle tree") @@ -47,6 +47,8 @@ var ( ErrEntryIndexAlreadyExists = errors.New("the entry index already exists in the tree") // ErrNotWritable is used when the MerkleTree is not writable and a write function is called ErrNotWritable = errors.New("Merkle Tree not writable") + // ErrKeyNotFound is used when a key is not found in the MerkleTree. + ErrKeyNotFound = errors.New("Key not found in the tree") rootNodeValue = []byte("currentroot") // HashZero is used at Empty nodes HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} @@ -277,6 +279,128 @@ func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) { return k, nil } +// Delete removes the specified Key from the MerkleTree, and updates the pad from the delted key to the Root with the new values. +// This method removes the key from the MerkleTree, but does not remove the old nodes from the key-value database; this means that if the tree is accessed by an old Root where the key was not deleted yet, the key will still exist. If is desired to remove the key-values from the database that are not under the current Root, an option could be to dump all the claims and import them in a new MerkleTree in a new database, but this will loose all the Root history of the MerkleTree +func (mt *MerkleTree) Delete(k *big.Int) error { + // verify that the MerkleTree is writable + if !mt.writable { + return ErrNotWritable + } + + // verfy that the ElemBytes are valid and fit inside the Finite Field. + if !cryptoUtils.CheckBigIntInField(k) { + return errors.New("Key not inside the Finite Field") + } + tx, err := mt.db.NewTx() + if err != nil { + return err + } + mt.Lock() + defer mt.Unlock() + + kHash := NewHashFromBigInt(k) + path := getPath(mt.maxLevels, kHash[:]) + + nextKey := mt.rootKey + var siblings []*Hash + for i := 0; i < mt.maxLevels; i++ { + n, err := mt.GetNode(nextKey) + if err != nil { + return err + } + switch n.Type { + case NodeTypeEmpty: + return nil + case NodeTypeLeaf: + if bytes.Equal(kHash[:], n.Entry[0][:]) { + // remove and go up with the sibling + err = mt.rmAndUpload(tx, path, kHash, siblings) + return err + } else { + return ErrKeyNotFound + } + case NodeTypeMiddle: + if path[i] { + nextKey = n.ChildR + siblings = append(siblings, n.ChildL) + } else { + nextKey = n.ChildL + siblings = append(siblings, n.ChildR) + } + default: + return ErrInvalidNodeFound + } + } + + return nil +} + +// rmAndUpload removes the key, and goes up until the root updating all the nodes with the new values. +func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings []*Hash) error { + toUpload := siblings[len(siblings)-1] + if len(siblings) < 2 { + mt.rootKey = siblings[0] + mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + return tx.Commit() + } + for i := len(siblings) - 2; i >= 0; i-- { + if !bytes.Equal(siblings[i][:], HashZero[:]) { + var newNode *Node + if path[i] { + newNode = NewNodeMiddle(siblings[i], toUpload) + } else { + newNode = NewNodeMiddle(toUpload, siblings[i]) + } + _, err := mt.addNode(tx, newNode) + if err != ErrNodeKeyAlreadyExists && err != nil { + return err + } + // go up until the root + newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNode, siblings[:i]) + if err != nil { + return err + } + mt.rootKey = newRootKey + mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + break + } + // if i==0 (root position), stop and store the sibling of the deleted leaf as root + if i == 0 { + mt.rootKey = toUpload + mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + break + } + } + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +// recalculatePathUntilRoot recalculates the nodes until the Root +func (mt *MerkleTree) recalculatePathUntilRoot(tx db.Tx, path []bool, node *Node, siblings []*Hash) (*Hash, error) { + for i := len(siblings) - 1; i >= 0; i-- { + nodeKey, err := node.Key() + if err != nil { + return nil, err + } + if path[i] { + node = NewNodeMiddle(siblings[i], nodeKey) + } else { + node = NewNodeMiddle(nodeKey, siblings[i]) + } + _, err = mt.addNode(tx, node) + if err != ErrNodeKeyAlreadyExists && err != nil { + return nil, err + } + } + + // return last node added, which is the root + nodeKey, err := node.Key() + return nodeKey, err +} + // dbGet is a helper function to get the node of a key from the internal // storage. func (mt *MerkleTree) dbGet(k []byte) (NodeType, []byte, error) { diff --git a/merkletree_test.go b/merkletree_test.go index 99bbc1c..5b48e3c 100644 --- a/merkletree_test.go +++ b/merkletree_test.go @@ -290,3 +290,133 @@ node [fontname=Monospace,fontsize=10,shape=box] mt.GraphViz(w, nil) assert.Equal(t, []byte(expected), w.Bytes()) } + +func TestDelete(t *testing.T) { + mt, err := NewMerkleTree(db.NewMemoryStorage(), 10) + assert.Nil(t, err) + assert.Equal(t, "0", mt.Root().String()) + + // test vectors generated using https://github.com/iden3/circomlib smt.js + err = mt.Add(big.NewInt(1), big.NewInt(2)) + assert.Nil(t, err) + assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", mt.Root().BigInt().String()) + + err = mt.Add(big.NewInt(33), big.NewInt(44)) + assert.Nil(t, err) + assert.Equal(t, "13563340744765267202993741297198970774200042973817962221376874695587906013050", mt.Root().BigInt().String()) + + err = mt.Add(big.NewInt(1234), big.NewInt(9876)) + assert.Nil(t, err) + assert.Equal(t, "16970503620176669663662021947486532860010370357132361783766545149750777353066", mt.Root().BigInt().String()) + + // mt.PrintGraphViz(nil) + + err = mt.Delete(big.NewInt(33)) + // mt.PrintGraphViz(nil) + assert.Nil(t, err) + assert.Equal(t, "12820263606494630162816839760750120928463716794691735985748071431547370997091", mt.Root().BigInt().String()) +} + +func TestDelete2(t *testing.T) { + mt := newTestingMerkle(t, 140) + defer mt.db.Close() + for i := 0; i < 8; i++ { + k := big.NewInt(int64(i)) + v := big.NewInt(0) + if err := mt.Add(k, v); err != nil { + t.Fatal(err) + } + } + + expectedRoot := mt.Root() + + k := big.NewInt(8) + v := big.NewInt(0) + err := mt.Add(k, v) + require.Nil(t, err) + + err = mt.Delete(big.NewInt(8)) + assert.Nil(t, err) + assert.Equal(t, expectedRoot, mt.Root()) + + mt2 := newTestingMerkle(t, 140) + defer mt2.db.Close() + for i := 0; i < 8; i++ { + k := big.NewInt(int64(i)) + v := big.NewInt(0) + if err := mt2.Add(k, v); err != nil { + t.Fatal(err) + } + } + assert.Equal(t, mt2.Root(), mt.Root()) +} + +func TestDelete3(t *testing.T) { + mt := newTestingMerkle(t, 140) + defer mt.db.Close() + + err := mt.Add(big.NewInt(1), big.NewInt(1)) + assert.Nil(t, err) + + err = mt.Add(big.NewInt(2), big.NewInt(2)) + assert.Nil(t, err) + + assert.Equal(t, "2427629547967522489273866134471574861207714751136138191708011221765688788661", mt.Root().BigInt().String()) + err = mt.Delete(big.NewInt(1)) + assert.Nil(t, err) + assert.Equal(t, "10822920717809411688334493481050035035708810950159417482558569847174767667301", mt.Root().BigInt().String()) + + mt2 := newTestingMerkle(t, 140) + defer mt2.db.Close() + err = mt2.Add(big.NewInt(2), big.NewInt(2)) + assert.Nil(t, err) + assert.Equal(t, mt2.Root(), mt.Root()) +} + +func TestDelete4(t *testing.T) { + mt := newTestingMerkle(t, 140) + defer mt.db.Close() + + err := mt.Add(big.NewInt(1), big.NewInt(1)) + assert.Nil(t, err) + + err = mt.Add(big.NewInt(2), big.NewInt(2)) + assert.Nil(t, err) + + err = mt.Add(big.NewInt(3), big.NewInt(3)) + assert.Nil(t, err) + + assert.Equal(t, "16614298246517994771186095530428786749320098419259206061045083278756632941513", mt.Root().BigInt().String()) + err = mt.Delete(big.NewInt(1)) + assert.Nil(t, err) + assert.Equal(t, "6117330520107511783353383870014397665359816230889739699667943862706617498952", mt.Root().BigInt().String()) + + mt2 := newTestingMerkle(t, 140) + defer mt2.db.Close() + err = mt2.Add(big.NewInt(2), big.NewInt(2)) + assert.Nil(t, err) + err = mt2.Add(big.NewInt(3), big.NewInt(3)) + assert.Nil(t, err) + assert.Equal(t, mt2.Root(), mt.Root()) +} + +func TestDelete5(t *testing.T) { + mt, err := NewMerkleTree(db.NewMemoryStorage(), 10) + assert.Nil(t, err) + + err = mt.Add(big.NewInt(1), big.NewInt(2)) + assert.Nil(t, err) + err = mt.Add(big.NewInt(33), big.NewInt(44)) + assert.Nil(t, err) + assert.Equal(t, "13563340744765267202993741297198970774200042973817962221376874695587906013050", mt.Root().BigInt().String()) + + err = mt.Delete(big.NewInt(1)) + assert.Nil(t, err) + assert.Equal(t, "12075524681474630909546786277734445038384732558409197537058769521806571391765", mt.Root().BigInt().String()) + + mt2 := newTestingMerkle(t, 140) + defer mt2.db.Close() + err = mt2.Add(big.NewInt(33), big.NewInt(44)) + assert.Nil(t, err) + assert.Equal(t, mt2.Root(), mt.Root()) +} diff --git a/node.go b/node.go index af73b6f..57ea09a 100644 --- a/node.go +++ b/node.go @@ -53,21 +53,21 @@ func NewNodeEmpty() *Node { // NewNodeFromBytes creates a new node by parsing the input []byte. func NewNodeFromBytes(b []byte) (*Node, error) { if len(b) < 1 { - return nil, ErrNodeDataBadSize + return nil, ErrNodeBytesBadSize } n := Node{Type: NodeType(b[0])} b = b[1:] switch n.Type { case NodeTypeMiddle: if len(b) != 2*ElemBytesLen { - return nil, ErrNodeDataBadSize + return nil, ErrNodeBytesBadSize } n.ChildL, n.ChildR = &Hash{}, &Hash{} copy(n.ChildL[:], b[:ElemBytesLen]) copy(n.ChildR[:], b[ElemBytesLen:ElemBytesLen*2]) case NodeTypeLeaf: if len(b) != 2*ElemBytesLen { - return nil, ErrNodeDataBadSize + return nil, ErrNodeBytesBadSize } n.Entry = [2]*Hash{{}, {}} copy(n.Entry[0][:], b[0:32])