diff --git a/tree.go b/tree.go index b11aa93..a594b30 100644 --- a/tree.go +++ b/tree.go @@ -63,8 +63,8 @@ var ( // ErrMaxVirtualLevel indicates when going down into the tree, the max // virtual level is reached ErrMaxVirtualLevel = fmt.Errorf("max virtual level reached") - // ErrSnapshotNotEditable indicates when the tree is a snapshot, thus - // can not be modified + // ErrSnapshotNotEditable indicates when the tree is a non writable + // snapshot, thus can not be modified ErrSnapshotNotEditable = fmt.Errorf("snapshot tree can not be edited") // ErrTreeNotEmpty indicates when the tree was expected to be empty and // it is not @@ -741,10 +741,10 @@ func bytesToBitmap(b []byte) []bool { return bitmap } -// Get returns the value for a given key. If the key is not found, will return -// the error ErrKeyNotFound, and in the leafK & leafV parameters will be placed -// the data found in the tree in the leaf that was on the path going to the -// input key. +// Get returns the value in the Tree for a given key. If the key is not found, +// will return the error ErrKeyNotFound, and in the leafK & leafV parameters +// will be placed the data found in the tree in the leaf that was on the path +// going to the input key. func (t *Tree) Get(k []byte) ([]byte, []byte, error) { rTx := t.db.ReadTx() defer rTx.Discard() @@ -853,6 +853,39 @@ func (t *Tree) GetNLeafsWithTx(rTx db.ReadTx) (int, error) { return int(nLeafs), nil } +// SetRoot sets the root to the given root +func (t *Tree) SetRoot(root []byte) error { + wTx := t.db.WriteTx() + defer wTx.Discard() + + if err := t.SetRootWithTx(wTx, root); err != nil { + return err + } + return wTx.Commit() +} + +// SetRootWithTx sets the root to the given root using the given db.WriteTx +func (t *Tree) SetRootWithTx(wTx db.WriteTx, root []byte) error { + if !t.editable() { + return ErrSnapshotNotEditable + } + + if root == nil { + return fmt.Errorf("can not SetRoot with nil root") + } + + // check that the root exists in the db + if !bytes.Equal(root, t.emptyHash) { + if _, err := wTx.Get(root); err == ErrKeyNotFound { + return fmt.Errorf("can not SetRoot with root %x, as it does not exist in the db", root) + } else if err != nil { + return err + } + } + + return wTx.Set(dbKeyRoot, root) +} + // Snapshot returns a read-only copy of the Tree from the given root func (t *Tree) Snapshot(fromRoot []byte) (*Tree, error) { // allow to define which root to use @@ -863,6 +896,19 @@ func (t *Tree) Snapshot(fromRoot []byte) (*Tree, error) { return nil, err } } + rTx := t.db.ReadTx() + defer rTx.Discard() + // check that the root exists in the db + if !bytes.Equal(fromRoot, t.emptyHash) { + if _, err := rTx.Get(fromRoot); err == ErrKeyNotFound { + return nil, + fmt.Errorf("can not do a Snapshot with root %x, as it does not exist in the db", + fromRoot) + } else if err != nil { + return nil, err + } + } + return &Tree{ db: t.db, maxLevels: t.maxLevels, diff --git a/tree_test.go b/tree_test.go index d0ce615..c794377 100644 --- a/tree_test.go +++ b/tree_test.go @@ -514,6 +514,62 @@ func TestAddBatchFullyUsed(t *testing.T) { c.Assert(1, qt.Equals, len(invalids)) } +func TestSetRoot(t *testing.T) { + c := qt.New(t) + database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) + c.Assert(err, qt.IsNil) + tree, err := NewTree(database, 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + + expectedRoot := "13742386369878513332697380582061714160370929283209286127733983161245560237407" + + // fill the tree + bLen := tree.HashFunction().Len() + var keys, values [][]byte + for i := 0; i < 1000; i++ { + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i))) + keys = append(keys, k) + values = append(values, v) + } + indexes, err := tree.AddBatch(keys, values) + c.Assert(err, qt.IsNil) + c.Check(len(indexes), qt.Equals, 0) + checkRootBIString(c, tree, + expectedRoot) + + // add one more k-v + k := BigIntToBytes(bLen, big.NewInt(1000)) + v := BigIntToBytes(bLen, big.NewInt(1000)) + err = tree.Add(k, v) + c.Assert(err, qt.IsNil) + checkRootBIString(c, tree, + "10747149055773881257049574592162159501044114324358186833013814735296193179713") + + // do a SetRoot, and expect the same root than the original tree + pastRootBI, ok := new(big.Int).SetString(expectedRoot, 10) + c.Assert(ok, qt.IsTrue) + pastRoot := BigIntToBytes(32, pastRootBI) + + err = tree.SetRoot(pastRoot) + c.Assert(err, qt.IsNil) + checkRootBIString(c, tree, expectedRoot) + + // check that the tree can be updated + err = tree.Add([]byte("test"), []byte("test")) + c.Assert(err, qt.IsNil) + err = tree.Update([]byte("test"), []byte("test")) + c.Assert(err, qt.IsNil) + + // check that the k-v '1000' does not exist in the new tree + _, _, err = tree.Get(k) + c.Assert(err, qt.Equals, ErrKeyNotFound) + + // check that can be set an empty root + err = tree.SetRoot(tree.emptyHash) + c.Assert(err, qt.IsNil) +} + func TestSnapshot(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})