mirror of
https://github.com/arnaucube/arbo.git
synced 2026-01-14 09:21:30 +01:00
Add Tree.emptyHash & nLeafs methods
This commit is contained in:
95
tree.go
95
tree.go
@@ -13,6 +13,7 @@ package arbo
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -41,8 +42,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
dbKeyRoot = []byte("root")
|
dbKeyRoot = []byte("root")
|
||||||
emptyValue = []byte{0}
|
dbKeyNLeafs = []byte("nleafs")
|
||||||
|
emptyValue = []byte{0}
|
||||||
)
|
)
|
||||||
|
|
||||||
// Tree defines the struct that implements the MerkleTree functionalities
|
// Tree defines the struct that implements the MerkleTree functionalities
|
||||||
@@ -55,6 +57,7 @@ type Tree struct {
|
|||||||
root []byte
|
root []byte
|
||||||
|
|
||||||
hashFunction HashFunction
|
hashFunction HashFunction
|
||||||
|
emptyHash []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTree returns a new Tree, if there is a Tree still in the given storage, it
|
// NewTree returns a new Tree, if there is a Tree still in the given storage, it
|
||||||
@@ -63,18 +66,23 @@ func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error
|
|||||||
t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
|
t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
|
||||||
t.updateAccessTime()
|
t.updateAccessTime()
|
||||||
|
|
||||||
|
t.emptyHash = make([]byte, t.hashFunction.Len()) // empty
|
||||||
|
|
||||||
root, err := t.dbGet(dbKeyRoot)
|
root, err := t.dbGet(dbKeyRoot)
|
||||||
if err == db.ErrNotFound {
|
if err == db.ErrNotFound {
|
||||||
// store new root 0
|
// store new root 0
|
||||||
tx, err := t.db.NewTx()
|
t.tx, err = t.db.NewTx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
t.root = make([]byte, t.hashFunction.Len()) // empty
|
t.root = t.emptyHash
|
||||||
if err = tx.Put(dbKeyRoot, t.root); err != nil {
|
if err = t.tx.Put(dbKeyRoot, t.root); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = tx.Commit(); err != nil {
|
if err = t.setNLeafs(0); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = t.tx.Commit(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &t, err
|
return &t, err
|
||||||
@@ -129,6 +137,10 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
|
|||||||
if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
|
if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
|
||||||
return indexes, err
|
return indexes, err
|
||||||
}
|
}
|
||||||
|
// update nLeafs
|
||||||
|
if err = t.incNLeafs(uint64(len(keys) - len(indexes))); err != nil {
|
||||||
|
return indexes, err
|
||||||
|
}
|
||||||
|
|
||||||
if err := t.tx.Commit(); err != nil {
|
if err := t.tx.Commit(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -159,6 +171,10 @@ func (t *Tree) Add(k, v []byte) error {
|
|||||||
if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
|
if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
// update nLeafs
|
||||||
|
if err = t.incNLeafs(1); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return t.tx.Commit()
|
return t.tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -208,8 +224,7 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte,
|
|||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
var currValue []byte
|
var currValue []byte
|
||||||
emptyKey := make([]byte, t.hashFunction.Len())
|
if bytes.Equal(currKey, t.emptyHash) {
|
||||||
if bytes.Equal(currKey, emptyKey) {
|
|
||||||
// empty value
|
// empty value
|
||||||
return currKey, emptyValue, siblings, nil
|
return currKey, emptyValue, siblings, nil
|
||||||
}
|
}
|
||||||
@@ -277,8 +292,7 @@ func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if oldPath[l] == newPath[l] {
|
if oldPath[l] == newPath[l] {
|
||||||
emptyKey := make([]byte, t.hashFunction.Len())
|
siblings = append(siblings, t.emptyHash)
|
||||||
siblings = append(siblings, emptyKey)
|
|
||||||
|
|
||||||
siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1)
|
siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -599,9 +613,8 @@ func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool,
|
|||||||
|
|
||||||
func (t *Tree) dbGet(k []byte) ([]byte, error) {
|
func (t *Tree) dbGet(k []byte) ([]byte, error) {
|
||||||
// if key is empty, return empty as value
|
// if key is empty, return empty as value
|
||||||
empty := make([]byte, t.hashFunction.Len())
|
if bytes.Equal(k, t.emptyHash) {
|
||||||
if bytes.Equal(k, empty) {
|
return t.emptyHash, nil
|
||||||
return empty, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
v, err := t.db.Get(k)
|
v, err := t.db.Get(k)
|
||||||
@@ -614,6 +627,38 @@ func (t *Tree) dbGet(k []byte) ([]byte, error) {
|
|||||||
return nil, db.ErrNotFound
|
return nil, db.ErrNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
|
||||||
|
// after the setNLeafs call.
|
||||||
|
func (t *Tree) incNLeafs(nLeafs uint64) error {
|
||||||
|
oldNLeafs, err := t.GetNLeafs()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
newNLeafs := oldNLeafs + nLeafs
|
||||||
|
return t.setNLeafs(newNLeafs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
|
||||||
|
// after the setNLeafs call.
|
||||||
|
func (t *Tree) setNLeafs(nLeafs uint64) error {
|
||||||
|
b := make([]byte, 8)
|
||||||
|
binary.LittleEndian.PutUint64(b, nLeafs)
|
||||||
|
if err := t.tx.Put(dbKeyNLeafs, b); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNLeafs returns the number of Leafs of the Tree.
|
||||||
|
func (t *Tree) GetNLeafs() (uint64, error) {
|
||||||
|
b, err := t.dbGet(dbKeyNLeafs)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
nLeafs := binary.LittleEndian.Uint64(b)
|
||||||
|
return nLeafs, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Iterate iterates through the full Tree, executing the given function on each
|
// Iterate iterates through the full Tree, executing the given function on each
|
||||||
// node of the Tree.
|
// node of the Tree.
|
||||||
func (t *Tree) Iterate(f func([]byte, []byte)) error {
|
func (t *Tree) Iterate(f func([]byte, []byte)) error {
|
||||||
@@ -677,9 +722,13 @@ func (t *Tree) Dump() ([]byte, error) {
|
|||||||
func (t *Tree) ImportDump(b []byte) error {
|
func (t *Tree) ImportDump(b []byte) error {
|
||||||
t.updateAccessTime()
|
t.updateAccessTime()
|
||||||
r := bytes.NewReader(b)
|
r := bytes.NewReader(b)
|
||||||
|
count := 0
|
||||||
|
// TODO instead of adding one by one, use AddBatch (once AddBatch is
|
||||||
|
// optimized)
|
||||||
|
var err error
|
||||||
for {
|
for {
|
||||||
l := make([]byte, 2)
|
l := make([]byte, 2)
|
||||||
_, err := io.ReadFull(r, l)
|
_, err = io.ReadFull(r, l)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@@ -699,6 +748,19 @@ func (t *Tree) ImportDump(b []byte) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
// update nLeafs (once ImportDump uses AddBatch method, this will not be
|
||||||
|
// needed)
|
||||||
|
t.tx, err = t.db.NewTx()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := t.incNLeafs(uint64(count)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = t.tx.Commit(); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -711,7 +773,6 @@ node [fontname=Monospace,fontsize=10,shape=box]
|
|||||||
`)
|
`)
|
||||||
nChars := 4
|
nChars := 4
|
||||||
nEmpties := 0
|
nEmpties := 0
|
||||||
empty := make([]byte, t.hashFunction.Len())
|
|
||||||
err := t.Iterate(func(k, v []byte) {
|
err := t.Iterate(func(k, v []byte) {
|
||||||
switch v[0] {
|
switch v[0] {
|
||||||
case PrefixValueEmpty:
|
case PrefixValueEmpty:
|
||||||
@@ -729,13 +790,13 @@ node [fontname=Monospace,fontsize=10,shape=box]
|
|||||||
lStr := hex.EncodeToString(l[:nChars])
|
lStr := hex.EncodeToString(l[:nChars])
|
||||||
rStr := hex.EncodeToString(r[:nChars])
|
rStr := hex.EncodeToString(r[:nChars])
|
||||||
eStr := ""
|
eStr := ""
|
||||||
if bytes.Equal(l, empty) {
|
if bytes.Equal(l, t.emptyHash) {
|
||||||
lStr = fmt.Sprintf("empty%v", nEmpties)
|
lStr = fmt.Sprintf("empty%v", nEmpties)
|
||||||
eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
|
eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
|
||||||
lStr)
|
lStr)
|
||||||
nEmpties++
|
nEmpties++
|
||||||
}
|
}
|
||||||
if bytes.Equal(r, empty) {
|
if bytes.Equal(r, t.emptyHash) {
|
||||||
rStr = fmt.Sprintf("empty%v", nEmpties)
|
rStr = fmt.Sprintf("empty%v", nEmpties)
|
||||||
eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
|
eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
|
||||||
rStr)
|
rStr)
|
||||||
|
|||||||
48
tree_test.go
48
tree_test.go
@@ -325,6 +325,54 @@ func TestRWMutex(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetGetNLeafs(t *testing.T) {
|
||||||
|
c := qt.New(t)
|
||||||
|
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
|
||||||
|
// 0
|
||||||
|
tree.tx, err = tree.db.NewTx()
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
|
||||||
|
err = tree.setNLeafs(0)
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
|
||||||
|
err = tree.tx.Commit()
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
|
||||||
|
n, err := tree.GetNLeafs()
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
c.Assert(n, qt.Equals, uint64(0))
|
||||||
|
|
||||||
|
// 1024
|
||||||
|
tree.tx, err = tree.db.NewTx()
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
|
||||||
|
err = tree.setNLeafs(1024)
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
|
||||||
|
err = tree.tx.Commit()
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
|
||||||
|
n, err = tree.GetNLeafs()
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
c.Assert(n, qt.Equals, uint64(1024))
|
||||||
|
|
||||||
|
// 2**64 -1
|
||||||
|
tree.tx, err = tree.db.NewTx()
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
|
||||||
|
err = tree.setNLeafs(18446744073709551615)
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
|
||||||
|
err = tree.tx.Commit()
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
|
||||||
|
n, err = tree.GetNLeafs()
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
c.Assert(n, qt.Equals, uint64(18446744073709551615))
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkAdd(b *testing.B) {
|
func BenchmarkAdd(b *testing.B) {
|
||||||
// prepare inputs
|
// prepare inputs
|
||||||
var ks, vs [][]byte
|
var ks, vs [][]byte
|
||||||
|
|||||||
Reference in New Issue
Block a user