Browse Source

Add Tree.emptyHash & nLeafs methods

master
arnaucube 3 years ago
parent
commit
6f43980c0f
2 changed files with 126 additions and 17 deletions
  1. +78
    -17
      tree.go
  2. +48
    -0
      tree_test.go

+ 78
- 17
tree.go

@ -13,6 +13,7 @@ package arbo
import ( import (
"bytes" "bytes"
"encoding/binary"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io" "io"
@ -41,8 +42,9 @@ const (
) )
var ( var (
dbKeyRoot = []byte("root")
emptyValue = []byte{0}
dbKeyRoot = []byte("root")
dbKeyNLeafs = []byte("nleafs")
emptyValue = []byte{0}
) )
// Tree defines the struct that implements the MerkleTree functionalities // Tree defines the struct that implements the MerkleTree functionalities
@ -55,6 +57,7 @@ type Tree struct {
root []byte root []byte
hashFunction HashFunction hashFunction HashFunction
emptyHash []byte
} }
// NewTree returns a new Tree, if there is a Tree still in the given storage, it // NewTree returns a new Tree, if there is a Tree still in the given storage, it
@ -63,18 +66,23 @@ func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error
t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash} t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
t.updateAccessTime() t.updateAccessTime()
t.emptyHash = make([]byte, t.hashFunction.Len()) // empty
root, err := t.dbGet(dbKeyRoot) root, err := t.dbGet(dbKeyRoot)
if err == db.ErrNotFound { if err == db.ErrNotFound {
// store new root 0 // store new root 0
tx, err := t.db.NewTx()
t.tx, err = t.db.NewTx()
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.root = make([]byte, t.hashFunction.Len()) // empty
if err = tx.Put(dbKeyRoot, t.root); err != nil {
t.root = t.emptyHash
if err = t.tx.Put(dbKeyRoot, t.root); err != nil {
return nil, err
}
if err = t.setNLeafs(0); err != nil {
return nil, err return nil, err
} }
if err = tx.Commit(); err != nil {
if err = t.tx.Commit(); err != nil {
return nil, err return nil, err
} }
return &t, err return &t, err
@ -129,6 +137,10 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
if err := t.tx.Put(dbKeyRoot, t.root); err != nil { if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
return indexes, err return indexes, err
} }
// update nLeafs
if err = t.incNLeafs(uint64(len(keys) - len(indexes))); err != nil {
return indexes, err
}
if err := t.tx.Commit(); err != nil { if err := t.tx.Commit(); err != nil {
return nil, err return nil, err
@ -159,6 +171,10 @@ func (t *Tree) Add(k, v []byte) error {
if err := t.tx.Put(dbKeyRoot, t.root); err != nil { if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
return err return err
} }
// update nLeafs
if err = t.incNLeafs(1); err != nil {
return err
}
return t.tx.Commit() return t.tx.Commit()
} }
@ -208,8 +224,7 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte,
} }
var err error var err error
var currValue []byte var currValue []byte
emptyKey := make([]byte, t.hashFunction.Len())
if bytes.Equal(currKey, emptyKey) {
if bytes.Equal(currKey, t.emptyHash) {
// empty value // empty value
return currKey, emptyValue, siblings, nil return currKey, emptyValue, siblings, nil
} }
@ -277,8 +292,7 @@ func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
} }
if oldPath[l] == newPath[l] { if oldPath[l] == newPath[l] {
emptyKey := make([]byte, t.hashFunction.Len())
siblings = append(siblings, emptyKey)
siblings = append(siblings, t.emptyHash)
siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1) siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1)
if err != nil { if err != nil {
@ -599,9 +613,8 @@ func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool,
func (t *Tree) dbGet(k []byte) ([]byte, error) { func (t *Tree) dbGet(k []byte) ([]byte, error) {
// if key is empty, return empty as value // if key is empty, return empty as value
empty := make([]byte, t.hashFunction.Len())
if bytes.Equal(k, empty) {
return empty, nil
if bytes.Equal(k, t.emptyHash) {
return t.emptyHash, nil
} }
v, err := t.db.Get(k) v, err := t.db.Get(k)
@ -614,6 +627,38 @@ func (t *Tree) dbGet(k []byte) ([]byte, error) {
return nil, db.ErrNotFound return nil, db.ErrNotFound
} }
// Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
// after the setNLeafs call.
func (t *Tree) incNLeafs(nLeafs uint64) error {
oldNLeafs, err := t.GetNLeafs()
if err != nil {
return err
}
newNLeafs := oldNLeafs + nLeafs
return t.setNLeafs(newNLeafs)
}
// Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
// after the setNLeafs call.
func (t *Tree) setNLeafs(nLeafs uint64) error {
b := make([]byte, 8)
binary.LittleEndian.PutUint64(b, nLeafs)
if err := t.tx.Put(dbKeyNLeafs, b); err != nil {
return err
}
return nil
}
// GetNLeafs returns the number of Leafs of the Tree.
func (t *Tree) GetNLeafs() (uint64, error) {
b, err := t.dbGet(dbKeyNLeafs)
if err != nil {
return 0, err
}
nLeafs := binary.LittleEndian.Uint64(b)
return nLeafs, nil
}
// Iterate iterates through the full Tree, executing the given function on each // Iterate iterates through the full Tree, executing the given function on each
// node of the Tree. // node of the Tree.
func (t *Tree) Iterate(f func([]byte, []byte)) error { func (t *Tree) Iterate(f func([]byte, []byte)) error {
@ -677,9 +722,13 @@ func (t *Tree) Dump() ([]byte, error) {
func (t *Tree) ImportDump(b []byte) error { func (t *Tree) ImportDump(b []byte) error {
t.updateAccessTime() t.updateAccessTime()
r := bytes.NewReader(b) r := bytes.NewReader(b)
count := 0
// TODO instead of adding one by one, use AddBatch (once AddBatch is
// optimized)
var err error
for { for {
l := make([]byte, 2) l := make([]byte, 2)
_, err := io.ReadFull(r, l)
_, err = io.ReadFull(r, l)
if err == io.EOF { if err == io.EOF {
break break
} else if err != nil { } else if err != nil {
@ -699,6 +748,19 @@ func (t *Tree) ImportDump(b []byte) error {
if err != nil { if err != nil {
return err return err
} }
count++
}
// update nLeafs (once ImportDump uses AddBatch method, this will not be
// needed)
t.tx, err = t.db.NewTx()
if err != nil {
return err
}
if err := t.incNLeafs(uint64(count)); err != nil {
return err
}
if err = t.tx.Commit(); err != nil {
return err
} }
return nil return nil
} }
@ -711,7 +773,6 @@ node [fontname=Monospace,fontsize=10,shape=box]
`) `)
nChars := 4 nChars := 4
nEmpties := 0 nEmpties := 0
empty := make([]byte, t.hashFunction.Len())
err := t.Iterate(func(k, v []byte) { err := t.Iterate(func(k, v []byte) {
switch v[0] { switch v[0] {
case PrefixValueEmpty: case PrefixValueEmpty:
@ -729,13 +790,13 @@ node [fontname=Monospace,fontsize=10,shape=box]
lStr := hex.EncodeToString(l[:nChars]) lStr := hex.EncodeToString(l[:nChars])
rStr := hex.EncodeToString(r[:nChars]) rStr := hex.EncodeToString(r[:nChars])
eStr := "" eStr := ""
if bytes.Equal(l, empty) {
if bytes.Equal(l, t.emptyHash) {
lStr = fmt.Sprintf("empty%v", nEmpties) lStr = fmt.Sprintf("empty%v", nEmpties)
eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n", eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
lStr) lStr)
nEmpties++ nEmpties++
} }
if bytes.Equal(r, empty) {
if bytes.Equal(r, t.emptyHash) {
rStr = fmt.Sprintf("empty%v", nEmpties) rStr = fmt.Sprintf("empty%v", nEmpties)
eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n", eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
rStr) rStr)

+ 48
- 0
tree_test.go

@ -325,6 +325,54 @@ func TestRWMutex(t *testing.T) {
} }
} }
func TestSetGetNLeafs(t *testing.T) {
c := qt.New(t)
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
// 0
tree.tx, err = tree.db.NewTx()
c.Assert(err, qt.IsNil)
err = tree.setNLeafs(0)
c.Assert(err, qt.IsNil)
err = tree.tx.Commit()
c.Assert(err, qt.IsNil)
n, err := tree.GetNLeafs()
c.Assert(err, qt.IsNil)
c.Assert(n, qt.Equals, uint64(0))
// 1024
tree.tx, err = tree.db.NewTx()
c.Assert(err, qt.IsNil)
err = tree.setNLeafs(1024)
c.Assert(err, qt.IsNil)
err = tree.tx.Commit()
c.Assert(err, qt.IsNil)
n, err = tree.GetNLeafs()
c.Assert(err, qt.IsNil)
c.Assert(n, qt.Equals, uint64(1024))
// 2**64 -1
tree.tx, err = tree.db.NewTx()
c.Assert(err, qt.IsNil)
err = tree.setNLeafs(18446744073709551615)
c.Assert(err, qt.IsNil)
err = tree.tx.Commit()
c.Assert(err, qt.IsNil)
n, err = tree.GetNLeafs()
c.Assert(err, qt.IsNil)
c.Assert(n, qt.Equals, uint64(18446744073709551615))
}
func BenchmarkAdd(b *testing.B) { func BenchmarkAdd(b *testing.B) {
// prepare inputs // prepare inputs
var ks, vs [][]byte var ks, vs [][]byte

Loading…
Cancel
Save