diff --git a/tree.go b/tree.go index 14a31b0..186de97 100644 --- a/tree.go +++ b/tree.go @@ -13,6 +13,7 @@ package arbo import ( "bytes" + "encoding/binary" "encoding/hex" "fmt" "io" @@ -41,8 +42,9 @@ const ( ) var ( - dbKeyRoot = []byte("root") - emptyValue = []byte{0} + dbKeyRoot = []byte("root") + dbKeyNLeafs = []byte("nleafs") + emptyValue = []byte{0} ) // Tree defines the struct that implements the MerkleTree functionalities @@ -55,6 +57,7 @@ type Tree struct { root []byte hashFunction HashFunction + emptyHash []byte } // NewTree returns a new Tree, if there is a Tree still in the given storage, it @@ -63,18 +66,23 @@ func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash} t.updateAccessTime() + t.emptyHash = make([]byte, t.hashFunction.Len()) // empty + root, err := t.dbGet(dbKeyRoot) if err == db.ErrNotFound { // store new root 0 - tx, err := t.db.NewTx() + t.tx, err = t.db.NewTx() if err != nil { return nil, err } - t.root = make([]byte, t.hashFunction.Len()) // empty - if err = tx.Put(dbKeyRoot, t.root); err != nil { + t.root = t.emptyHash + if err = t.tx.Put(dbKeyRoot, t.root); err != nil { + return nil, err + } + if err = t.setNLeafs(0); err != nil { return nil, err } - if err = tx.Commit(); err != nil { + if err = t.tx.Commit(); err != nil { return nil, err } return &t, err @@ -129,6 +137,10 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { if err := t.tx.Put(dbKeyRoot, t.root); err != nil { return indexes, err } + // update nLeafs + if err = t.incNLeafs(uint64(len(keys) - len(indexes))); err != nil { + return indexes, err + } if err := t.tx.Commit(); err != nil { return nil, err @@ -159,6 +171,10 @@ func (t *Tree) Add(k, v []byte) error { if err := t.tx.Put(dbKeyRoot, t.root); err != nil { return err } + // update nLeafs + if err = t.incNLeafs(1); err != nil { + return err + } return t.tx.Commit() } @@ -208,8 +224,7 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, } var err error var currValue []byte - emptyKey := make([]byte, t.hashFunction.Len()) - if bytes.Equal(currKey, emptyKey) { + if bytes.Equal(currKey, t.emptyHash) { // empty value return currKey, emptyValue, siblings, nil } @@ -277,8 +292,7 @@ func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath, } if oldPath[l] == newPath[l] { - emptyKey := make([]byte, t.hashFunction.Len()) - siblings = append(siblings, emptyKey) + siblings = append(siblings, t.emptyHash) siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1) if err != nil { @@ -599,9 +613,8 @@ func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, func (t *Tree) dbGet(k []byte) ([]byte, error) { // if key is empty, return empty as value - empty := make([]byte, t.hashFunction.Len()) - if bytes.Equal(k, empty) { - return empty, nil + if bytes.Equal(k, t.emptyHash) { + return t.emptyHash, nil } v, err := t.db.Get(k) @@ -614,6 +627,38 @@ func (t *Tree) dbGet(k []byte) ([]byte, error) { return nil, db.ErrNotFound } +// Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit +// after the setNLeafs call. +func (t *Tree) incNLeafs(nLeafs uint64) error { + oldNLeafs, err := t.GetNLeafs() + if err != nil { + return err + } + newNLeafs := oldNLeafs + nLeafs + return t.setNLeafs(newNLeafs) +} + +// Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit +// after the setNLeafs call. +func (t *Tree) setNLeafs(nLeafs uint64) error { + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, nLeafs) + if err := t.tx.Put(dbKeyNLeafs, b); err != nil { + return err + } + return nil +} + +// GetNLeafs returns the number of Leafs of the Tree. +func (t *Tree) GetNLeafs() (uint64, error) { + b, err := t.dbGet(dbKeyNLeafs) + if err != nil { + return 0, err + } + nLeafs := binary.LittleEndian.Uint64(b) + return nLeafs, nil +} + // Iterate iterates through the full Tree, executing the given function on each // node of the Tree. func (t *Tree) Iterate(f func([]byte, []byte)) error { @@ -677,9 +722,13 @@ func (t *Tree) Dump() ([]byte, error) { func (t *Tree) ImportDump(b []byte) error { t.updateAccessTime() r := bytes.NewReader(b) + count := 0 + // TODO instead of adding one by one, use AddBatch (once AddBatch is + // optimized) + var err error for { l := make([]byte, 2) - _, err := io.ReadFull(r, l) + _, err = io.ReadFull(r, l) if err == io.EOF { break } else if err != nil { @@ -699,6 +748,19 @@ func (t *Tree) ImportDump(b []byte) error { if err != nil { return err } + count++ + } + // update nLeafs (once ImportDump uses AddBatch method, this will not be + // needed) + t.tx, err = t.db.NewTx() + if err != nil { + return err + } + if err := t.incNLeafs(uint64(count)); err != nil { + return err + } + if err = t.tx.Commit(); err != nil { + return err } return nil } @@ -711,7 +773,6 @@ node [fontname=Monospace,fontsize=10,shape=box] `) nChars := 4 nEmpties := 0 - empty := make([]byte, t.hashFunction.Len()) err := t.Iterate(func(k, v []byte) { switch v[0] { case PrefixValueEmpty: @@ -729,13 +790,13 @@ node [fontname=Monospace,fontsize=10,shape=box] lStr := hex.EncodeToString(l[:nChars]) rStr := hex.EncodeToString(r[:nChars]) eStr := "" - if bytes.Equal(l, empty) { + if bytes.Equal(l, t.emptyHash) { lStr = fmt.Sprintf("empty%v", nEmpties) eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n", lStr) nEmpties++ } - if bytes.Equal(r, empty) { + if bytes.Equal(r, t.emptyHash) { rStr = fmt.Sprintf("empty%v", nEmpties) eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n", rStr) diff --git a/tree_test.go b/tree_test.go index 53d1c15..963d2bc 100644 --- a/tree_test.go +++ b/tree_test.go @@ -325,6 +325,54 @@ func TestRWMutex(t *testing.T) { } } +func TestSetGetNLeafs(t *testing.T) { + c := qt.New(t) + tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + + // 0 + tree.tx, err = tree.db.NewTx() + c.Assert(err, qt.IsNil) + + err = tree.setNLeafs(0) + c.Assert(err, qt.IsNil) + + err = tree.tx.Commit() + c.Assert(err, qt.IsNil) + + n, err := tree.GetNLeafs() + c.Assert(err, qt.IsNil) + c.Assert(n, qt.Equals, uint64(0)) + + // 1024 + tree.tx, err = tree.db.NewTx() + c.Assert(err, qt.IsNil) + + err = tree.setNLeafs(1024) + c.Assert(err, qt.IsNil) + + err = tree.tx.Commit() + c.Assert(err, qt.IsNil) + + n, err = tree.GetNLeafs() + c.Assert(err, qt.IsNil) + c.Assert(n, qt.Equals, uint64(1024)) + + // 2**64 -1 + tree.tx, err = tree.db.NewTx() + c.Assert(err, qt.IsNil) + + err = tree.setNLeafs(18446744073709551615) + c.Assert(err, qt.IsNil) + + err = tree.tx.Commit() + c.Assert(err, qt.IsNil) + + n, err = tree.GetNLeafs() + c.Assert(err, qt.IsNil) + c.Assert(n, qt.Equals, uint64(18446744073709551615)) +} + func BenchmarkAdd(b *testing.B) { // prepare inputs var ks, vs [][]byte