From 3cba052c2aa11c41c75f44a2b51ccc189747a419 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Thu, 23 Jul 2020 22:05:53 +0200 Subject: [PATCH] Add mt.Update(key, value) method --- merkletree.go | 69 +++++++++++++++++++++++++++++++++++++++++++++- merkletree_test.go | 55 ++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 1 deletion(-) diff --git a/merkletree.go b/merkletree.go index 3a19e21..66b702a 100644 --- a/merkletree.go +++ b/merkletree.go @@ -342,6 +342,73 @@ func (mt *MerkleTree) Get(k *big.Int) (*big.Int, error) { return nil, ErrKeyNotFound } +// Update updates the value of a specified key in the MerkleTree, and updates +// the path from the leaf to the Root with the new values. +func (mt *MerkleTree) Update(k, v *big.Int) error { + // verify that the MerkleTree is writable + if !mt.writable { + return ErrNotWritable + } + + // verfy that k & are valid and fit inside the Finite Field. + if !cryptoUtils.CheckBigIntInField(k) { + return errors.New("Key not inside the Finite Field") + } + if !cryptoUtils.CheckBigIntInField(v) { + return errors.New("Key not inside the Finite Field") + } + tx, err := mt.db.NewTx() + if err != nil { + return err + } + mt.Lock() + defer mt.Unlock() + + kHash := NewHashFromBigInt(k) + vHash := NewHashFromBigInt(v) + path := getPath(mt.maxLevels, kHash[:]) + + nextKey := mt.rootKey + var siblings []*Hash + for i := 0; i < mt.maxLevels; i++ { + n, err := mt.GetNode(nextKey) + if err != nil { + return err + } + switch n.Type { + case NodeTypeEmpty: + return ErrKeyNotFound + case NodeTypeLeaf: + if bytes.Equal(kHash[:], n.Entry[0][:]) { + // update leaf and upload to the root + newNodeLeaf := NewNodeLeaf(kHash, vHash) + _, err := mt.addNode(tx, newNodeLeaf) + newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNodeLeaf, siblings) + if err != nil { + return err + } + mt.rootKey = newRootKey + mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + return tx.Commit() + } else { + return ErrKeyNotFound + } + case NodeTypeMiddle: + if path[i] { + nextKey = n.ChildR + siblings = append(siblings, n.ChildL) + } else { + nextKey = n.ChildL + siblings = append(siblings, n.ChildR) + } + default: + return ErrInvalidNodeFound + } + } + + return ErrKeyNotFound +} + // Delete removes the specified Key from the MerkleTree and updates the path // from the deleted key to the Root with the new values. This method removes // the key from the MerkleTree, but does not remove the old nodes from the @@ -403,7 +470,7 @@ func (mt *MerkleTree) Delete(k *big.Int) error { } } - return nil + return ErrKeyNotFound } // rmAndUpload removes the key, and goes up until the root updating all the nodes with the new values. diff --git a/merkletree_test.go b/merkletree_test.go index 1a9ce41..4825de8 100644 --- a/merkletree_test.go +++ b/merkletree_test.go @@ -136,6 +136,60 @@ func TestGet(t *testing.T) { assert.Nil(t, v) } +func TestUpdate(t *testing.T) { + mt := newTestingMerkle(t, 140) + defer mt.db.Close() + + for i := 0; i < 16; i++ { + k := big.NewInt(int64(i)) + v := big.NewInt(int64(i * 2)) + if err := mt.Add(k, v); err != nil { + t.Fatal(err) + } + } + v, err := mt.Get(big.NewInt(10)) + assert.Nil(t, err) + assert.Equal(t, big.NewInt(20), v) + + err = mt.Update(big.NewInt(10), big.NewInt(1024)) + v, err = mt.Get(big.NewInt(10)) + assert.Nil(t, err) + assert.Equal(t, big.NewInt(1024), v) + + err = mt.Update(big.NewInt(1000), big.NewInt(1024)) + assert.Equal(t, ErrKeyNotFound, err) +} + +func TestUpdate2(t *testing.T) { + mt1 := newTestingMerkle(t, 140) + defer mt1.db.Close() + mt2 := newTestingMerkle(t, 140) + defer mt2.db.Close() + + err := mt1.Add(big.NewInt(1), big.NewInt(119)) + assert.Nil(t, err) + err = mt1.Add(big.NewInt(2), big.NewInt(229)) + assert.Nil(t, err) + err = mt1.Add(big.NewInt(9876), big.NewInt(6789)) + assert.Nil(t, err) + + err = mt2.Add(big.NewInt(1), big.NewInt(11)) + assert.Nil(t, err) + err = mt2.Add(big.NewInt(2), big.NewInt(22)) + assert.Nil(t, err) + err = mt2.Add(big.NewInt(9876), big.NewInt(10)) + assert.Nil(t, err) + + err = mt1.Update(big.NewInt(1), big.NewInt(11)) + assert.Nil(t, err) + err = mt1.Update(big.NewInt(2), big.NewInt(22)) + assert.Nil(t, err) + err = mt2.Update(big.NewInt(9876), big.NewInt(6789)) + assert.Nil(t, err) + + assert.Equal(t, mt1.Root(), mt2.Root()) +} + func TestGenerateAndVerifyProof128(t *testing.T) { mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140) require.Nil(t, err) @@ -346,6 +400,7 @@ func TestDelete(t *testing.T) { err = mt.Delete(big.NewInt(1)) assert.Nil(t, err) assert.Equal(t, "0", mt.Root().String()) + } func TestDelete2(t *testing.T) {