Browse Source

Add AddBatch CaseC

CASE C: ALMOST CASE B --> if Tree has few Leafs (but numLeafs>=minLeafsThreshold)
==============================================================================
- Use A, B, G, F as Roots of subtrees
- Do CASE B for each subtree
- Then go from L to the Root

              R
             /  \
            /    \
           /      \
          *        *
         / |      / \
        /  |     /   \
       /   |    /     \
L:    A    B   G       D
              / \
             /   \
            /     \
           C      *
                 / \
                /   \
               /     \
              ...    ... (nLeafs >= minLeafsThreshold)
master
arnaucube 3 years ago
parent
commit
a3473079de
4 changed files with 488 additions and 52 deletions
  1. +201
    -22
      addbatch.go
  2. +221
    -0
      addbatch_test.go
  3. +65
    -29
      tree.go
  4. +1
    -1
      utils.go

+ 201
- 22
addbatch.go

@ -3,6 +3,7 @@ package arbo
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"math"
"sort" "sort"
) )
@ -25,11 +26,24 @@ the leafs)
- Do CASE A for the new Tree, giving the already existing key&values (leafs) - Do CASE A for the new Tree, giving the already existing key&values (leafs)
from the original Tree + the new key&values to be added from the AddBatch call from the original Tree + the new key&values to be added from the AddBatch call
R
/ \
A *
/ \
B C
R R
/ \ / \
A * / \
/ \ / \
B C * *
/ | / \
/ | / \
/ | / \
L: A B G D
/ \
/ \
/ \
C *
/ \
/ \
/ \
... ... (nLeafs < minLeafsThreshold)
CASE C: ALMOST CASE B --> if Tree has few Leafs (but numLeafs>=minLeafsThreshold) CASE C: ALMOST CASE B --> if Tree has few Leafs (but numLeafs>=minLeafsThreshold)
@ -54,7 +68,7 @@ L: A B G D
/ \ / \
/ \ / \
/ \ / \
D E
... ... (nLeafs >= minLeafsThreshold)
@ -123,6 +137,11 @@ Algorithm decision
*/ */
const (
minLeafsThreshold = uint64(100) // nolint:gomnd // TMP WIP this will be autocalculated
nBuckets = uint64(4) // TMP WIP this will be autocalculated from
)
// AddBatchOpt is the WIP implementation of the AddBatch method in a more // AddBatchOpt is the WIP implementation of the AddBatch method in a more
// optimized approach. // optimized approach.
func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) {
@ -141,44 +160,151 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) {
return nil, err return nil, err
} }
t.tx, err = t.db.NewTx()
t.tx, err = t.db.NewTx() // TODO add t.tx.Commit()
if err != nil { if err != nil {
return nil, err return nil, err
} }
// if nLeafs==0 (root==0): CASE A
// CASE A: if nLeafs==0 (root==0)
if bytes.Equal(t.root, t.emptyHash) { if bytes.Equal(t.root, t.emptyHash) {
// sort keys & values by path // sort keys & values by path
sortKvs(kvs) sortKvs(kvs)
return t.buildTreeBottomUp(kvs) return t.buildTreeBottomUp(kvs)
} }
// if nLeafs<nBuckets: CASE B
// CASE B: if nLeafs<nBuckets
nLeafs, err := t.GetNLeafs() nLeafs, err := t.GetNLeafs()
if err != nil { if err != nil {
return nil, err return nil, err
} }
minLeafsThreshold := uint64(100) // nolint:gomnd // TMP WIP
if nLeafs < minLeafsThreshold {
// get already existing keys
aKs, aVs, err := t.getLeafs()
if nLeafs < minLeafsThreshold { // CASE B
invalids, excedents, err := t.caseB(0, kvs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
aKvs, err := t.keysValuesToKvs(aKs, aVs)
// add the excedents
for i := 0; i < len(excedents); i++ {
err = t.add(0, excedents[i].k, excedents[i].v)
if err != nil {
invalids = append(invalids, excedents[i].pos)
}
}
return invalids, nil
}
// CASE C: if nLeafs>=minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold
// available parallelization, will need to be a power of 2 (2**n)
var excedents []kv
l := int(math.Log2(float64(nBuckets)))
if nLeafs >= minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold {
// TODO move to own function
// 1. go down until level L (L=log2(nBuckets))
keysAtL, err := t.getKeysAtLevel(l + 1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// add already existing key-values to the inputted key-values
kvs = append(kvs, aKvs...)
// proceed with CASE A
sortKvs(kvs)
return t.buildTreeBottomUp(kvs)
buckets := splitInBuckets(kvs, nBuckets)
// 2. use keys at level L as roots of the subtrees under each one
var subRoots [][]byte
// TODO parallelize
for i := 0; i < len(keysAtL); i++ {
bucketTree := Tree{tx: t.tx, db: t.db, maxLevels: t.maxLevels,
hashFunction: t.hashFunction, root: keysAtL[i]}
// 3. and do CASE B for each
_, bucketExcedents, err := bucketTree.caseB(l, buckets[i])
if err != nil {
return nil, err
}
excedents = append(excedents, bucketExcedents...)
subRoots = append(subRoots, bucketTree.root)
}
// 4. go upFromKeys from the new roots of the subtrees
newRoot, err := t.upFromKeys(subRoots)
if err != nil {
return nil, err
}
t.root = newRoot
var invalids []int
for i := 0; i < len(excedents); i++ {
// Add until the level L
err = t.add(0, excedents[i].k, excedents[i].v)
if err != nil {
invalids = append(invalids, excedents[i].pos) // TODO WIP
}
}
return invalids, nil
} }
// TODO store t.root into DB
// TODO update NLeafs from DB
return nil, fmt.Errorf("UNIMPLEMENTED") return nil, fmt.Errorf("UNIMPLEMENTED")
} }
func (t *Tree) caseB(l int, kvs []kv) ([]int, []kv, error) {
// get already existing keys
aKs, aVs, err := t.getLeafs(t.root)
if err != nil {
return nil, nil, err
}
aKvs, err := t.keysValuesToKvs(aKs, aVs)
if err != nil {
return nil, nil, err
}
// add already existing key-values to the inputted key-values
kvs = append(kvs, aKvs...)
// proceed with CASE A
sortKvs(kvs)
// cutPowerOfTwo, the excedent add it as normal Tree.Add
kvsP2, kvsNonP2 := cutPowerOfTwo(kvs)
invalids, err := t.buildTreeBottomUp(kvsP2)
if err != nil {
return nil, nil, err
}
// return the excedents which will be added at the full tree at the end
return invalids, kvsNonP2, nil
}
func splitInBuckets(kvs []kv, nBuckets uint64) [][]kv {
buckets := make([][]kv, nBuckets)
// 1. classify the keyvalues into buckets
for i := 0; i < len(kvs); i++ {
pair := kvs[i]
bucketnum := keyToBucket(pair.k, int(nBuckets))
buckets[bucketnum] = append(buckets[bucketnum], pair)
}
return buckets
}
// TODO rename in a more 'real' name (calculate bucket from/for key)
func keyToBucket(k []byte, nBuckets int) int {
nLevels := int(math.Log2(float64(nBuckets)))
b := make([]int, nBuckets)
for i := 0; i < nBuckets; i++ {
b[i] = i
}
r := b
mid := len(r) / 2 //nolint:gomnd
for i := 0; i < nLevels; i++ {
if int(k[i/8]&(1<<(i%8))) != 0 {
r = r[mid:]
mid = len(r) / 2 //nolint:gomnd
} else {
r = r[:mid]
mid = len(r) / 2 //nolint:gomnd
}
}
return r[0]
}
type kv struct { type kv struct {
pos int // original position in the array pos int // original position in the array
keyPath []byte keyPath []byte
@ -241,7 +367,8 @@ func (t *Tree) kvsToKeysValues(kvs []kv) ([][]byte, [][]byte) {
} }
*/ */
// keys & values must be sorted by path, and must be length multiple of 2
// keys & values must be sorted by path, and the array ks must be length
// multiple of 2
// TODO return index of failed keyvaules // TODO return index of failed keyvaules
func (t *Tree) buildTreeBottomUp(kvs []kv) ([]int, error) { func (t *Tree) buildTreeBottomUp(kvs []kv) ([]int, error) {
// build the leafs // build the leafs
@ -258,6 +385,7 @@ func (t *Tree) buildTreeBottomUp(kvs []kv) ([]int, error) {
} }
leafKeys[i] = leafKey leafKeys[i] = leafKey
} }
// TODO parallelize t.upFromKeys until level log2(nBuckets) is reached
r, err := t.upFromKeys(leafKeys) r, err := t.upFromKeys(leafKeys)
if err != nil { if err != nil {
return nil, err return nil, err
@ -266,6 +394,8 @@ func (t *Tree) buildTreeBottomUp(kvs []kv) ([]int, error) {
return nil, nil return nil, nil
} }
// keys & values must be sorted by path, and the array ks must be length
// multiple of 2
func (t *Tree) upFromKeys(ks [][]byte) ([]byte, error) { func (t *Tree) upFromKeys(ks [][]byte) ([]byte, error) {
if len(ks) == 1 { if len(ks) == 1 {
return ks[0], nil return ks[0], nil
@ -287,9 +417,9 @@ func (t *Tree) upFromKeys(ks [][]byte) ([]byte, error) {
return t.upFromKeys(rKs) return t.upFromKeys(rKs)
} }
func (t *Tree) getLeafs() ([][]byte, [][]byte, error) {
func (t *Tree) getLeafs(root []byte) ([][]byte, [][]byte, error) {
var ks, vs [][]byte var ks, vs [][]byte
err := t.Iterate(func(k, v []byte) {
err := t.iter(root, func(k, v []byte) {
if v[0] != PrefixValueLeaf { if v[0] != PrefixValueLeaf {
return return
} }
@ -299,3 +429,52 @@ func (t *Tree) getLeafs() ([][]byte, [][]byte, error) {
}) })
return ks, vs, err return ks, vs, err
} }
func (t *Tree) getKeysAtLevel(l int) ([][]byte, error) {
var keys [][]byte
err := t.iterWithStop(t.root, 0, func(currLvl int, k, v []byte) bool {
if currLvl == l {
keys = append(keys, k)
}
if currLvl >= l {
return true // to stop the iter from going down
}
return false
})
return keys, err
}
// cutPowerOfTwo returns []kv of length that is a power of 2, and a second []kv
// with the extra elements that don't fit in a power of 2 length
func cutPowerOfTwo(kvs []kv) ([]kv, []kv) {
x := len(kvs)
if (x & (x - 1)) != 0 {
p2 := highestPowerOfTwo(x)
return kvs[:p2], kvs[p2:]
}
return kvs, nil
}
func highestPowerOfTwo(n int) int {
res := 0
for i := n; i >= 1; i-- {
if (i & (i - 1)) == 0 {
res = i
break
}
}
return res
}
// func computeSimpleAddCost(nLeafs int) int {
// // nLvls 2^nLvls
// nLvls := int(math.Log2(float64(nLeafs)))
// return nLvls * int(math.Pow(2, float64(nLvls)))
// }
//
// func computeBottomUpAddCost(nLeafs int) int {
// // 2^nLvls * 2 - 1
// nLvls := int(math.Log2(float64(nLeafs)))
// return (int(math.Pow(2, float64(nLvls))) * 2) - 1
// }

+ 221
- 0
addbatch_test.go

@ -1,6 +1,7 @@
package arbo package arbo
import ( import (
"encoding/hex"
"fmt" "fmt"
"math/big" "math/big"
"testing" "testing"
@ -99,3 +100,223 @@ func TestAddBatchCaseB(t *testing.T) {
// check that both trees roots are equal // check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree.Root()) c.Check(tree2.Root(), qt.DeepEquals, tree.Root())
} }
func TestGetKeysAtLevel(t *testing.T) {
c := qt.New(t)
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree.db.Close()
for i := 0; i < 32; i++ {
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(int64(i * 2)))
if err := tree.Add(k, v); err != nil {
t.Fatal(err)
}
}
keys, err := tree.getKeysAtLevel(2)
c.Assert(err, qt.IsNil)
expected := []string{
"a5d5f14fce7348e40751496cf25d107d91b0bd043435b9577d778a01f8aa6111",
"e9e8dd9b28a7f81d1ff34cb5cefc0146dd848b31031a427b79bdadb62e7f6910",
}
for i := 0; i < len(keys); i++ {
c.Assert(hex.EncodeToString(keys[i]), qt.Equals, expected[i])
}
keys, err = tree.getKeysAtLevel(3)
c.Assert(err, qt.IsNil)
expected = []string{
"9f12c13e52bca96ad4882a26558e48ab67ddd63e062b839207e893d961390f01",
"16d246dd6826ec7346c7328f11c4261facf82d4689f33263ff6e207956a77f21",
"4a22cc901c6337daa17a431fa20170684b710e5f551509511492ec24e81a8f2f",
"470d61abcbd154977bffc9a9ec5a8daff0caabcf2a25e8441f604c79daa0f82d",
}
for i := 0; i < len(keys); i++ {
c.Assert(hex.EncodeToString(keys[i]), qt.Equals, expected[i])
}
keys, err = tree.getKeysAtLevel(4)
c.Assert(err, qt.IsNil)
expected = []string{
"7a5d1c81f7b96318012de3417e53d4f13df5b1337718651cd29d0cb0a66edd20",
"3408213e4e844bdf3355eb8781c74e31626812898c2dbe141ed6d2c92256fc1c",
"dfd8a4d0b6954a3e9f3892e655b58d456eeedf9367f27dfdd9bc2dd6a5577312",
"9e99fbec06fb2a6725997c12c4995f62725eb4cce4808523a5a5e80cca64b007",
"0befa1e070231dbf4e8ff841c05878cdec823e0c09594c24910a248b3ff5a628",
"b7131b0a15c772a57005a4dc5d0d6dd4b3414f5d9ee7408ce5e86c5ab3520e04",
"6d1abe0364077846a56bab1deb1a04883eb796b74fe531a7676a9a370f83ab21",
"4270116394bede69cf9cd72069eca018238080380bef5de75be8dcbbe968e105",
}
for i := 0; i < len(keys); i++ {
c.Assert(hex.EncodeToString(keys[i]), qt.Equals, expected[i])
}
}
func TestSplitInBuckets(t *testing.T) {
c := qt.New(t)
nLeafs := 16
kvs := make([]kv, nLeafs)
for i := 0; i < nLeafs; i++ {
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(int64(i * 2)))
keyPath := make([]byte, 32)
copy(keyPath[:], k)
kvs[i].pos = i
kvs[i].keyPath = k
kvs[i].k = k
kvs[i].v = v
}
// check keyToBucket results for 4 buckets & 8 keys
c.Assert(keyToBucket(kvs[0].k, 4), qt.Equals, 0)
c.Assert(keyToBucket(kvs[1].k, 4), qt.Equals, 2)
c.Assert(keyToBucket(kvs[2].k, 4), qt.Equals, 1)
c.Assert(keyToBucket(kvs[3].k, 4), qt.Equals, 3)
c.Assert(keyToBucket(kvs[4].k, 4), qt.Equals, 0)
c.Assert(keyToBucket(kvs[5].k, 4), qt.Equals, 2)
c.Assert(keyToBucket(kvs[6].k, 4), qt.Equals, 1)
c.Assert(keyToBucket(kvs[7].k, 4), qt.Equals, 3)
// check keyToBucket results for 8 buckets & 8 keys
c.Assert(keyToBucket(kvs[0].k, 8), qt.Equals, 0)
c.Assert(keyToBucket(kvs[1].k, 8), qt.Equals, 4)
c.Assert(keyToBucket(kvs[2].k, 8), qt.Equals, 2)
c.Assert(keyToBucket(kvs[3].k, 8), qt.Equals, 6)
c.Assert(keyToBucket(kvs[4].k, 8), qt.Equals, 1)
c.Assert(keyToBucket(kvs[5].k, 8), qt.Equals, 5)
c.Assert(keyToBucket(kvs[6].k, 8), qt.Equals, 3)
c.Assert(keyToBucket(kvs[7].k, 8), qt.Equals, 7)
buckets := splitInBuckets(kvs, 4)
expected := [][]string{
{
"00000000", // bucket 0
"08000000",
"04000000",
"0c000000",
},
{
"02000000", // bucket 1
"0a000000",
"06000000",
"0e000000",
},
{
"01000000", // bucket 2
"09000000",
"05000000",
"0d000000",
},
{
"03000000", // bucket 3
"0b000000",
"07000000",
"0f000000",
},
}
for i := 0; i < len(buckets); i++ {
sortKvs(buckets[i])
c.Assert(len(buckets[i]), qt.Equals, len(expected[i]))
for j := 0; j < len(buckets[i]); j++ {
c.Check(hex.EncodeToString(buckets[i][j].k[:4]), qt.Equals, expected[i][j])
}
}
}
func TestAddBatchCaseC(t *testing.T) {
c := qt.New(t)
nLeafs := 1024
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree.db.Close()
start := time.Now()
for i := 0; i < nLeafs; i++ {
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(int64(i * 2)))
if err := tree.Add(k, v); err != nil {
t.Fatal(err)
}
}
fmt.Println(time.Since(start))
tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree2.db.Close()
// add the initial leafs to fill a bit the tree before calling the
// AddBatch method
for i := 0; i < 101; i++ { // TMP TODO use const minLeafsThreshold-1 once ready
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(int64(i * 2)))
if err := tree2.Add(k, v); err != nil {
t.Fatal(err)
}
}
// tree2.PrintGraphviz(nil)
var keys, values [][]byte
for i := 101; i < nLeafs; i++ {
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(int64(i * 2)))
keys = append(keys, k)
values = append(values, v)
}
start = time.Now()
indexes, err := tree2.AddBatchOpt(keys, values)
c.Assert(err, qt.IsNil)
fmt.Println(time.Since(start))
c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree.Root())
// tree.PrintGraphviz(nil)
// tree2.PrintGraphviz(nil)
// // tree.PrintGraphvizFirstNLevels(nil, 4)
// // tree2.PrintGraphvizFirstNLevels(nil, 4)
// fmt.Println("TREE")
// printLeafs("t1.txt", tree)
// fmt.Println("TREE2")
// printLeafs("t2.txt", tree2)
}
// func printLeafs(name string, t *Tree) {
// w := bytes.NewBufferString("")
//
// err := t.Iterate(func(k, v []byte) {
// if v[0] != PrefixValueLeaf {
// return
// }
// leafK, _ := readLeafValue(v)
// fmt.Fprintf(w, hex.EncodeToString(leafK[:4])+"\n")
// })
// if err != nil {
// panic(err)
// }
// err = ioutil.WriteFile(name, w.Bytes(), 0644)
// if err != nil {
// panic(err)
// }
//
// }
// func TestComputeCosts(t *testing.T) {
// fmt.Println(computeSimpleAddCost(10))
// fmt.Println(computeBottomUpAddCost(10))
//
// fmt.Println(computeSimpleAddCost(1024))
// fmt.Println(computeBottomUpAddCost(1024))
// }
// TODO test tree with nLeafs > minLeafsThreshold, but that at level L, there is
// less keys than nBuckets (so CASE C could be applied if first few leafs are
// added to balance the tree)

