diff --git a/tree.go b/tree.go index 843f4f8..58c7e83 100644 --- a/tree.go +++ b/tree.go @@ -756,6 +756,30 @@ func (t *Tree) GetNLeafs() (int, error) { return int(nLeafs), nil } +// Snapshot returns a copy of the Tree from the given root +func (t *Tree) Snapshot(rootKey []byte) (*Tree, error) { + // TODO currently this method only changes the 'root pointer', but the + // db continues being the same. In a future iteration, once the + // db.Database interface allows to do database checkpoints, this method + // could be updated to do a full checkpoint of the database for the + // snapshot, to return a completly new independent tree containing the + // snapshot. + t.RLock() + defer t.RUnlock() + + // allow to define which root to use + if rootKey == nil { + rootKey = t.Root() + } + return &Tree{ + db: t.db, + maxLevels: t.maxLevels, + root: rootKey, + hashFunction: t.hashFunction, + dbg: t.dbg, + }, nil +} + // Iterate iterates through the full Tree, executing the given function on each // node of the Tree. func (t *Tree) Iterate(rootKey []byte, f func([]byte, []byte)) error { @@ -844,8 +868,8 @@ func (t *Tree) Dump(rootKey []byte) ([]byte, error) { return b, err } -// ImportDump imports the leafs (that have been exported with the ExportLeafs -// method) in the Tree. +// ImportDump imports the leafs (that have been exported with the Dump method) +// in the Tree. func (t *Tree) ImportDump(b []byte) error { r := bytes.NewReader(b) var err error diff --git a/tree_test.go b/tree_test.go index d8c21ae..c7dc45b 100644 --- a/tree_test.go +++ b/tree_test.go @@ -427,6 +427,61 @@ func TestSetGetNLeafs(t *testing.T) { c.Assert(n, qt.Equals, maxInt) } +func TestSnapshot(t *testing.T) { + c := qt.New(t) + database, err := db.NewBadgerDB(c.TempDir()) + c.Assert(err, qt.IsNil) + tree, err := NewTree(database, 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + + // fill the tree + bLen := tree.HashFunction().Len() + var keys, values [][]byte + for i := 0; i < 1000; i++ { + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i))) + keys = append(keys, k) + values = append(values, v) + } + indexes, err := tree.AddBatch(keys, values) + c.Assert(err, qt.IsNil) + c.Check(len(indexes), qt.Equals, 0) + + rootBI := BytesToBigInt(tree.Root()) + c.Check(rootBI.String(), qt.Equals, + "13742386369878513332697380582061714160370929283209286127733983161245560237407") + + tree2, err := tree.Snapshot(nil) + c.Assert(err, qt.IsNil) + rootBI = BytesToBigInt(tree2.Root()) + c.Check(rootBI.String(), qt.Equals, + "13742386369878513332697380582061714160370929283209286127733983161245560237407") + + // add k-v to original tree + k := BigIntToBytes(bLen, big.NewInt(int64(1000))) + v := BigIntToBytes(bLen, big.NewInt(int64(1000))) + err = tree.Add(k, v) + c.Assert(err, qt.IsNil) + + // expect original tree to have new root + rootBI = BytesToBigInt(tree.Root()) + c.Check(rootBI.String(), qt.Equals, + "10747149055773881257049574592162159501044114324358186833013814735296193179713") + + // expect snapshot tree to have the old root + rootBI = BytesToBigInt(tree2.Root()) + c.Check(rootBI.String(), qt.Equals, + "13742386369878513332697380582061714160370929283209286127733983161245560237407") + + err = tree2.Add(k, v) + c.Assert(err, qt.IsNil) + // after adding also the k-v into the snapshot tree, expect original + // tree to have new root + rootBI = BytesToBigInt(tree.Root()) + c.Check(rootBI.String(), qt.Equals, + "10747149055773881257049574592162159501044114324358186833013814735296193179713") +} + func BenchmarkAdd(b *testing.B) { bLen := 32 // for both Poseidon & Sha256 // prepare inputs