diff --git a/tree.go b/tree.go index 394165f..23d80bc 100644 --- a/tree.go +++ b/tree.go @@ -13,6 +13,7 @@ package arbo import ( "bytes" + "encoding/hex" "fmt" "io" "math" @@ -371,6 +372,57 @@ func getPath(numLevels int, k []byte) []bool { return path } +// Update updates the value for a given existing key. If the given key does not +// exist, returns an error. +func (t *Tree) Update(k, v []byte) error { + t.updateAccessTime() + + tx, err := t.db.NewTx() + if err != nil { + return err + } + + keyPath := make([]byte, t.hashFunction.Len()) + copy(keyPath[:], k) + path := getPath(t.maxLevels, keyPath) + + var siblings [][]byte + _, valueAtBottom, siblings, err := t.down(tx, k, t.root, siblings, path, 0, true) + if err != nil { + return err + } + oldKey, _ := readLeafValue(valueAtBottom) + if !bytes.Equal(oldKey, k) { + return fmt.Errorf("key %s does not exist", hex.EncodeToString(k)) + } + + leafKey, leafValue, err := newLeafValue(t.hashFunction, k, v) + if err != nil { + return err + } + + if err := tx.Put(leafKey, leafValue); err != nil { + return err + } + + // go up to the root + if len(siblings) == 0 { + t.root = leafKey + return tx.Commit() + } + root, err := t.up(tx, leafKey, siblings, path, len(siblings)-1) + if err != nil { + return err + } + + t.root = root + // store root to db + if err := tx.Put(dbKeyRoot, t.root); err != nil { + return err + } + return tx.Commit() +} + // GenProof generates a MerkleTree proof for the given key. If the key exists in // the Tree, the proof will be of existence, if the key does not exist in the // tree, the proof will be of non-existence. diff --git a/tree_test.go b/tree_test.go index 7444836..bb67954 100644 --- a/tree_test.go +++ b/tree_test.go @@ -104,8 +104,8 @@ func TestAddDifferentOrder(t *testing.T) { defer tree1.db.Close() for i := 0; i < 16; i++ { - k := SwapEndianness(big.NewInt(int64(i)).Bytes()) - v := SwapEndianness(big.NewInt(0).Bytes()) + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(0)) if err := tree1.Add(k, v); err != nil { t.Fatal(err) } @@ -116,8 +116,8 @@ func TestAddDifferentOrder(t *testing.T) { defer tree2.db.Close() for i := 16 - 1; i >= 0; i-- { - k := big.NewInt(int64(i)).Bytes() - v := big.NewInt(0).Bytes() + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(0)) if err := tree2.Add(k, v); err != nil { t.Fatal(err) } @@ -134,8 +134,8 @@ func TestAddRepeatedIndex(t *testing.T) { c.Assert(err, qt.IsNil) defer tree.db.Close() - k := big.NewInt(int64(3)).Bytes() - v := big.NewInt(int64(12)).Bytes() + k := BigIntToBytes(big.NewInt(int64(3))) + v := BigIntToBytes(big.NewInt(int64(12))) if err := tree.Add(k, v); err != nil { t.Fatal(err) } @@ -144,6 +144,56 @@ func TestAddRepeatedIndex(t *testing.T) { c.Check(err, qt.ErrorMatches, "max virtual level 100") } +func TestUpdate(t *testing.T) { + c := qt.New(t) + tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + defer tree.db.Close() + + k := BigIntToBytes(big.NewInt(int64(20))) + v := BigIntToBytes(big.NewInt(int64(12))) + if err := tree.Add(k, v); err != nil { + t.Fatal(err) + } + + v = BigIntToBytes(big.NewInt(int64(11))) + err = tree.Update(k, v) + c.Assert(err, qt.IsNil) + + gettedKey, gettedValue, err := tree.Get(k) + c.Assert(err, qt.IsNil) + c.Check(gettedKey, qt.DeepEquals, k) + c.Check(gettedValue, qt.DeepEquals, v) + + // add more leafs to the tree to do another test + for i := 0; i < 16; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + if err := tree.Add(k, v); err != nil { + t.Fatal(err) + } + } + + k = BigIntToBytes(big.NewInt(int64(3))) + v = BigIntToBytes(big.NewInt(int64(11))) + // check that before the Update, value for 3 is !=11 + gettedKey, gettedValue, err = tree.Get(k) + c.Assert(err, qt.IsNil) + c.Check(gettedKey, qt.DeepEquals, k) + c.Check(gettedValue, qt.Not(qt.DeepEquals), v) + c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(6))) + + err = tree.Update(k, v) + c.Assert(err, qt.IsNil) + + // check that after Update, the value for 3 is ==11 + gettedKey, gettedValue, err = tree.Get(k) + c.Assert(err, qt.IsNil) + c.Check(gettedKey, qt.DeepEquals, k) + c.Check(gettedValue, qt.DeepEquals, v) + c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(11))) +} + func TestAux(t *testing.T) { c := qt.New(t) tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)