mirror of
https://github.com/arnaucube/arbo.git
synced 2026-01-14 09:21:30 +01:00
Add Tree.emptyHash & nLeafs methods
This commit is contained in:
95
tree.go
95
tree.go
@@ -13,6 +13,7 @@ package arbo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -41,8 +42,9 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
dbKeyRoot = []byte("root")
|
||||
emptyValue = []byte{0}
|
||||
dbKeyRoot = []byte("root")
|
||||
dbKeyNLeafs = []byte("nleafs")
|
||||
emptyValue = []byte{0}
|
||||
)
|
||||
|
||||
// Tree defines the struct that implements the MerkleTree functionalities
|
||||
@@ -55,6 +57,7 @@ type Tree struct {
|
||||
root []byte
|
||||
|
||||
hashFunction HashFunction
|
||||
emptyHash []byte
|
||||
}
|
||||
|
||||
// NewTree returns a new Tree, if there is a Tree still in the given storage, it
|
||||
@@ -63,18 +66,23 @@ func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error
|
||||
t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
|
||||
t.updateAccessTime()
|
||||
|
||||
t.emptyHash = make([]byte, t.hashFunction.Len()) // empty
|
||||
|
||||
root, err := t.dbGet(dbKeyRoot)
|
||||
if err == db.ErrNotFound {
|
||||
// store new root 0
|
||||
tx, err := t.db.NewTx()
|
||||
t.tx, err = t.db.NewTx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.root = make([]byte, t.hashFunction.Len()) // empty
|
||||
if err = tx.Put(dbKeyRoot, t.root); err != nil {
|
||||
t.root = t.emptyHash
|
||||
if err = t.tx.Put(dbKeyRoot, t.root); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = tx.Commit(); err != nil {
|
||||
if err = t.setNLeafs(0); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = t.tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &t, err
|
||||
@@ -129,6 +137,10 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
|
||||
if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
|
||||
return indexes, err
|
||||
}
|
||||
// update nLeafs
|
||||
if err = t.incNLeafs(uint64(len(keys) - len(indexes))); err != nil {
|
||||
return indexes, err
|
||||
}
|
||||
|
||||
if err := t.tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
@@ -159,6 +171,10 @@ func (t *Tree) Add(k, v []byte) error {
|
||||
if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
|
||||
return err
|
||||
}
|
||||
// update nLeafs
|
||||
if err = t.incNLeafs(1); err != nil {
|
||||
return err
|
||||
}
|
||||
return t.tx.Commit()
|
||||
}
|
||||
|
||||
@@ -208,8 +224,7 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte,
|
||||
}
|
||||
var err error
|
||||
var currValue []byte
|
||||
emptyKey := make([]byte, t.hashFunction.Len())
|
||||
if bytes.Equal(currKey, emptyKey) {
|
||||
if bytes.Equal(currKey, t.emptyHash) {
|
||||
// empty value
|
||||
return currKey, emptyValue, siblings, nil
|
||||
}
|
||||
@@ -277,8 +292,7 @@ func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
|
||||
}
|
||||
|
||||
if oldPath[l] == newPath[l] {
|
||||
emptyKey := make([]byte, t.hashFunction.Len())
|
||||
siblings = append(siblings, emptyKey)
|
||||
siblings = append(siblings, t.emptyHash)
|
||||
|
||||
siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1)
|
||||
if err != nil {
|
||||
@@ -599,9 +613,8 @@ func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool,
|
||||
|
||||
func (t *Tree) dbGet(k []byte) ([]byte, error) {
|
||||
// if key is empty, return empty as value
|
||||
empty := make([]byte, t.hashFunction.Len())
|
||||
if bytes.Equal(k, empty) {
|
||||
return empty, nil
|
||||
if bytes.Equal(k, t.emptyHash) {
|
||||
return t.emptyHash, nil
|
||||
}
|
||||
|
||||
v, err := t.db.Get(k)
|
||||
@@ -614,6 +627,38 @@ func (t *Tree) dbGet(k []byte) ([]byte, error) {
|
||||
return nil, db.ErrNotFound
|
||||
}
|
||||
|
||||
// Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
|
||||
// after the setNLeafs call.
|
||||
func (t *Tree) incNLeafs(nLeafs uint64) error {
|
||||
oldNLeafs, err := t.GetNLeafs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newNLeafs := oldNLeafs + nLeafs
|
||||
return t.setNLeafs(newNLeafs)
|
||||
}
|
||||
|
||||
// Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
|
||||
// after the setNLeafs call.
|
||||
func (t *Tree) setNLeafs(nLeafs uint64) error {
|
||||
b := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(b, nLeafs)
|
||||
if err := t.tx.Put(dbKeyNLeafs, b); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNLeafs returns the number of Leafs of the Tree.
|
||||
func (t *Tree) GetNLeafs() (uint64, error) {
|
||||
b, err := t.dbGet(dbKeyNLeafs)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
nLeafs := binary.LittleEndian.Uint64(b)
|
||||
return nLeafs, nil
|
||||
}
|
||||
|
||||
// Iterate iterates through the full Tree, executing the given function on each
|
||||
// node of the Tree.
|
||||
func (t *Tree) Iterate(f func([]byte, []byte)) error {
|
||||
@@ -677,9 +722,13 @@ func (t *Tree) Dump() ([]byte, error) {
|
||||
func (t *Tree) ImportDump(b []byte) error {
|
||||
t.updateAccessTime()
|
||||
r := bytes.NewReader(b)
|
||||
count := 0
|
||||
// TODO instead of adding one by one, use AddBatch (once AddBatch is
|
||||
// optimized)
|
||||
var err error
|
||||
for {
|
||||
l := make([]byte, 2)
|
||||
_, err := io.ReadFull(r, l)
|
||||
_, err = io.ReadFull(r, l)
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
@@ -699,6 +748,19 @@ func (t *Tree) ImportDump(b []byte) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
count++
|
||||
}
|
||||
// update nLeafs (once ImportDump uses AddBatch method, this will not be
|
||||
// needed)
|
||||
t.tx, err = t.db.NewTx()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := t.incNLeafs(uint64(count)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = t.tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -711,7 +773,6 @@ node [fontname=Monospace,fontsize=10,shape=box]
|
||||
`)
|
||||
nChars := 4
|
||||
nEmpties := 0
|
||||
empty := make([]byte, t.hashFunction.Len())
|
||||
err := t.Iterate(func(k, v []byte) {
|
||||
switch v[0] {
|
||||
case PrefixValueEmpty:
|
||||
@@ -729,13 +790,13 @@ node [fontname=Monospace,fontsize=10,shape=box]
|
||||
lStr := hex.EncodeToString(l[:nChars])
|
||||
rStr := hex.EncodeToString(r[:nChars])
|
||||
eStr := ""
|
||||
if bytes.Equal(l, empty) {
|
||||
if bytes.Equal(l, t.emptyHash) {
|
||||
lStr = fmt.Sprintf("empty%v", nEmpties)
|
||||
eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
|
||||
lStr)
|
||||
nEmpties++
|
||||
}
|
||||
if bytes.Equal(r, empty) {
|
||||
if bytes.Equal(r, t.emptyHash) {
|
||||
rStr = fmt.Sprintf("empty%v", nEmpties)
|
||||
eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
|
||||
rStr)
|
||||
|
||||
48
tree_test.go
48
tree_test.go
@@ -325,6 +325,54 @@ func TestRWMutex(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetGetNLeafs(t *testing.T) {
|
||||
c := qt.New(t)
|
||||
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
|
||||
c.Assert(err, qt.IsNil)
|
||||
|
||||
// 0
|
||||
tree.tx, err = tree.db.NewTx()
|
||||
c.Assert(err, qt.IsNil)
|
||||
|
||||
err = tree.setNLeafs(0)
|
||||
c.Assert(err, qt.IsNil)
|
||||
|
||||
err = tree.tx.Commit()
|
||||
c.Assert(err, qt.IsNil)
|
||||
|
||||
n, err := tree.GetNLeafs()
|
||||
c.Assert(err, qt.IsNil)
|
||||
c.Assert(n, qt.Equals, uint64(0))
|
||||
|
||||
// 1024
|
||||
tree.tx, err = tree.db.NewTx()
|
||||
c.Assert(err, qt.IsNil)
|
||||
|
||||
err = tree.setNLeafs(1024)
|
||||
c.Assert(err, qt.IsNil)
|
||||
|
||||
err = tree.tx.Commit()
|
||||
c.Assert(err, qt.IsNil)
|
||||
|
||||
n, err = tree.GetNLeafs()
|
||||
c.Assert(err, qt.IsNil)
|
||||
c.Assert(n, qt.Equals, uint64(1024))
|
||||
|
||||
// 2**64 -1
|
||||
tree.tx, err = tree.db.NewTx()
|
||||
c.Assert(err, qt.IsNil)
|
||||
|
||||
err = tree.setNLeafs(18446744073709551615)
|
||||
c.Assert(err, qt.IsNil)
|
||||
|
||||
err = tree.tx.Commit()
|
||||
c.Assert(err, qt.IsNil)
|
||||
|
||||
n, err = tree.GetNLeafs()
|
||||
c.Assert(err, qt.IsNil)
|
||||
c.Assert(n, qt.Equals, uint64(18446744073709551615))
|
||||
}
|
||||
|
||||
func BenchmarkAdd(b *testing.B) {
|
||||
// prepare inputs
|
||||
var ks, vs [][]byte
|
||||
|
||||
Reference in New Issue
Block a user