diff --git a/addbatch_test.go b/addbatch_test.go index ae45417..d02fd10 100644 --- a/addbatch_test.go +++ b/addbatch_test.go @@ -914,7 +914,9 @@ func TestLoadVT(t *testing.T) { c.Assert(err, qt.IsNil) // check that tree & vt roots are equal - c.Check(tree.Root(), qt.DeepEquals, vt.root.h) + root, err := tree.Root() + c.Assert(err, qt.IsNil) + c.Check(root, qt.DeepEquals, vt.root.h) } // TestAddKeysWithEmptyValues calls AddBatch giving an array of empty values @@ -976,12 +978,14 @@ func TestAddKeysWithEmptyValues(t *testing.T) { c.Assert(existence, qt.IsTrue) // check with empty array - verif, err := CheckProof(tree.hashFunction, keys[9], []byte{}, tree.Root(), siblings) + root, err := tree.Root() + c.Assert(err, qt.IsNil) + verif, err := CheckProof(tree.hashFunction, keys[9], []byte{}, root, siblings) c.Assert(err, qt.IsNil) c.Check(verif, qt.IsTrue) // check with array with only 1 zero - verif, err = CheckProof(tree.hashFunction, keys[9], []byte{0}, tree.Root(), siblings) + verif, err = CheckProof(tree.hashFunction, keys[9], []byte{0}, root, siblings) c.Assert(err, qt.IsNil) c.Check(verif, qt.IsTrue) @@ -989,12 +993,12 @@ func TestAddKeysWithEmptyValues(t *testing.T) { e32 := []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} c.Assert(len(e32), qt.Equals, 32) - verif, err = CheckProof(tree.hashFunction, keys[9], e32, tree.Root(), siblings) + verif, err = CheckProof(tree.hashFunction, keys[9], e32, root, siblings) c.Assert(err, qt.IsNil) c.Check(verif, qt.IsTrue) // check with array with value!=0 returns false at verification - verif, err = CheckProof(tree.hashFunction, keys[9], []byte{0, 1}, tree.Root(), siblings) + verif, err = CheckProof(tree.hashFunction, keys[9], []byte{0, 1}, root, siblings) c.Assert(err, qt.IsNil) c.Check(verif, qt.IsFalse) } diff --git a/circomproofs.go b/circomproofs.go index 9a5f8e1..feb0b3a 100644 --- a/circomproofs.go +++ b/circomproofs.go @@ -63,7 +63,10 @@ func (t *Tree) GenerateCircomVerifierProof(k []byte) (*CircomVerifierProof, erro return nil, err } var cp CircomVerifierProof - cp.Root = t.Root() + cp.Root, err = t.Root() + if err != nil { + return nil, err + } s, err := UnpackSiblings(t.hashFunction, siblings) if err != nil { return nil, err diff --git a/helpers_test.go b/helpers_test.go index 0094a61..7382225 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -13,7 +13,11 @@ import ( ) func checkRoots(c *qt.C, tree1, tree2 *Tree) { - if !bytes.Equal(tree2.Root(), tree1.Root()) { + root1, err := tree1.Root() + c.Assert(err, qt.IsNil) + root2, err := tree2.Root() + c.Assert(err, qt.IsNil) + if !bytes.Equal(root2, root1) { dir := "err-dump" if _, err := os.Stat(dir); os.IsNotExist(err) { err := os.Mkdir(dir, os.ModePerm) @@ -25,7 +29,11 @@ func checkRoots(c *qt.C, tree1, tree2 *Tree) { // store tree2 storeTree(c, tree2, dir+"/tree2") - c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) + root1, err := tree1.Root() + c.Assert(err, qt.IsNil) + root2, err := tree2.Root() + c.Assert(err, qt.IsNil) + c.Check(root2, qt.DeepEquals, root1) } } @@ -103,5 +111,9 @@ func TestReadTreeDBG(t *testing.T) { // tree1.PrintGraphvizFirstNLevels(nil, 6) // tree2.PrintGraphvizFirstNLevels(nil, 6) - c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) + root1, err := tree1.Root() + c.Assert(err, qt.IsNil) + root2, err := tree2.Root() + c.Assert(err, qt.IsNil) + c.Check(root2, qt.DeepEquals, root1) } diff --git a/tree.go b/tree.go index 01d673d..5719374 100644 --- a/tree.go +++ b/tree.go @@ -63,15 +63,18 @@ var ( // ErrMaxVirtualLevel indicates when going down into the tree, the max // virtual level is reached ErrMaxVirtualLevel = fmt.Errorf("max virtual level reached") + // ErrSnapshotNotEditable indicates when the tree is a snapshot, thus + // can not be modified + ErrSnapshotNotEditable = fmt.Errorf("snapshot tree can not be edited") ) // Tree defines the struct that implements the MerkleTree functionalities type Tree struct { sync.RWMutex - db db.Database - maxLevels int - root []byte + db db.Database + maxLevels int + snapshotRoot []byte hashFunction HashFunction // TODO in the methods that use it, check if emptyHash param is len>0 @@ -90,11 +93,10 @@ func NewTree(database db.Database, maxLevels int, hash HashFunction) (*Tree, err wTx := t.db.WriteTx() defer wTx.Discard() - root, err := wTx.Get(dbKeyRoot) + _, err := wTx.Get(dbKeyRoot) if err == db.ErrKeyNotFound { - // store new root 0 - t.root = t.emptyHash - if err = wTx.Set(dbKeyRoot, t.root); err != nil { + // store new root 0 (empty) + if err = wTx.Set(dbKeyRoot, t.emptyHash); err != nil { return nil, err } if err = t.setNLeafs(wTx, 0); err != nil { @@ -111,14 +113,30 @@ func NewTree(database db.Database, maxLevels int, hash HashFunction) (*Tree, err if err = wTx.Commit(); err != nil { return nil, err } - t.root = root return &t, nil } // Root returns the root of the Tree -func (t *Tree) Root() []byte { - // TODO get Root from db - return t.root +func (t *Tree) Root() ([]byte, error) { + rTx := t.db.ReadTx() + defer rTx.Discard() + return t.RootWithTx(rTx) +} + +// RootWithTx returns the root of the Tree using the given db.ReadTx +func (t *Tree) RootWithTx(rTx db.ReadTx) ([]byte, error) { + // if snapshotRoot is defined, means that the tree is a snapshot, and + // the root is not obtained from the db, but from the snapshotRoot + // parameter + if t.snapshotRoot != nil { + return t.snapshotRoot, nil + } + // get db root + return rTx.Get(dbKeyRoot) +} + +func (t *Tree) setRoot(wTx db.WriteTx, root []byte) error { + return wTx.Set(dbKeyRoot, root) } // HashFunction returns Tree.hashFunction @@ -126,6 +144,12 @@ func (t *Tree) HashFunction() HashFunction { return t.hashFunction } +// editable returns true if the tree is editable, and false when is not +// editable (because is a snapshot tree) +func (t *Tree) editable() bool { + return t.snapshotRoot == nil +} + // AddBatch adds a batch of key-values to the Tree. Returns an array containing // the indexes of the keys failed to add. Supports empty values as input // parameters, which is equivalent to 0 valued byte array. @@ -147,6 +171,10 @@ func (t *Tree) AddBatchWithTx(wTx db.WriteTx, keys, values [][]byte) ([]int, err t.Lock() defer t.Unlock() + if !t.editable() { + return nil, ErrSnapshotNotEditable + } + vt, err := t.loadVT(wTx) if err != nil { return nil, err @@ -177,7 +205,6 @@ func (t *Tree) AddBatchWithTx(wTx db.WriteTx, keys, values [][]byte) ([]int, err // nothing stored in the db and the error is returned return nil, err } - t.root = vt.root.h // store pairs in db for i := 0; i < len(pairs); i++ { @@ -186,8 +213,8 @@ func (t *Tree) AddBatchWithTx(wTx db.WriteTx, keys, values [][]byte) ([]int, err } } - // store root to db - if err := wTx.Set(dbKeyRoot, t.root); err != nil { + // store root (from the vt) to db + if err := wTx.Set(dbKeyRoot, vt.root.h); err != nil { return nil, err } @@ -239,12 +266,21 @@ func (t *Tree) AddWithTx(wTx db.WriteTx, k, v []byte) error { t.Lock() defer t.Unlock() - err := t.add(wTx, 0, k, v) // add from level 0 + if !t.editable() { + return ErrSnapshotNotEditable + } + + root, err := t.RootWithTx(wTx) + if err != nil { + return err + } + + root, err = t.add(wTx, root, 0, k, v) // add from level 0 if err != nil { return err } // store root to db - if err := wTx.Set(dbKeyRoot, t.root); err != nil { + if err := t.setRoot(wTx, root); err != nil { return err } // update nLeafs @@ -254,39 +290,38 @@ func (t *Tree) AddWithTx(wTx db.WriteTx, k, v []byte) error { return nil } -func (t *Tree) add(wTx db.WriteTx, fromLvl int, k, v []byte) error { +func (t *Tree) add(wTx db.WriteTx, root []byte, fromLvl int, k, v []byte) ([]byte, error) { keyPath := make([]byte, t.hashFunction.Len()) copy(keyPath[:], k) path := getPath(t.maxLevels, keyPath) // go down to the leaf var siblings [][]byte - _, _, siblings, err := t.down(wTx, k, t.root, siblings, path, fromLvl, false) + _, _, siblings, err := t.down(wTx, k, root, siblings, path, fromLvl, false) if err != nil { - return err + return nil, err } leafKey, leafValue, err := t.newLeafValue(k, v) if err != nil { - return err + return nil, err } if err := wTx.Set(leafKey, leafValue); err != nil { - return err + return nil, err } // go up to the root if len(siblings) == 0 { - t.root = leafKey - return nil + // return the leafKey as root + return leafKey, nil } - root, err := t.up(wTx, leafKey, siblings, path, len(siblings)-1, fromLvl) + root, err = t.up(wTx, leafKey, siblings, path, len(siblings)-1, fromLvl) if err != nil { - return err + return nil, err } - t.root = root - return nil + return root, nil } // down goes down to the leaf recursively @@ -521,14 +556,23 @@ func (t *Tree) UpdateWithTx(wTx db.WriteTx, k, v []byte) error { t.Lock() defer t.Unlock() + if !t.editable() { + return ErrSnapshotNotEditable + } + var err error keyPath := make([]byte, t.hashFunction.Len()) copy(keyPath[:], k) path := getPath(t.maxLevels, keyPath) + root, err := t.RootWithTx(wTx) + if err != nil { + return err + } + var siblings [][]byte - _, valueAtBottom, siblings, err := t.down(wTx, k, t.root, siblings, path, 0, true) + _, valueAtBottom, siblings, err := t.down(wTx, k, root, siblings, path, 0, true) if err != nil { return err } @@ -548,17 +592,15 @@ func (t *Tree) UpdateWithTx(wTx db.WriteTx, k, v []byte) error { // go up to the root if len(siblings) == 0 { - t.root = leafKey - return nil + return t.setRoot(wTx, leafKey) } - root, err := t.up(wTx, leafKey, siblings, path, len(siblings)-1, 0) + root, err = t.up(wTx, leafKey, siblings, path, len(siblings)-1, 0) if err != nil { return err } - t.root = root // store root to db - if err := wTx.Set(dbKeyRoot, t.root); err != nil { + if err := t.setRoot(wTx, root); err != nil { return err } return nil @@ -580,10 +622,15 @@ func (t *Tree) GenProofWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, []byte, keyPath := make([]byte, t.hashFunction.Len()) copy(keyPath[:], k) + root, err := t.RootWithTx(rTx) + if err != nil { + return nil, nil, nil, false, err + } + path := getPath(t.maxLevels, keyPath) // go down to the leaf var siblings [][]byte - _, value, siblings, err := t.down(rTx, k, t.root, siblings, path, 0, true) + _, value, siblings, err := t.down(rTx, k, root, siblings, path, 0, true) if err != nil { return nil, nil, nil, false, err } @@ -672,7 +719,10 @@ func bytesToBitmap(b []byte) []bool { return bitmap } -// Get returns the value for a given key +// Get returns the value for a given key. If the key is not found, will return +// the error ErrKeyNotFound, and in the leafK & leafV parameters will be placed +// the data found in the tree in the leaf that was on the path going to the +// input key. func (t *Tree) Get(k []byte) ([]byte, []byte, error) { rTx := t.db.ReadTx() defer rTx.Discard() @@ -681,22 +731,28 @@ func (t *Tree) Get(k []byte) ([]byte, []byte, error) { } // GetWithTx does the same than the Get method, but allowing to pass the -// db.ReadTx that is used. +// db.ReadTx that is used. If the key is not found, will return the error +// ErrKeyNotFound, and in the leafK & leafV parameters will be placed the data +// found in the tree in the leaf that was on the path going to the input key. func (t *Tree) GetWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, error) { keyPath := make([]byte, t.hashFunction.Len()) copy(keyPath[:], k) + root, err := t.RootWithTx(rTx) + if err != nil { + return nil, nil, err + } + path := getPath(t.maxLevels, keyPath) // go down to the leaf var siblings [][]byte - _, value, _, err := t.down(rTx, k, t.root, siblings, path, 0, true) + _, value, _, err := t.down(rTx, k, root, siblings, path, 0, true) if err != nil { return nil, nil, err } leafK, leafV := ReadLeafValue(value) if !bytes.Equal(k, leafK) { - return leafK, leafV, fmt.Errorf("Tree.Get error: keys doesn't match, %s != %s", - BytesToBigInt(k), BytesToBigInt(leafK)) + return leafK, leafV, ErrKeyNotFound } return leafK, leafV, nil @@ -775,25 +831,23 @@ func (t *Tree) GetNLeafsWithTx(rTx db.ReadTx) (int, error) { return int(nLeafs), nil } -// Snapshot returns a copy of the Tree from the given root -func (t *Tree) Snapshot(rootKey []byte) (*Tree, error) { - // TODO currently this method only changes the 'root pointer', but the - // db continues being the same. In a future iteration, once the - // db.Database interface allows to do database checkpoints, this method - // could be updated to do a full checkpoint of the database for the - // snapshot, to return a completly new independent tree containing the - // snapshot. +// Snapshot returns a read-only copy of the Tree from the given root +func (t *Tree) Snapshot(fromRoot []byte) (*Tree, error) { t.RLock() defer t.RUnlock() // allow to define which root to use - if rootKey == nil { - rootKey = t.Root() + if fromRoot == nil { + var err error + fromRoot, err = t.Root() + if err != nil { + return nil, err + } } return &Tree{ db: t.db, maxLevels: t.maxLevels, - root: rootKey, + snapshotRoot: fromRoot, hashFunction: t.hashFunction, dbg: t.dbg, }, nil @@ -801,45 +855,58 @@ func (t *Tree) Snapshot(rootKey []byte) (*Tree, error) { // Iterate iterates through the full Tree, executing the given function on each // node of the Tree. -func (t *Tree) Iterate(rootKey []byte, f func([]byte, []byte)) error { +func (t *Tree) Iterate(fromRoot []byte, f func([]byte, []byte)) error { rTx := t.db.ReadTx() defer rTx.Discard() - return t.IterateWithTx(rTx, rootKey, f) + return t.IterateWithTx(rTx, fromRoot, f) } // IterateWithTx does the same than the Iterate method, but allowing to pass // the db.ReadTx that is used. -func (t *Tree) IterateWithTx(rTx db.ReadTx, rootKey []byte, f func([]byte, []byte)) error { +func (t *Tree) IterateWithTx(rTx db.ReadTx, fromRoot []byte, f func([]byte, []byte)) error { // allow to define which root to use - if rootKey == nil { - rootKey = t.Root() + if fromRoot == nil { + var err error + fromRoot, err = t.RootWithTx(rTx) + if err != nil { + return err + } } - return t.iter(rTx, rootKey, f) + return t.iter(rTx, fromRoot, f) } // IterateWithStop does the same than Iterate, but with int for the current // level, and a boolean parameter used by the passed function, is to indicate to // stop iterating on the branch when the method returns 'true'. -func (t *Tree) IterateWithStop(rootKey []byte, f func(int, []byte, []byte) bool) error { - // allow to define which root to use - if rootKey == nil { - rootKey = t.Root() - } +func (t *Tree) IterateWithStop(fromRoot []byte, f func(int, []byte, []byte) bool) error { rTx := t.db.ReadTx() defer rTx.Discard() - return t.iterWithStop(rTx, rootKey, 0, f) + + // allow to define which root to use + if fromRoot == nil { + var err error + fromRoot, err = t.RootWithTx(rTx) + if err != nil { + return err + } + } + return t.iterWithStop(rTx, fromRoot, 0, f) } // IterateWithStopWithTx does the same than the IterateWithStop method, but // allowing to pass the db.ReadTx that is used. -func (t *Tree) IterateWithStopWithTx(rTx db.ReadTx, rootKey []byte, +func (t *Tree) IterateWithStopWithTx(rTx db.ReadTx, fromRoot []byte, f func(int, []byte, []byte) bool) error { // allow to define which root to use - if rootKey == nil { - rootKey = t.Root() + if fromRoot == nil { + var err error + fromRoot, err = t.RootWithTx(rTx) + if err != nil { + return err + } } - return t.iterWithStop(rTx, rootKey, 0, f) + return t.iterWithStop(rTx, fromRoot, 0, f) } func (t *Tree) iterWithStop(rTx db.ReadTx, k []byte, currLevel int, @@ -892,16 +959,20 @@ func (t *Tree) iter(rTx db.ReadTx, k []byte, f func([]byte, []byte)) error { // [ 1 byte | 1 byte | S bytes | len(v) bytes ] // [ len(k) | len(v) | key | value ] // Where S is the size of the output of the hash function used for the Tree. -func (t *Tree) Dump(rootKey []byte) ([]byte, error) { +func (t *Tree) Dump(fromRoot []byte) ([]byte, error) { // allow to define which root to use - if rootKey == nil { - rootKey = t.Root() + if fromRoot == nil { + var err error + fromRoot, err = t.Root() + if err != nil { + return nil, err + } } // WARNING current encoding only supports key & values of 255 bytes each // (due using only 1 byte for the length headers). var b []byte - err := t.Iterate(rootKey, func(k, v []byte) { + err := t.Iterate(fromRoot, func(k, v []byte) { if v[0] != PrefixValueLeaf { return } @@ -919,6 +990,10 @@ func (t *Tree) Dump(rootKey []byte) ([]byte, error) { // ImportDump imports the leafs (that have been exported with the Dump method) // in the Tree. func (t *Tree) ImportDump(b []byte) error { + if !t.editable() { + return ErrSnapshotNotEditable + } + r := bytes.NewReader(b) var err error var keys, values [][]byte @@ -951,25 +1026,31 @@ func (t *Tree) ImportDump(b []byte) error { // Graphviz iterates across the full tree to generate a string Graphviz // representation of the tree and writes it to w -func (t *Tree) Graphviz(w io.Writer, rootKey []byte) error { - return t.GraphvizFirstNLevels(w, rootKey, t.maxLevels) +func (t *Tree) Graphviz(w io.Writer, fromRoot []byte) error { + return t.GraphvizFirstNLevels(w, fromRoot, t.maxLevels) } // GraphvizFirstNLevels iterates across the first NLevels of the tree to // generate a string Graphviz representation of the first NLevels of the tree // and writes it to w -func (t *Tree) GraphvizFirstNLevels(w io.Writer, rootKey []byte, untilLvl int) error { +func (t *Tree) GraphvizFirstNLevels(w io.Writer, fromRoot []byte, untilLvl int) error { fmt.Fprintf(w, `digraph hierarchy { node [fontname=Monospace,fontsize=10,shape=box] `) - if rootKey == nil { - rootKey = t.Root() - } + rTx := t.db.ReadTx() defer rTx.Discard() + if fromRoot == nil { + var err error + fromRoot, err = t.RootWithTx(rTx) + if err != nil { + return err + } + } + nEmpties := 0 - err := t.iterWithStop(rTx, rootKey, 0, func(currLvl int, k, v []byte) bool { + err := t.iterWithStop(rTx, fromRoot, 0, func(currLvl int, k, v []byte) bool { if currLvl == untilLvl { return true // to stop the iter from going down } @@ -1013,28 +1094,36 @@ node [fontname=Monospace,fontsize=10,shape=box] } // PrintGraphviz prints the output of Tree.Graphviz -func (t *Tree) PrintGraphviz(rootKey []byte) error { - if rootKey == nil { - rootKey = t.Root() +func (t *Tree) PrintGraphviz(fromRoot []byte) error { + if fromRoot == nil { + var err error + fromRoot, err = t.Root() + if err != nil { + return err + } } - return t.PrintGraphvizFirstNLevels(rootKey, t.maxLevels) + return t.PrintGraphvizFirstNLevels(fromRoot, t.maxLevels) } // PrintGraphvizFirstNLevels prints the output of Tree.GraphvizFirstNLevels -func (t *Tree) PrintGraphvizFirstNLevels(rootKey []byte, untilLvl int) error { - if rootKey == nil { - rootKey = t.Root() +func (t *Tree) PrintGraphvizFirstNLevels(fromRoot []byte, untilLvl int) error { + if fromRoot == nil { + var err error + fromRoot, err = t.Root() + if err != nil { + return err + } } w := bytes.NewBufferString("") fmt.Fprintf(w, - "--------\nGraphviz of the Tree with Root "+hex.EncodeToString(rootKey)+":\n") - err := t.GraphvizFirstNLevels(w, rootKey, untilLvl) + "--------\nGraphviz of the Tree with Root "+hex.EncodeToString(fromRoot)+":\n") + err := t.GraphvizFirstNLevels(w, fromRoot, untilLvl) if err != nil { fmt.Println(w) return err } fmt.Fprintf(w, - "End of Graphviz of the Tree with Root "+hex.EncodeToString(rootKey)+"\n--------\n") + "End of Graphviz of the Tree with Root "+hex.EncodeToString(fromRoot)+"\n--------\n") fmt.Println(w) return nil diff --git a/tree_test.go b/tree_test.go index 769f6fb..7cdb7e3 100644 --- a/tree_test.go +++ b/tree_test.go @@ -11,6 +11,13 @@ import ( "go.vocdoni.io/dvote/db/badgerdb" ) +func checkRootBIString(c *qt.C, tree *Tree, expected string) { + root, err := tree.Root() + c.Assert(err, qt.IsNil) + rootBI := BytesToBigInt(root) + c.Check(rootBI.String(), qt.Equals, expected) +} + func TestDBTx(t *testing.T) { c := qt.New(t) @@ -57,29 +64,28 @@ func testAdd(c *qt.C, hashFunc HashFunction, testVectors []string) { c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - c.Check(hex.EncodeToString(tree.Root()), qt.Equals, testVectors[0]) + root, err := tree.Root() + c.Assert(err, qt.IsNil) + c.Check(hex.EncodeToString(root), qt.Equals, testVectors[0]) bLen := hashFunc.Len() err = tree.Add( BigIntToBytes(bLen, big.NewInt(1)), BigIntToBytes(bLen, big.NewInt(2))) c.Assert(err, qt.IsNil) - rootBI := BytesToBigInt(tree.Root()) - c.Check(rootBI.String(), qt.Equals, testVectors[1]) + checkRootBIString(c, tree, testVectors[1]) err = tree.Add( BigIntToBytes(bLen, big.NewInt(33)), BigIntToBytes(bLen, big.NewInt(44))) c.Assert(err, qt.IsNil) - rootBI = BytesToBigInt(tree.Root()) - c.Check(rootBI.String(), qt.Equals, testVectors[2]) + checkRootBIString(c, tree, testVectors[2]) err = tree.Add( BigIntToBytes(bLen, big.NewInt(1234)), BigIntToBytes(bLen, big.NewInt(9876))) c.Assert(err, qt.IsNil) - rootBI = BytesToBigInt(tree.Root()) - c.Check(rootBI.String(), qt.Equals, testVectors[3]) + checkRootBIString(c, tree, testVectors[3]) } func TestAddBatch(t *testing.T) { @@ -99,8 +105,7 @@ func TestAddBatch(t *testing.T) { } } - rootBI := BytesToBigInt(tree.Root()) - c.Check(rootBI.String(), qt.Equals, + checkRootBIString(c, tree, "296519252211642170490407814696803112091039265640052570497930797516015811235") database, err = badgerdb.New(badgerdb.Options{Path: c.TempDir()}) @@ -120,8 +125,7 @@ func TestAddBatch(t *testing.T) { c.Assert(err, qt.IsNil) c.Check(len(indexes), qt.Equals, 0) - rootBI = BytesToBigInt(tree2.Root()) - c.Check(rootBI.String(), qt.Equals, + checkRootBIString(c, tree2, "296519252211642170490407814696803112091039265640052570497930797516015811235") } @@ -156,8 +160,12 @@ func TestAddDifferentOrder(t *testing.T) { } } - c.Check(hex.EncodeToString(tree2.Root()), qt.Equals, hex.EncodeToString(tree1.Root())) - c.Check(hex.EncodeToString(tree1.Root()), qt.Equals, + root1, err := tree1.Root() + c.Assert(err, qt.IsNil) + root2, err := tree2.Root() + c.Assert(err, qt.IsNil) + c.Check(hex.EncodeToString(root2), qt.Equals, hex.EncodeToString(root1)) + c.Check(hex.EncodeToString(root1), qt.Equals, "3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f") } @@ -320,7 +328,9 @@ func TestGenProofAndVerify(t *testing.T) { c.Assert(k, qt.DeepEquals, kAux) c.Assert(existence, qt.IsTrue) - verif, err := CheckProof(tree.hashFunction, k, v, tree.Root(), siblings) + root, err := tree.Root() + c.Assert(err, qt.IsNil) + verif, err := CheckProof(tree.hashFunction, k, v, root, siblings) c.Assert(err, qt.IsNil) c.Check(verif, qt.IsTrue) } @@ -352,8 +362,13 @@ func TestDumpAndImportDump(t *testing.T) { defer tree2.db.Close() //nolint:errcheck err = tree2.ImportDump(e) c.Assert(err, qt.IsNil) - c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) - c.Check(hex.EncodeToString(tree2.Root()), qt.Equals, + + root1, err := tree1.Root() + c.Assert(err, qt.IsNil) + root2, err := tree2.Root() + c.Assert(err, qt.IsNil) + c.Check(root2, qt.DeepEquals, root1) + c.Check(hex.EncodeToString(root2), qt.Equals, "0d93aaa3362b2f999f15e15728f123087c2eee716f01c01f56e23aae07f09f08") } @@ -458,40 +473,34 @@ func TestSnapshot(t *testing.T) { indexes, err := tree.AddBatch(keys, values) c.Assert(err, qt.IsNil) c.Check(len(indexes), qt.Equals, 0) - - rootBI := BytesToBigInt(tree.Root()) - c.Check(rootBI.String(), qt.Equals, + checkRootBIString(c, tree, "13742386369878513332697380582061714160370929283209286127733983161245560237407") - tree2, err := tree.Snapshot(nil) + // do a snapshot, and expect the same root than the original tree + snapshotTree, err := tree.Snapshot(nil) c.Assert(err, qt.IsNil) - rootBI = BytesToBigInt(tree2.Root()) - c.Check(rootBI.String(), qt.Equals, + checkRootBIString(c, snapshotTree, "13742386369878513332697380582061714160370929283209286127733983161245560237407") - // add k-v to original tree - k := BigIntToBytes(bLen, big.NewInt(int64(1000))) - v := BigIntToBytes(bLen, big.NewInt(int64(1000))) - err = tree.Add(k, v) - c.Assert(err, qt.IsNil) - - // expect original tree to have new root - rootBI = BytesToBigInt(tree.Root()) - c.Check(rootBI.String(), qt.Equals, - "10747149055773881257049574592162159501044114324358186833013814735296193179713") - - // expect snapshot tree to have the old root - rootBI = BytesToBigInt(tree2.Root()) - c.Check(rootBI.String(), qt.Equals, + // check that the snapshotTree can not be updated + _, err = snapshotTree.AddBatch(keys, values) + c.Assert(err, qt.Equals, ErrSnapshotNotEditable) + err = snapshotTree.Add([]byte("test"), []byte("test")) + c.Assert(err, qt.Equals, ErrSnapshotNotEditable) + err = snapshotTree.Update([]byte("test"), []byte("test")) + c.Assert(err, qt.Equals, ErrSnapshotNotEditable) + err = snapshotTree.ImportDump(nil) + c.Assert(err, qt.Equals, ErrSnapshotNotEditable) + + // update the original tree by adding a new key-value, and check that + // snapshotTree still has the old root, but the original tree has a new + // root + err = tree.Add([]byte("test"), []byte("test")) + c.Assert(err, qt.IsNil) + checkRootBIString(c, snapshotTree, "13742386369878513332697380582061714160370929283209286127733983161245560237407") - - err = tree2.Add(k, v) - c.Assert(err, qt.IsNil) - // after adding also the k-v into the snapshot tree, expect original - // tree to have new root - rootBI = BytesToBigInt(tree.Root()) - c.Check(rootBI.String(), qt.Equals, - "10747149055773881257049574592162159501044114324358186833013814735296193179713") + checkRootBIString(c, tree, + "1025190963769001718196479367844646783678188389989148142691917685159698888868") } func BenchmarkAdd(b *testing.B) { diff --git a/vt_test.go b/vt_test.go index 6b607e5..6eba30d 100644 --- a/vt_test.go +++ b/vt_test.go @@ -98,7 +98,9 @@ func testVirtualTree(c *qt.C, maxLevels int, keys, values [][]byte) { // compute hashes, and check Root _, err = vTree.computeHashes() c.Assert(err, qt.IsNil) - c.Assert(vTree.root.h, qt.DeepEquals, tree.root) + root, err := tree.Root() + c.Assert(err, qt.IsNil) + c.Assert(vTree.root.h, qt.DeepEquals, root) } func TestVirtualTreeAddBatch(t *testing.T) { @@ -136,7 +138,9 @@ func TestVirtualTreeAddBatch(t *testing.T) { // compute hashes, and check Root _, err = vTree.computeHashes() c.Assert(err, qt.IsNil) - c.Assert(vTree.root.h, qt.DeepEquals, tree.root) + root, err := tree.Root() + c.Assert(err, qt.IsNil) + c.Assert(vTree.root.h, qt.DeepEquals, root) } func TestGetNodesAtLevel(t *testing.T) {