+ 65
- 29
tree.go

@ -128,7 +128,7 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
var indexes []int var indexes []int
for i := 0; i < len(keys); i++ { for i := 0; i < len(keys); i++ {
err = t.add(keys[i], values[i])
err = t.add(0, keys[i], values[i])
if err != nil { if err != nil {
indexes = append(indexes, i) indexes = append(indexes, i)
} }
@ -163,7 +163,7 @@ func (t *Tree) Add(k, v []byte) error {
return err return err
} }
err = t.add(k, v)
err = t.add(0, k, v) // add from level 0
if err != nil { if err != nil {
return err return err
} }
@ -178,7 +178,7 @@ func (t *Tree) Add(k, v []byte) error {
return t.tx.Commit() return t.tx.Commit()
} }
func (t *Tree) add(k, v []byte) error {
func (t *Tree) add(fromLvl int, k, v []byte) error {
// TODO check validity of key & value (for the Tree.HashFunction type) // TODO check validity of key & value (for the Tree.HashFunction type)
keyPath := make([]byte, t.hashFunction.Len()) keyPath := make([]byte, t.hashFunction.Len())
@ -187,7 +187,7 @@ func (t *Tree) add(k, v []byte) error {
path := getPath(t.maxLevels, keyPath) path := getPath(t.maxLevels, keyPath)
// go down to the leaf // go down to the leaf
var siblings [][]byte var siblings [][]byte
_, _, siblings, err := t.down(k, t.root, siblings, path, 0, false)
_, _, siblings, err := t.down(k, t.root, siblings, path, fromLvl, false)
if err != nil { if err != nil {
return err return err
} }
@ -217,9 +217,9 @@ func (t *Tree) add(k, v []byte) error {
// down goes down to the leaf recursively // down goes down to the leaf recursively
func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, func (t *Tree) down(newKey, currKey []byte, siblings [][]byte,
path []bool, l int, getLeaf bool) (
path []bool, currLvl int, getLeaf bool) (
[]byte, []byte, [][]byte, error) { []byte, []byte, [][]byte, error) {
if l > t.maxLevels-1 {
if currLvl > t.maxLevels-1 {
return nil, nil, nil, fmt.Errorf("max level") return nil, nil, nil, fmt.Errorf("max level")
} }
var err error var err error
@ -254,7 +254,7 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte,
// if currKey is already used, go down until paths diverge // if currKey is already used, go down until paths diverge
oldPath := getPath(t.maxLevels, oldLeafKeyFull) oldPath := getPath(t.maxLevels, oldLeafKeyFull)
siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, l)
siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, currLvl)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -267,16 +267,16 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte,
PrefixValueLen+t.hashFunction.Len()*2, len(currValue)) PrefixValueLen+t.hashFunction.Len()*2, len(currValue))
} }
// collect siblings while going down // collect siblings while going down
if path[l] {
if path[currLvl] {
// right // right
lChild, rChild := readIntermediateChilds(currValue) lChild, rChild := readIntermediateChilds(currValue)
siblings = append(siblings, lChild) siblings = append(siblings, lChild)
return t.down(newKey, rChild, siblings, path, l+1, getLeaf)
return t.down(newKey, rChild, siblings, path, currLvl+1, getLeaf)
} }
// left // left
lChild, rChild := readIntermediateChilds(currValue) lChild, rChild := readIntermediateChilds(currValue)
siblings = append(siblings, rChild) siblings = append(siblings, rChild)
return t.down(newKey, lChild, siblings, path, l+1, getLeaf)
return t.down(newKey, lChild, siblings, path, currLvl+1, getLeaf)
default: default:
return nil, nil, nil, fmt.Errorf("invalid value") return nil, nil, nil, fmt.Errorf("invalid value")
} }
@ -285,16 +285,16 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte,
// downVirtually is used when in a leaf already exists, and a new leaf which // downVirtually is used when in a leaf already exists, and a new leaf which
// shares the path until the existing leaf is being added // shares the path until the existing leaf is being added
func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath, func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
newPath []bool, l int) ([][]byte, error) {
newPath []bool, currLvl int) ([][]byte, error) {
var err error var err error
if l > t.maxLevels-1 {
return nil, fmt.Errorf("max virtual level %d", l)
if currLvl > t.maxLevels-1 {
return nil, fmt.Errorf("max virtual level %d", currLvl)
} }
if oldPath[l] == newPath[l] {
if oldPath[currLvl] == newPath[currLvl] {
siblings = append(siblings, t.emptyHash) siblings = append(siblings, t.emptyHash)
siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1)
siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, currLvl+1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -307,16 +307,16 @@ func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
} }
// up goes up recursively updating the intermediate nodes // up goes up recursively updating the intermediate nodes
func (t *Tree) up(key []byte, siblings [][]byte, path []bool, l int) ([]byte, error) {
func (t *Tree) up(key []byte, siblings [][]byte, path []bool, currLvl int) ([]byte, error) {
var k, v []byte var k, v []byte
var err error var err error
if path[l] {
k, v, err = newIntermediate(t.hashFunction, siblings[l], key)
if path[currLvl] {
k, v, err = newIntermediate(t.hashFunction, siblings[currLvl], key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
k, v, err = newIntermediate(t.hashFunction, key, siblings[l])
k, v, err = newIntermediate(t.hashFunction, key, siblings[currLvl])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -326,12 +326,12 @@ func (t *Tree) up(key []byte, siblings [][]byte, path []bool, l int) ([]byte, er
return nil, err return nil, err
} }
if l == 0 {
if currLvl == 0 {
// reached the root // reached the root
return k, nil return k, nil
} }
return t.up(k, siblings, path, l-1)
return t.up(k, siblings, path, currLvl-1)
} }
func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) { func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
@ -666,24 +666,36 @@ func (t *Tree) Iterate(f func([]byte, []byte)) error {
return t.iter(t.root, f) return t.iter(t.root, f)
} }
func (t *Tree) iter(k []byte, f func([]byte, []byte)) error {
// IterateWithStop does the same than Iterate, but with int for the current
// level, and a boolean parameter used by the passed function, is to indicate to
// stop iterating on the branch when the method returns 'true'.
func (t *Tree) IterateWithStop(f func(int, []byte, []byte) bool) error {
t.updateAccessTime()
return t.iterWithStop(t.root, 0, f)
}
func (t *Tree) iterWithStop(k []byte, currLevel int, f func(int, []byte, []byte) bool) error {
v, err := t.dbGet(k) v, err := t.dbGet(k)
if err != nil { if err != nil {
return err return err
} }
currLevel++
switch v[0] { switch v[0] {
case PrefixValueEmpty: case PrefixValueEmpty:
f(k, v)
f(currLevel, k, v)
case PrefixValueLeaf: case PrefixValueLeaf:
f(k, v)
f(currLevel, k, v)
case PrefixValueIntermediate: case PrefixValueIntermediate:
f(k, v)
stop := f(currLevel, k, v)
if stop {
return nil
}
l, r := readIntermediateChilds(v) l, r := readIntermediateChilds(v)
if err = t.iter(l, f); err != nil {
if err = t.iterWithStop(l, currLevel, f); err != nil {
return err return err
} }
if err = t.iter(r, f); err != nil {
if err = t.iterWithStop(r, currLevel, f); err != nil {
return err return err
} }
default: default:
@ -692,6 +704,14 @@ func (t *Tree) iter(k []byte, f func([]byte, []byte)) error {
return nil return nil
} }
func (t *Tree) iter(k []byte, f func([]byte, []byte)) error {
f2 := func(currLvl int, k, v []byte) bool {
f(k, v)
return false
}
return t.iterWithStop(k, 0, f2)
}
// Dump exports all the Tree leafs in a byte array of length: // Dump exports all the Tree leafs in a byte array of length:
// [ N * (2+len(k+v)) ]. Where N is the number of key-values, and for each k+v: // [ N * (2+len(k+v)) ]. Where N is the number of key-values, and for each k+v:
// [ 1 byte | 1 byte | S bytes | len(v) bytes ] // [ 1 byte | 1 byte | S bytes | len(v) bytes ]
@ -768,12 +788,22 @@ func (t *Tree) ImportDump(b []byte) error {
// Graphviz iterates across the full tree to generate a string Graphviz // Graphviz iterates across the full tree to generate a string Graphviz
// representation of the tree and writes it to w // representation of the tree and writes it to w
func (t *Tree) Graphviz(w io.Writer, rootKey []byte) error { func (t *Tree) Graphviz(w io.Writer, rootKey []byte) error {
return t.GraphvizFirstNLevels(w, rootKey, t.maxLevels)
}
// GraphvizFirstNLevels iterates across the first NLevels of the tree to
// generate a string Graphviz representation of the first NLevels of the tree
// and writes it to w
func (t *Tree) GraphvizFirstNLevels(w io.Writer, rootKey []byte, untilLvl int) error {
fmt.Fprintf(w, `digraph hierarchy { fmt.Fprintf(w, `digraph hierarchy {
node [fontname=Monospace,fontsize=10,shape=box] node [fontname=Monospace,fontsize=10,shape=box]
`) `)
nChars := 4 nChars := 4
nEmpties := 0 nEmpties := 0
err := t.Iterate(func(k, v []byte) {
err := t.iterWithStop(t.root, 0, func(currLvl int, k, v []byte) bool {
if currLvl == untilLvl {
return true // to stop the iter from going down
}
switch v[0] { switch v[0] {
case PrefixValueEmpty: case PrefixValueEmpty:
case PrefixValueLeaf: case PrefixValueLeaf:
@ -807,6 +837,7 @@ node [fontname=Monospace,fontsize=10,shape=box]
fmt.Fprint(w, eStr) fmt.Fprint(w, eStr)
default: default:
} }
return false
}) })
fmt.Fprintf(w, "}\n") fmt.Fprintf(w, "}\n")
return err return err
@ -814,13 +845,18 @@ node [fontname=Monospace,fontsize=10,shape=box]
// PrintGraphviz prints the output of Tree.Graphviz // PrintGraphviz prints the output of Tree.Graphviz
func (t *Tree) PrintGraphviz(rootKey []byte) error { func (t *Tree) PrintGraphviz(rootKey []byte) error {
return t.PrintGraphvizFirstNLevels(rootKey, t.maxLevels)
}
// PrintGraphvizFirstNLevels prints the output of Tree.GraphvizFirstNLevels
func (t *Tree) PrintGraphvizFirstNLevels(rootKey []byte, untilLvl int) error {
if rootKey == nil { if rootKey == nil {
rootKey = t.Root() rootKey = t.Root()
} }
w := bytes.NewBufferString("") w := bytes.NewBufferString("")
fmt.Fprintf(w, fmt.Fprintf(w,
"--------\nGraphviz of the Tree with Root "+hex.EncodeToString(rootKey)+":\n") "--------\nGraphviz of the Tree with Root "+hex.EncodeToString(rootKey)+":\n")
err := t.Graphviz(w, nil)
err := t.GraphvizFirstNLevels(w, nil, untilLvl)
if err != nil { if err != nil {
fmt.Println(w) fmt.Println(w)
return err return err

+ 1
- 1
utils.go

@ -13,7 +13,7 @@ func SwapEndianness(b []byte) []byte {
// BigIntToBytes converts a *big.Int into a byte array in Little-Endian // BigIntToBytes converts a *big.Int into a byte array in Little-Endian
func BigIntToBytes(bi *big.Int) []byte { func BigIntToBytes(bi *big.Int) []byte {
var b [32]byte
var b [32]byte // TODO make the length depending on the tree.hashFunction.Len()
copy(b[:], SwapEndianness(bi.Bytes())) copy(b[:], SwapEndianness(bi.Bytes()))
return b[:] return b[:]
} }

Loading…
Cancel
Save