mirror of
https://github.com/arnaucube/arbo.git
synced 2026-01-07 14:31:28 +01:00
Add tree.Update(k, v) method
This commit is contained in:
52
tree.go
52
tree.go
@@ -13,6 +13,7 @@ package arbo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
@@ -371,6 +372,57 @@ func getPath(numLevels int, k []byte) []bool {
|
||||
return path
|
||||
}
|
||||
|
||||
// Update updates the value for a given existing key. If the given key does not
|
||||
// exist, returns an error.
|
||||
func (t *Tree) Update(k, v []byte) error {
|
||||
t.updateAccessTime()
|
||||
|
||||
tx, err := t.db.NewTx()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyPath := make([]byte, t.hashFunction.Len())
|
||||
copy(keyPath[:], k)
|
||||
path := getPath(t.maxLevels, keyPath)
|
||||
|
||||
var siblings [][]byte
|
||||
_, valueAtBottom, siblings, err := t.down(tx, k, t.root, siblings, path, 0, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
oldKey, _ := readLeafValue(valueAtBottom)
|
||||
if !bytes.Equal(oldKey, k) {
|
||||
return fmt.Errorf("key %s does not exist", hex.EncodeToString(k))
|
||||
}
|
||||
|
||||
leafKey, leafValue, err := newLeafValue(t.hashFunction, k, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Put(leafKey, leafValue); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// go up to the root
|
||||
if len(siblings) == 0 {
|
||||
t.root = leafKey
|
||||
return tx.Commit()
|
||||
}
|
||||
root, err := t.up(tx, leafKey, siblings, path, len(siblings)-1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.root = root
|
||||
// store root to db
|
||||
if err := tx.Put(dbKeyRoot, t.root); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// GenProof generates a MerkleTree proof for the given key. If the key exists in
|
||||
// the Tree, the proof will be of existence, if the key does not exist in the
|
||||
// tree, the proof will be of non-existence.
|
||||
|
||||
62
tree_test.go
62
tree_test.go
@@ -104,8 +104,8 @@ func TestAddDifferentOrder(t *testing.T) {
|
||||
defer tree1.db.Close()
|
||||
|
||||
for i := 0; i < 16; i++ {
|
||||
k := SwapEndianness(big.NewInt(int64(i)).Bytes())
|
||||
v := SwapEndianness(big.NewInt(0).Bytes())
|
||||
k := BigIntToBytes(big.NewInt(int64(i)))
|
||||
v := BigIntToBytes(big.NewInt(0))
|
||||
if err := tree1.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -116,8 +116,8 @@ func TestAddDifferentOrder(t *testing.T) {
|
||||
defer tree2.db.Close()
|
||||
|
||||
for i := 16 - 1; i >= 0; i-- {
|
||||
k := big.NewInt(int64(i)).Bytes()
|
||||
v := big.NewInt(0).Bytes()
|
||||
k := BigIntToBytes(big.NewInt(int64(i)))
|
||||
v := BigIntToBytes(big.NewInt(0))
|
||||
if err := tree2.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -134,8 +134,8 @@ func TestAddRepeatedIndex(t *testing.T) {
|
||||
c.Assert(err, qt.IsNil)
|
||||
defer tree.db.Close()
|
||||
|
||||
k := big.NewInt(int64(3)).Bytes()
|
||||
v := big.NewInt(int64(12)).Bytes()
|
||||
k := BigIntToBytes(big.NewInt(int64(3)))
|
||||
v := BigIntToBytes(big.NewInt(int64(12)))
|
||||
if err := tree.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -144,6 +144,56 @@ func TestAddRepeatedIndex(t *testing.T) {
|
||||
c.Check(err, qt.ErrorMatches, "max virtual level 100")
|
||||
}
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
c := qt.New(t)
|
||||
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
|
||||
c.Assert(err, qt.IsNil)
|
||||
defer tree.db.Close()
|
||||
|
||||
k := BigIntToBytes(big.NewInt(int64(20)))
|
||||
v := BigIntToBytes(big.NewInt(int64(12)))
|
||||
if err := tree.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
v = BigIntToBytes(big.NewInt(int64(11)))
|
||||
err = tree.Update(k, v)
|
||||
c.Assert(err, qt.IsNil)
|
||||
|
||||
gettedKey, gettedValue, err := tree.Get(k)
|
||||
c.Assert(err, qt.IsNil)
|
||||
c.Check(gettedKey, qt.DeepEquals, k)
|
||||
c.Check(gettedValue, qt.DeepEquals, v)
|
||||
|
||||
// add more leafs to the tree to do another test
|
||||
for i := 0; i < 16; i++ {
|
||||
k := BigIntToBytes(big.NewInt(int64(i)))
|
||||
v := BigIntToBytes(big.NewInt(int64(i * 2)))
|
||||
if err := tree.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
k = BigIntToBytes(big.NewInt(int64(3)))
|
||||
v = BigIntToBytes(big.NewInt(int64(11)))
|
||||
// check that before the Update, value for 3 is !=11
|
||||
gettedKey, gettedValue, err = tree.Get(k)
|
||||
c.Assert(err, qt.IsNil)
|
||||
c.Check(gettedKey, qt.DeepEquals, k)
|
||||
c.Check(gettedValue, qt.Not(qt.DeepEquals), v)
|
||||
c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(6)))
|
||||
|
||||
err = tree.Update(k, v)
|
||||
c.Assert(err, qt.IsNil)
|
||||
|
||||
// check that after Update, the value for 3 is ==11
|
||||
gettedKey, gettedValue, err = tree.Get(k)
|
||||
c.Assert(err, qt.IsNil)
|
||||
c.Check(gettedKey, qt.DeepEquals, k)
|
||||
c.Check(gettedValue, qt.DeepEquals, v)
|
||||
c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(11)))
|
||||
}
|
||||
|
||||
func TestAux(t *testing.T) {
|
||||
c := qt.New(t)
|
||||
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
|
||||
|
||||
Reference in New Issue
Block a user