|
|
@ -13,6 +13,7 @@ package arbo |
|
|
|
|
|
|
|
import ( |
|
|
|
"bytes" |
|
|
|
"encoding/binary" |
|
|
|
"encoding/hex" |
|
|
|
"fmt" |
|
|
|
"io" |
|
|
@ -41,8 +42,9 @@ const ( |
|
|
|
) |
|
|
|
|
|
|
|
var ( |
|
|
|
dbKeyRoot = []byte("root") |
|
|
|
emptyValue = []byte{0} |
|
|
|
dbKeyRoot = []byte("root") |
|
|
|
dbKeyNLeafs = []byte("nleafs") |
|
|
|
emptyValue = []byte{0} |
|
|
|
) |
|
|
|
|
|
|
|
// Tree defines the struct that implements the MerkleTree functionalities
|
|
|
@ -55,6 +57,7 @@ type Tree struct { |
|
|
|
root []byte |
|
|
|
|
|
|
|
hashFunction HashFunction |
|
|
|
emptyHash []byte |
|
|
|
} |
|
|
|
|
|
|
|
// NewTree returns a new Tree, if there is a Tree still in the given storage, it
|
|
|
@ -63,18 +66,23 @@ func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error |
|
|
|
t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash} |
|
|
|
t.updateAccessTime() |
|
|
|
|
|
|
|
t.emptyHash = make([]byte, t.hashFunction.Len()) // empty
|
|
|
|
|
|
|
|
root, err := t.dbGet(dbKeyRoot) |
|
|
|
if err == db.ErrNotFound { |
|
|
|
// store new root 0
|
|
|
|
tx, err := t.db.NewTx() |
|
|
|
t.tx, err = t.db.NewTx() |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
t.root = make([]byte, t.hashFunction.Len()) // empty
|
|
|
|
if err = tx.Put(dbKeyRoot, t.root); err != nil { |
|
|
|
t.root = t.emptyHash |
|
|
|
if err = t.tx.Put(dbKeyRoot, t.root); err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
if err = t.setNLeafs(0); err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
if err = tx.Commit(); err != nil { |
|
|
|
if err = t.tx.Commit(); err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
return &t, err |
|
|
@ -129,6 +137,10 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { |
|
|
|
if err := t.tx.Put(dbKeyRoot, t.root); err != nil { |
|
|
|
return indexes, err |
|
|
|
} |
|
|
|
// update nLeafs
|
|
|
|
if err = t.incNLeafs(uint64(len(keys) - len(indexes))); err != nil { |
|
|
|
return indexes, err |
|
|
|
} |
|
|
|
|
|
|
|
if err := t.tx.Commit(); err != nil { |
|
|
|
return nil, err |
|
|
@ -159,6 +171,10 @@ func (t *Tree) Add(k, v []byte) error { |
|
|
|
if err := t.tx.Put(dbKeyRoot, t.root); err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
// update nLeafs
|
|
|
|
if err = t.incNLeafs(1); err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
return t.tx.Commit() |
|
|
|
} |
|
|
|
|
|
|
@ -208,8 +224,7 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, |
|
|
|
} |
|
|
|
var err error |
|
|
|
var currValue []byte |
|
|
|
emptyKey := make([]byte, t.hashFunction.Len()) |
|
|
|
if bytes.Equal(currKey, emptyKey) { |
|
|
|
if bytes.Equal(currKey, t.emptyHash) { |
|
|
|
// empty value
|
|
|
|
return currKey, emptyValue, siblings, nil |
|
|
|
} |
|
|
@ -277,8 +292,7 @@ func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath, |
|
|
|
} |
|
|
|
|
|
|
|
if oldPath[l] == newPath[l] { |
|
|
|
emptyKey := make([]byte, t.hashFunction.Len()) |
|
|
|
siblings = append(siblings, emptyKey) |
|
|
|
siblings = append(siblings, t.emptyHash) |
|
|
|
|
|
|
|
siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1) |
|
|
|
if err != nil { |
|
|
@ -599,9 +613,8 @@ func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, |
|
|
|
|
|
|
|
func (t *Tree) dbGet(k []byte) ([]byte, error) { |
|
|
|
// if key is empty, return empty as value
|
|
|
|
empty := make([]byte, t.hashFunction.Len()) |
|
|
|
if bytes.Equal(k, empty) { |
|
|
|
return empty, nil |
|
|
|
if bytes.Equal(k, t.emptyHash) { |
|
|
|
return t.emptyHash, nil |
|
|
|
} |
|
|
|
|
|
|
|
v, err := t.db.Get(k) |
|
|
@ -614,6 +627,38 @@ func (t *Tree) dbGet(k []byte) ([]byte, error) { |
|
|
|
return nil, db.ErrNotFound |
|
|
|
} |
|
|
|
|
|
|
|
// Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
|
|
|
|
// after the setNLeafs call.
|
|
|
|
func (t *Tree) incNLeafs(nLeafs uint64) error { |
|
|
|
oldNLeafs, err := t.GetNLeafs() |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
newNLeafs := oldNLeafs + nLeafs |
|
|
|
return t.setNLeafs(newNLeafs) |
|
|
|
} |
|
|
|
|
|
|
|
// Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
|
|
|
|
// after the setNLeafs call.
|
|
|
|
func (t *Tree) setNLeafs(nLeafs uint64) error { |
|
|
|
b := make([]byte, 8) |
|
|
|
binary.LittleEndian.PutUint64(b, nLeafs) |
|
|
|
if err := t.tx.Put(dbKeyNLeafs, b); err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
return nil |
|
|
|
} |
|
|
|
|
|
|
|
// GetNLeafs returns the number of Leafs of the Tree.
|
|
|
|
func (t *Tree) GetNLeafs() (uint64, error) { |
|
|
|
b, err := t.dbGet(dbKeyNLeafs) |
|
|
|
if err != nil { |
|
|
|
return 0, err |
|
|
|
} |
|
|
|
nLeafs := binary.LittleEndian.Uint64(b) |
|
|
|
return nLeafs, nil |
|
|
|
} |
|
|
|
|
|
|
|
// Iterate iterates through the full Tree, executing the given function on each
|
|
|
|
// node of the Tree.
|
|
|
|
func (t *Tree) Iterate(f func([]byte, []byte)) error { |
|
|
@ -677,9 +722,13 @@ func (t *Tree) Dump() ([]byte, error) { |
|
|
|
func (t *Tree) ImportDump(b []byte) error { |
|
|
|
t.updateAccessTime() |
|
|
|
r := bytes.NewReader(b) |
|
|
|
count := 0 |
|
|
|
// TODO instead of adding one by one, use AddBatch (once AddBatch is
|
|
|
|
// optimized)
|
|
|
|
var err error |
|
|
|
for { |
|
|
|
l := make([]byte, 2) |
|
|
|
_, err := io.ReadFull(r, l) |
|
|
|
_, err = io.ReadFull(r, l) |
|
|
|
if err == io.EOF { |
|
|
|
break |
|
|
|
} else if err != nil { |
|
|
@ -699,6 +748,19 @@ func (t *Tree) ImportDump(b []byte) error { |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
count++ |
|
|
|
} |
|
|
|
// update nLeafs (once ImportDump uses AddBatch method, this will not be
|
|
|
|
// needed)
|
|
|
|
t.tx, err = t.db.NewTx() |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
if err := t.incNLeafs(uint64(count)); err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
if err = t.tx.Commit(); err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
return nil |
|
|
|
} |
|
|
@ -711,7 +773,6 @@ node [fontname=Monospace,fontsize=10,shape=box] |
|
|
|
`) |
|
|
|
nChars := 4 |
|
|
|
nEmpties := 0 |
|
|
|
empty := make([]byte, t.hashFunction.Len()) |
|
|
|
err := t.Iterate(func(k, v []byte) { |
|
|
|
switch v[0] { |
|
|
|
case PrefixValueEmpty: |
|
|
@ -729,13 +790,13 @@ node [fontname=Monospace,fontsize=10,shape=box] |
|
|
|
lStr := hex.EncodeToString(l[:nChars]) |
|
|
|
rStr := hex.EncodeToString(r[:nChars]) |
|
|
|
eStr := "" |
|
|
|
if bytes.Equal(l, empty) { |
|
|
|
if bytes.Equal(l, t.emptyHash) { |
|
|
|
lStr = fmt.Sprintf("empty%v", nEmpties) |
|
|
|
eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n", |
|
|
|
lStr) |
|
|
|
nEmpties++ |
|
|
|
} |
|
|
|
if bytes.Equal(r, empty) { |
|
|
|
if bytes.Equal(r, t.emptyHash) { |
|
|
|
rStr = fmt.Sprintf("empty%v", nEmpties) |
|
|
|
eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n", |
|
|
|
rStr) |
|
|
|