From 5f6c35e4354c3a48640b39f9e2bc3d63d2f07939 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Wed, 11 Aug 2021 19:30:31 +0200 Subject: [PATCH] Update Snapshot & Root approach Update Snapshot & Root approach to get the root always from the db, except in the cases that the tree is a snapshot, in which the root will be in memory. In this way, when a snapshot is performed and the original tree gets modifyed, the snapshot will still point to the old root. Also, the root obtained from the db, uses also the db.ReadTx, so if the root is being modifyied in the current tx (db.WriteTx), when getting the root it will be return the lastest version that is in the tx but not yet in the db. --- addbatch_test.go | 14 ++- circomproofs.go | 5 +- helpers_test.go | 18 +++- tree.go | 263 +++++++++++++++++++++++++++++++---------------- tree_test.go | 97 +++++++++-------- vt_test.go | 8 +- 6 files changed, 263 insertions(+), 142 deletions(-) diff --git a/addbatch_test.go b/addbatch_test.go index ae45417..d02fd10 100644 --- a/addbatch_test.go +++ b/addbatch_test.go @@ -914,7 +914,9 @@ func TestLoadVT(t *testing.T) { c.Assert(err, qt.IsNil) // check that tree & vt roots are equal - c.Check(tree.Root(), qt.DeepEquals, vt.root.h) + root, err := tree.Root() + c.Assert(err, qt.IsNil) + c.Check(root, qt.DeepEquals, vt.root.h) } // TestAddKeysWithEmptyValues calls AddBatch giving an array of empty values @@ -976,12 +978,14 @@ func TestAddKeysWithEmptyValues(t *testing.T) { c.Assert(existence, qt.IsTrue) // check with empty array - verif, err := CheckProof(tree.hashFunction, keys[9], []byte{}, tree.Root(), siblings) + root, err := tree.Root() + c.Assert(err, qt.IsNil) + verif, err := CheckProof(tree.hashFunction, keys[9], []byte{}, root, siblings) c.Assert(err, qt.IsNil) c.Check(verif, qt.IsTrue) // check with array with only 1 zero - verif, err = CheckProof(tree.hashFunction, keys[9], []byte{0}, tree.Root(), siblings) + verif, err = CheckProof(tree.hashFunction, keys[9], []byte{0}, root, siblings) c.Assert(err, qt.IsNil) c.Check(verif, qt.IsTrue) @@ -989,12 +993,12 @@ func TestAddKeysWithEmptyValues(t *testing.T) { e32 := []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} c.Assert(len(e32), qt.Equals, 32) - verif, err = CheckProof(tree.hashFunction, keys[9], e32, tree.Root(), siblings) + verif, err = CheckProof(tree.hashFunction, keys[9], e32, root, siblings) c.Assert(err, qt.IsNil) c.Check(verif, qt.IsTrue) // check with array with value!=0 returns false at verification - verif, err = CheckProof(tree.hashFunction, keys[9], []byte{0, 1}, tree.Root(), siblings) + verif, err = CheckProof(tree.hashFunction, keys[9], []byte{0, 1}, root, siblings) c.Assert(err, qt.IsNil) c.Check(verif, qt.IsFalse) } diff --git a/circomproofs.go b/circomproofs.go index 9a5f8e1..feb0b3a 100644 --- a/circomproofs.go +++ b/circomproofs.go @@ -63,7 +63,10 @@ func (t *Tree) GenerateCircomVerifierProof(k []byte) (*CircomVerifierProof, erro return nil, err } var cp CircomVerifierProof - cp.Root = t.Root() + cp.Root, err = t.Root() + if err != nil { + return nil, err + } s, err := UnpackSiblings(t.hashFunction, siblings) if err != nil { return nil, err diff --git a/helpers_test.go b/helpers_test.go index 0094a61..7382225 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -13,7 +13,11 @@ import ( ) func checkRoots(c *qt.C, tree1, tree2 *Tree) { - if !bytes.Equal(tree2.Root(), tree1.Root()) { + root1, err := tree1.Root() + c.Assert(err, qt.IsNil) + root2, err := tree2.Root() + c.Assert(err, qt.IsNil) + if !bytes.Equal(root2, root1) { dir := "err-dump" if _, err := os.Stat(dir); os.IsNotExist(err) { err := os.Mkdir(dir, os.ModePerm) @@ -25,7 +29,11 @@ func checkRoots(c *qt.C, tree1, tree2 *Tree) { // store tree2 storeTree(c, tree2, dir+"/tree2") - c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) + root1, err := tree1.Root() + c.Assert(err, qt.IsNil) + root2, err := tree2.Root() + c.Assert(err, qt.IsNil) + c.Check(root2, qt.DeepEquals, root1) } } @@ -103,5 +111,9 @@ func TestReadTreeDBG(t *testing.T) { // tree1.PrintGraphvizFirstNLevels(nil, 6) // tree2.PrintGraphvizFirstNLevels(nil, 6) - c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) + root1, err := tree1.Root() + c.Assert(err, qt.IsNil) + root2, err := tree2.Root() + c.Assert(err, qt.IsNil) + c.Check(root2, qt.DeepEquals, root1) } diff --git a/tree.go b/tree.go index 01d673d..5719374 100644 --- a/tree.go +++ b/tree.go @@ -63,15 +63,18 @@ var ( // ErrMaxVirtualLevel indicates when going down into the tree, the max // virtual level is reached ErrMaxVirtualLevel = fmt.Errorf("max virtual level reached") + // ErrSnapshotNotEditable indicates when the tree is a snapshot, thus + // can not be modified + ErrSnapshotNotEditable = fmt.Errorf("snapshot tree can not be edited") ) // Tree defines the struct that implements the MerkleTree functionalities type Tree struct { sync.RWMutex - db db.Database - maxLevels int - root []byte + db db.Database + maxLevels int + snapshotRoot []byte hashFunction HashFunction // TODO in the methods that use it, check if emptyHash param is len>0 @@ -90,11 +93,10 @@ func NewTree(database db.Database, maxLevels int, hash HashFunction) (*Tree, err wTx := t.db.WriteTx() defer wTx.Discard() - root, err := wTx.Get(dbKeyRoot) + _, err := wTx.Get(dbKeyRoot) if err == db.ErrKeyNotFound { - // store new root 0 - t.root = t.emptyHash - if err = wTx.Set(dbKeyRoot, t.root); err != nil { + // store new root 0 (empty) + if err = wTx.Set(dbKeyRoot, t.emptyHash); err != nil { return nil, err } if err = t.setNLeafs(wTx, 0); err != nil { @@ -111,14 +113,30 @@ func NewTree(database db.Database, maxLevels int, hash HashFunction) (*Tree, err if err = wTx.Commit(); err != nil { return nil, err } - t.root = root return &t, nil } // Root returns the root of the Tree -func (t *Tree) Root() []byte { - // TODO get Root from db - return t.root +func (t *Tree) Root() ([]byte, error) { + rTx := t.db.ReadTx() + defer rTx.Discard() + return t.RootWithTx(rTx) +} + +// RootWithTx returns the root of the Tree using the given db.ReadTx +func (t *Tree) RootWithTx(rTx db.ReadTx) ([]byte, error) { + // if snapshotRoot is defined, means that the tree is a snapshot, and + // the root is not obtained from the db, but from the snapshotRoot + // parameter + if t.snapshotRoot != nil { + return t.snapshotRoot, nil + } + // get db root + return rTx.Get(dbKeyRoot) +} + +func (t *Tree) setRoot(wTx db.WriteTx, root []byte) error { + return wTx.Set(dbKeyRoot, root) } // HashFunction returns Tree.hashFunction @@ -126,6 +144,12 @@ func (t *Tree) HashFunction() HashFunction { return t.hashFunction } +// editable returns true if the tree is editable, and false when is not +// editable (because is a snapshot tree) +func (t *Tree) editable() bool { + return t.snapshotRoot == nil +} + // AddBatch adds a batch of key-values to the Tree. Returns an array containing // the indexes of the keys failed to add. Supports empty values as input // parameters, which is equivalent to 0 valued byte array. @@ -147,6 +171,10 @@ func (t *Tree) AddBatchWithTx(wTx db.WriteTx, keys, values [][]byte) ([]int, err t.Lock() defer t.Unlock() + if !t.editable() { + return nil, ErrSnapshotNotEditable + } + vt, err := t.loadVT(wTx) if err != nil { return nil, err @@ -177,7 +205,6 @@ func (t *Tree) AddBatchWithTx(wTx db.WriteTx, keys, values [][]byte) ([]int, err // nothing stored in the db and the error is returned return nil, err } - t.root = vt.root.h // store pairs in db for i := 0; i < len(pairs); i++ { @@ -186,8 +213,8 @@ func (t *Tree) AddBatchWithTx(wTx db.WriteTx, keys, values [][]byte) ([]int, err } } - // store root to db - if err := wTx.Set(dbKeyRoot, t.root); err != nil { + // store root (from the vt) to db + if err := wTx.Set(dbKeyRoot, vt.root.h); err != nil { return nil, err } @@ -239,12 +266,21 @@ func (t *Tree) AddWithTx(wTx db.WriteTx, k, v []byte) error { t.Lock() defer t.Unlock() - err := t.add(wTx, 0, k, v) // add from level 0 + if !t.editable() { + return ErrSnapshotNotEditable + } + + root, err := t.RootWithTx(wTx) + if err != nil { + return err + } + + root, err = t.add(wTx, root, 0, k, v) // add from level 0 if err != nil { return err } // store root to db - if err := wTx.Set(dbKeyRoot, t.root); err != nil { + if err := t.setRoot(wTx, root); err != nil { return err } // update nLeafs @@ -254,39 +290,38 @@ func (t *Tree) AddWithTx(wTx db.WriteTx, k, v []byte) error { return nil } -func (t *Tree) add(wTx db.WriteTx, fromLvl int, k, v []byte) error { +func (t *Tree) add(wTx db.WriteTx, root []byte, fromLvl int, k, v []byte) ([]byte, error) { keyPath := make([]byte, t.hashFunction.Len()) copy(keyPath[:], k) path := getPath(t.maxLevels, keyPath) // go down to the leaf var siblings [][]byte - _, _, siblings, err := t.down(wTx, k, t.root, siblings, path, fromLvl, false) + _, _, siblings, err := t.down(wTx, k, root, siblings, path, fromLvl, false) if err != nil { - return err + return nil, err } leafKey, leafValue, err := t.newLeafValue(k, v) if err != nil { - return err + return nil, err } if err := wTx.Set(leafKey, leafValue); err != nil { - return err + return nil, err } // go up to the root if len(siblings) == 0 { - t.root = leafKey - return nil + // return the leafKey as root + return leafKey, nil } - root, err := t.up(wTx, leafKey, siblings, path, len(siblings)-1, fromLvl) + root, err = t.up(wTx, leafKey, siblings, path, len(siblings)-1, fromLvl) if err != nil { - return err + return nil, err } - t.root = root - return nil + return root, nil } // down goes down to the leaf recursively @@ -521,14 +556,23 @@ func (t *Tree) UpdateWithTx(wTx db.WriteTx, k, v []byte) error { t.Lock() defer t.Unlock() + if !t.editable() { + return ErrSnapshotNotEditable + } + var err error keyPath := make([]byte, t.hashFunction.Len()) copy(keyPath[:], k) path := getPath(t.maxLevels, keyPath) + root, err := t.RootWithTx(wTx) + if err != nil { + return err + } + var siblings [][]byte - _, valueAtBottom, siblings, err := t.down(wTx, k, t.root, siblings, path, 0, true) + _, valueAtBottom, siblings, err := t.down(wTx, k, root, siblings, path, 0, true) if err != nil { return err } @@ -548,17 +592,15 @@ func (t *Tree) UpdateWithTx(wTx db.WriteTx, k, v []byte) error { // go up to the root if len(siblings) == 0 { - t.root = leafKey - return nil + return t.setRoot(wTx, leafKey) } - root, err := t.up(wTx, leafKey, siblings, path, len(siblings)-1, 0) + root, err = t.up(wTx, leafKey, siblings, path, len(siblings)-1, 0) if err != nil { return err } - t.root = root // store root to db - if err := wTx.Set(dbKeyRoot, t.root); err != nil { + if err := t.setRoot(wTx, root); err != nil { return err } return nil @@ -580,10 +622,15 @@ func (t *Tree) GenProofWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, []byte, keyPath := make([]byte, t.hashFunction.Len()) copy(keyPath[:], k) + root, err := t.RootWithTx(rTx) + if err != nil { + return nil, nil, nil, false, err + } + path := getPath(t.maxLevels, keyPath) // go down to the leaf var siblings [][]byte - _, value, siblings, err := t.down(rTx, k, t.root, siblings, path, 0, true) + _, value, siblings, err := t.down(rTx, k, root, siblings, path, 0, true) if err != nil { return nil, nil, nil, false, err } @@ -672,7 +719,10 @@ func bytesToBitmap(b []byte) []bool { return bitmap } -// Get returns the value for a given key +// Get returns the value for a given key. If the key is not found, will return +// the error ErrKeyNotFound, and in the leafK & leafV parameters will be placed +// the data found in the tree in the leaf that was on the path going to the +// input key. func (t *Tree) Get(k []byte) ([]byte, []byte, error) { rTx := t.db.ReadTx() defer rTx.Discard() @@ -681,22 +731,28 @@ func (t *Tree) Get(k []byte) ([]byte, []byte, error) { } // GetWithTx does the same than the Get method, but allowing to pass the -// db.ReadTx that is used. +// db.ReadTx that is used. If the key is not found, will return the error +// ErrKeyNotFound, and in the leafK & leafV parameters will be placed the data +// found in the tree in the leaf that was on the path going to the input key. func (t *Tree) GetWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, error) { keyPath := make([]byte, t.hashFunction.Len()) copy(keyPath[:], k) + root, err := t.RootWithTx(rTx) + if err != nil { + return nil, nil, err + } + path := getPath(t.maxLevels, keyPath) // go down to the leaf var siblings [][]byte - _, value, _, err := t.down(rTx, k, t.root, siblings, path, 0, true) + _, value, _, err := t.down(rTx, k, root, siblings, path, 0, true) if err != nil { return nil, nil, err } leafK, leafV := ReadLeafValue(value) if !bytes.Equal(k, leafK) { - return leafK, leafV, fmt.Errorf("Tree.Get error: keys doesn't match, %s != %s", - BytesToBigInt(k), BytesToBigInt(leafK)) + return leafK, leafV, ErrKeyNotFound } return leafK, leafV, nil @@ -775,25 +831,23 @@ func (t *Tree) GetNLeafsWithTx(rTx db.ReadTx) (int, error) { return int(nLeafs), nil } -// Snapshot returns a copy of the Tree from the given root -func (t *Tree) Snapshot(rootKey []byte) (*Tree, error) { - // TODO currently this method only changes the 'root pointer', but the - // db continues being the same. In a future iteration, once the - // db.Database interface allows to do database checkpoints, this method - // could be updated to do a full checkpoint of the database for the - // snapshot, to return a completly new independent tree containing the - // snapshot. +// Snapshot returns a read-only copy of the Tree from the given root +func (t *Tree) Snapshot(fromRoot []byte) (*Tree, error) { t.RLock() defer t.RUnlock() // allow to define which root to use - if rootKey == nil { - rootKey = t.Root() + if fromRoot == nil { + var err error + fromRoot, err = t.Root() + if err != nil { + return nil, err + } } return &Tree{ db: t.db, maxLevels: t.maxLevels, - root: rootKey, + snapshotRoot: fromRoot, hashFunction: t.hashFunction, dbg: t.dbg, }, nil @@ -801,45 +855,58 @@ func (t *Tree) Snapshot(rootKey []byte) (*Tree, error) { // Iterate iterates through the full Tree, executing the given function on each // node of the Tree. -func (t *Tree) Iterate(rootKey []byte, f func([]byte, []byte)) error { +func (t *Tree) Iterate(fromRoot []byte, f func([]byte, []byte)) error { rTx := t.db.ReadTx() defer rTx.Discard() - return t.IterateWithTx(rTx, rootKey, f) + return t.IterateWithTx(rTx, fromRoot, f) } // IterateWithTx does the same than the Iterate method, but allowing to pass // the db.ReadTx that is used. -func (t *Tree) IterateWithTx(rTx db.ReadTx, rootKey []byte, f func([]byte, []byte)) error { +func (t *Tree) IterateWithTx(rTx db.ReadTx, fromRoot []byte, f func([]byte, []byte)) error { // allow to define which root to use - if rootKey == nil { - rootKey = t.Root() + if fromRoot == nil { + var err error + fromRoot, err = t.RootWithTx(rTx) + if err != nil { + return err + } } - return t.iter(rTx, rootKey, f) + return t.iter(rTx, fromRoot, f) } // IterateWithStop does the same than Iterate, but with int for the current // level, and a boolean parameter used by the passed function, is to indicate to // stop iterating on the branch when the method returns 'true'. -func (t *Tree) IterateWithStop(rootKey []byte, f func(int, []byte, []byte) bool) error { - // allow to define which root to use - if rootKey == nil { - rootKey = t.Root() - } +func (t *Tree) IterateWithStop(fromRoot []byte, f func(int, []byte, []byte) bool) error { rTx := t.db.ReadTx() defer rTx.Discard() - return t.iterWithStop(rTx, rootKey, 0, f) + + // allow to define which root to use + if fromRoot == nil { + var err error + fromRoot, err = t.RootWithTx(rTx) + if err != nil { + return err + } + } + return t.iterWithStop(rTx, fromRoot, 0, f) } // IterateWithStopWithTx does the same than the IterateWithStop method, but // allowing to pass the db.ReadTx that is used. -func (t *Tree) IterateWithStopWithTx(rTx db.ReadTx, rootKey []byte, +func (t *Tree) IterateWithStopWithTx(rTx db.ReadTx, fromRoot []byte, f func(int, []byte, []byte) bool) error { // allow to define which root to use - if rootKey == nil { - rootKey = t.Root() + if fromRoot == nil { + var err error + fromRoot, err = t.RootWithTx(rTx) + if err != nil { + return err + } } - return t.iterWithStop(rTx, rootKey, 0, f) + return t.iterWithStop(rTx, fromRoot, 0, f) } func (t *Tree) iterWithStop(rTx db.ReadTx, k []byte, currLevel int, @@ -892,16 +959,20 @@ func (t *Tree) iter(rTx db.ReadTx, k []byte, f func([]byte, []byte)) error { // [ 1 byte | 1 byte | S bytes | len(v) bytes ] // [ len(k) | len(v) | key | value ] // Where S is the size of the output of the hash function used for the Tree. -func (t *Tree) Dump(rootKey []byte) ([]byte, error) { +func (t *Tree) Dump(fromRoot []byte) ([]byte, error) { // allow to define which root to use - if rootKey == nil { - rootKey = t.Root() + if fromRoot == nil { + var err error + fromRoot, err = t.Root() + if err != nil { + return nil, err + } } // WARNING current encoding only supports key & values of 255 bytes each // (due using only 1 byte for the length headers). var b []byte - err := t.Iterate(rootKey, func(k, v []byte) { + err := t.Iterate(fromRoot, func(k, v []byte) { if v[0] != PrefixValueLeaf { return } @@ -919,6 +990,10 @@ func (t *Tree) Dump(rootKey []byte) ([]byte, error) { // ImportDump imports the leafs (that have been exported with the Dump method) // in the Tree. func (t *Tree) ImportDump(b []byte) error { + if !t.editable() { + return ErrSnapshotNotEditable + } + r := bytes.NewReader(b) var err error var keys, values [][]byte @@ -951,25 +1026,31 @@ func (t *Tree) ImportDump(b []byte) error { // Graphviz iterates across the full tree to generate a string Graphviz // representation of the tree and writes it to w -func (t *Tree) Graphviz(w io.Writer, rootKey []byte) error { - return t.GraphvizFirstNLevels(w, rootKey, t.maxLevels) +func (t *Tree) Graphviz(w io.Writer, fromRoot []byte) error { + return t.GraphvizFirstNLevels(w, fromRoot, t.maxLevels) } // GraphvizFirstNLevels iterates across the first NLevels of the tree to // generate a string Graphviz representation of the first NLevels of the tree // and writes it to w -func (t *Tree) GraphvizFirstNLevels(w io.Writer, rootKey []byte, untilLvl int) error { +func (t *Tree) GraphvizFirstNLevels(w io.Writer, fromRoot []byte, untilLvl int) error { fmt.Fprintf(w, `digraph hierarchy { node [fontname=Monospace,fontsize=10,shape=box] `) - if rootKey == nil { - rootKey = t.Root() - } + rTx := t.db.ReadTx() defer rTx.Discard() + if fromRoot == nil { + var err error + fromRoot, err = t.RootWithTx(rTx) + if err != nil { + return err + } + } + nEmpties := 0 - err := t.iterWithStop(rTx, rootKey, 0, func(currLvl int, k, v []byte) bool { + err := t.iterWithStop(rTx, fromRoot, 0, func(currLvl int, k, v []byte) bool { if currLvl == untilLvl { return true // to stop the iter from going down } @@ -1013,28 +1094,36 @@ node [fontname=Monospace,fontsize=10,shape=box] } // PrintGraphviz prints the output of Tree.Graphviz -func (t *Tree) PrintGraphviz(rootKey []byte) error { - if rootKey == nil { - rootKey = t.Root() +func (t *Tree) PrintGraphviz(fromRoot []byte) error { + if fromRoot == nil { + var err error + fromRoot, err = t.Root() + if err != nil { + return err + } } - return t.PrintGraphvizFirstNLevels(rootKey, t.maxLevels) + return t.PrintGraphvizFirstNLevels(fromRoot, t.maxLevels) } // PrintGraphvizFirstNLevels prints the output of Tree.GraphvizFirstNLevels -func (t *Tree) PrintGraphvizFirstNLevels(rootKey []byte, untilLvl int) error { - if rootKey == nil { - rootKey = t.Root() +func (t *Tree) PrintGraphvizFirstNLevels(fromRoot []byte, untilLvl int) error { + if fromRoot == nil { + var err error + fromRoot, err = t.Root() + if err != nil { + return err + } } w := bytes.NewBufferString("") fmt.Fprintf(w, - "--------\nGraphviz of the Tree with Root "+hex.EncodeToString(rootKey)+":\n") - err := t.GraphvizFirstNLevels(w, rootKey, untilLvl) + "--------\nGraphviz of the Tree with Root "+hex.EncodeToString(fromRoot)+":\n") + err := t.GraphvizFirstNLevels(w, fromRoot, untilLvl) if err != nil { fmt.Println(w) return err } fmt.Fprintf(w, - "End of Graphviz of the Tree with Root "+hex.EncodeToString(rootKey)+"\n--------\n") + "End of Graphviz of the Tree with Root "+hex.EncodeToString(fromRoot)+"\n--------\n") fmt.Println(w) return nil diff --git a/tree_test.go b/tree_test.go index 769f6fb..7cdb7e3 100644 --- a/tree_test.go +++ b/tree_test.go @@ -11,6 +11,13 @@ import ( "go.vocdoni.io/dvote/db/badgerdb" ) +func checkRootBIString(c *qt.C, tree *Tree, expected string) { + root, err := tree.Root() + c.Assert(err, qt.IsNil) + rootBI := BytesToBigInt(root) + c.Check(rootBI.String(), qt.Equals, expected) +} + func TestDBTx(t *testing.T) { c := qt.New(t) @@ -57,29 +64,28 @@ func testAdd(c *qt.C, hashFunc HashFunction, testVectors []string) { c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - c.Check(hex.EncodeToString(tree.Root()), qt.Equals, testVectors[0]) + root, err := tree.Root() + c.Assert(err, qt.IsNil) + c.Check(hex.EncodeToString(root), qt.Equals, testVectors[0]) bLen := hashFunc.Len() err = tree.Add( BigIntToBytes(bLen, big.NewInt(1)), BigIntToBytes(bLen, big.NewInt(2))) c.Assert(err, qt.IsNil) - rootBI := BytesToBigInt(tree.Root()) - c.Check(rootBI.String(), qt.Equals, testVectors[1]) + checkRootBIString(c, tree, testVectors[1]) err = tree.Add( BigIntToBytes(bLen, big.NewInt(33)), BigIntToBytes(bLen, big.NewInt(44))) c.Assert(err, qt.IsNil) - rootBI = BytesToBigInt(tree.Root()) - c.Check(rootBI.String(), qt.Equals, testVectors[2]) + checkRootBIString(c, tree, testVectors[2]) err = tree.Add( BigIntToBytes(bLen, big.NewInt(1234)), BigIntToBytes(bLen, big.NewInt(9876))) c.Assert(err, qt.IsNil) - rootBI = BytesToBigInt(tree.Root()) - c.Check(rootBI.String(), qt.Equals, testVectors[3]) + checkRootBIString(c, tree, testVectors[3]) } func TestAddBatch(t *testing.T) { @@ -99,8 +105,7 @@ func TestAddBatch(t *testing.T) { } } - rootBI := BytesToBigInt(tree.Root()) - c.Check(rootBI.String(), qt.Equals, + checkRootBIString(c, tree, "296519252211642170490407814696803112091039265640052570497930797516015811235") database, err = badgerdb.New(badgerdb.Options{Path: c.TempDir()}) @@ -120,8 +125,7 @@ func TestAddBatch(t *testing.T) { c.Assert(err, qt.IsNil) c.Check(len(indexes), qt.Equals, 0) - rootBI = BytesToBigInt(tree2.Root()) - c.Check(rootBI.String(), qt.Equals, + checkRootBIString(c, tree2, "296519252211642170490407814696803112091039265640052570497930797516015811235") } @@ -156,8 +160,12 @@ func TestAddDifferentOrder(t *testing.T) { } } - c.Check(hex.EncodeToString(tree2.Root()), qt.Equals, hex.EncodeToString(tree1.Root())) - c.Check(hex.EncodeToString(tree1.Root()), qt.Equals, + root1, err := tree1.Root() + c.Assert(err, qt.IsNil) + root2, err := tree2.Root() + c.Assert(err, qt.IsNil) + c.Check(hex.EncodeToString(root2), qt.Equals, hex.EncodeToString(root1)) + c.Check(hex.EncodeToString(root1), qt.Equals, "3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f") } @@ -320,7 +328,9 @@ func TestGenProofAndVerify(t *testing.T) { c.Assert(k, qt.DeepEquals, kAux) c.Assert(existence, qt.IsTrue) - verif, err := CheckProof(tree.hashFunction, k, v, tree.Root(), siblings) + root, err := tree.Root() + c.Assert(err, qt.IsNil) + verif, err := CheckProof(tree.hashFunction, k, v, root, siblings) c.Assert(err, qt.IsNil) c.Check(verif, qt.IsTrue) } @@ -352,8 +362,13 @@ func TestDumpAndImportDump(t *testing.T) { defer tree2.db.Close() //nolint:errcheck err = tree2.ImportDump(e) c.Assert(err, qt.IsNil) - c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) - c.Check(hex.EncodeToString(tree2.Root()), qt.Equals, + + root1, err := tree1.Root() + c.Assert(err, qt.IsNil) + root2, err := tree2.Root() + c.Assert(err, qt.IsNil) + c.Check(root2, qt.DeepEquals, root1) + c.Check(hex.EncodeToString(root2), qt.Equals, "0d93aaa3362b2f999f15e15728f123087c2eee716f01c01f56e23aae07f09f08") } @@ -458,40 +473,34 @@ func TestSnapshot(t *testing.T) { indexes, err := tree.AddBatch(keys, values) c.Assert(err, qt.IsNil) c.Check(len(indexes), qt.Equals, 0) - - rootBI := BytesToBigInt(tree.Root()) - c.Check(rootBI.String(), qt.Equals, + checkRootBIString(c, tree, "13742386369878513332697380582061714160370929283209286127733983161245560237407") - tree2, err := tree.Snapshot(nil) + // do a snapshot, and expect the same root than the original tree + snapshotTree, err := tree.Snapshot(nil) c.Assert(err, qt.IsNil) - rootBI = BytesToBigInt(tree2.Root()) - c.Check(rootBI.String(), qt.Equals, + checkRootBIString(c, snapshotTree, "13742386369878513332697380582061714160370929283209286127733983161245560237407") - // add k-v to original tree - k := BigIntToBytes(bLen, big.NewInt(int64(1000))) - v := BigIntToBytes(bLen, big.NewInt(int64(1000))) - err = tree.Add(k, v) - c.Assert(err, qt.IsNil) - - // expect original tree to have new root - rootBI = BytesToBigInt(tree.Root()) - c.Check(rootBI.String(), qt.Equals, - "10747149055773881257049574592162159501044114324358186833013814735296193179713") - - // expect snapshot tree to have the old root - rootBI = BytesToBigInt(tree2.Root()) - c.Check(rootBI.String(), qt.Equals, + // check that the snapshotTree can not be updated + _, err = snapshotTree.AddBatch(keys, values) + c.Assert(err, qt.Equals, ErrSnapshotNotEditable) + err = snapshotTree.Add([]byte("test"), []byte("test")) + c.Assert(err, qt.Equals, ErrSnapshotNotEditable) + err = snapshotTree.Update([]byte("test"), []byte("test")) + c.Assert(err, qt.Equals, ErrSnapshotNotEditable) + err = snapshotTree.ImportDump(nil) + c.Assert(err, qt.Equals, ErrSnapshotNotEditable) + + // update the original tree by adding a new key-value, and check that + // snapshotTree still has the old root, but the original tree has a new + // root + err = tree.Add([]byte("test"), []byte("test")) + c.Assert(err, qt.IsNil) + checkRootBIString(c, snapshotTree, "13742386369878513332697380582061714160370929283209286127733983161245560237407") - - err = tree2.Add(k, v) - c.Assert(err, qt.IsNil) - // after adding also the k-v into the snapshot tree, expect original - // tree to have new root - rootBI = BytesToBigInt(tree.Root()) - c.Check(rootBI.String(), qt.Equals, - "10747149055773881257049574592162159501044114324358186833013814735296193179713") + checkRootBIString(c, tree, + "1025190963769001718196479367844646783678188389989148142691917685159698888868") } func BenchmarkAdd(b *testing.B) { diff --git a/vt_test.go b/vt_test.go index 6b607e5..6eba30d 100644 --- a/vt_test.go +++ b/vt_test.go @@ -98,7 +98,9 @@ func testVirtualTree(c *qt.C, maxLevels int, keys, values [][]byte) { // compute hashes, and check Root _, err = vTree.computeHashes() c.Assert(err, qt.IsNil) - c.Assert(vTree.root.h, qt.DeepEquals, tree.root) + root, err := tree.Root() + c.Assert(err, qt.IsNil) + c.Assert(vTree.root.h, qt.DeepEquals, root) } func TestVirtualTreeAddBatch(t *testing.T) { @@ -136,7 +138,9 @@ func TestVirtualTreeAddBatch(t *testing.T) { // compute hashes, and check Root _, err = vTree.computeHashes() c.Assert(err, qt.IsNil) - c.Assert(vTree.root.h, qt.DeepEquals, tree.root) + root, err := tree.Root() + c.Assert(err, qt.IsNil) + c.Assert(vTree.root.h, qt.DeepEquals, root) } func TestGetNodesAtLevel(t *testing.T) {