From a3473079deb5dddf15068a17a8444c8136fe1106 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Fri, 23 Apr 2021 16:49:31 +0200 Subject: [PATCH] Add AddBatch CaseC CASE C: ALMOST CASE B --> if Tree has few Leafs (but numLeafs>=minLeafsThreshold) ============================================================================== - Use A, B, G, F as Roots of subtrees - Do CASE B for each subtree - Then go from L to the Root R / \ / \ / \ * * / | / \ / | / \ / | / \ L: A B G D / \ / \ / \ C * / \ / \ / \ ... ... (nLeafs >= minLeafsThreshold) --- addbatch.go | 223 ++++++++++++++++++++++++++++++++++++++++++----- addbatch_test.go | 221 ++++++++++++++++++++++++++++++++++++++++++++++ tree.go | 94 ++++++++++++++------ utils.go | 2 +- 4 files changed, 488 insertions(+), 52 deletions(-) diff --git a/addbatch.go b/addbatch.go index 63fa97d..5622973 100644 --- a/addbatch.go +++ b/addbatch.go @@ -3,6 +3,7 @@ package arbo import ( "bytes" "fmt" + "math" "sort" ) @@ -25,11 +26,24 @@ the leafs) - Do CASE A for the new Tree, giving the already existing key&values (leafs) from the original Tree + the new key&values to be added from the AddBatch call - R - / \ - A * - / \ - B C + + R R + / \ / \ + A * / \ + / \ / \ + B C * * + / | / \ + / | / \ + / | / \ + L: A B G D + / \ + / \ + / \ + C * + / \ + / \ + / \ + ... ... (nLeafs < minLeafsThreshold) CASE C: ALMOST CASE B --> if Tree has few Leafs (but numLeafs>=minLeafsThreshold) @@ -54,7 +68,7 @@ L: A B G D / \ / \ / \ - D E + ... ... (nLeafs >= minLeafsThreshold) @@ -123,6 +137,11 @@ Algorithm decision */ +const ( + minLeafsThreshold = uint64(100) // nolint:gomnd // TMP WIP this will be autocalculated + nBuckets = uint64(4) // TMP WIP this will be autocalculated from +) + // AddBatchOpt is the WIP implementation of the AddBatch method in a more // optimized approach. func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { @@ -141,44 +160,151 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { return nil, err } - t.tx, err = t.db.NewTx() + t.tx, err = t.db.NewTx() // TODO add t.tx.Commit() if err != nil { return nil, err } - // if nLeafs==0 (root==0): CASE A + // CASE A: if nLeafs==0 (root==0) if bytes.Equal(t.root, t.emptyHash) { // sort keys & values by path sortKvs(kvs) return t.buildTreeBottomUp(kvs) } - // if nLeafs=minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold + // available parallelization, will need to be a power of 2 (2**n) + var excedents []kv + l := int(math.Log2(float64(nBuckets))) + if nLeafs >= minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold { + // TODO move to own function + // 1. go down until level L (L=log2(nBuckets)) + keysAtL, err := t.getKeysAtLevel(l + 1) if err != nil { return nil, err } - // add already existing key-values to the inputted key-values - kvs = append(kvs, aKvs...) - // proceed with CASE A - sortKvs(kvs) - return t.buildTreeBottomUp(kvs) + + buckets := splitInBuckets(kvs, nBuckets) + + // 2. use keys at level L as roots of the subtrees under each one + var subRoots [][]byte + // TODO parallelize + for i := 0; i < len(keysAtL); i++ { + bucketTree := Tree{tx: t.tx, db: t.db, maxLevels: t.maxLevels, + hashFunction: t.hashFunction, root: keysAtL[i]} + + // 3. and do CASE B for each + _, bucketExcedents, err := bucketTree.caseB(l, buckets[i]) + if err != nil { + return nil, err + } + excedents = append(excedents, bucketExcedents...) + subRoots = append(subRoots, bucketTree.root) + } + // 4. go upFromKeys from the new roots of the subtrees + newRoot, err := t.upFromKeys(subRoots) + if err != nil { + return nil, err + } + t.root = newRoot + + var invalids []int + for i := 0; i < len(excedents); i++ { + // Add until the level L + err = t.add(0, excedents[i].k, excedents[i].v) + if err != nil { + invalids = append(invalids, excedents[i].pos) // TODO WIP + } + } + + return invalids, nil } + // TODO store t.root into DB + // TODO update NLeafs from DB + return nil, fmt.Errorf("UNIMPLEMENTED") } +func (t *Tree) caseB(l int, kvs []kv) ([]int, []kv, error) { + // get already existing keys + aKs, aVs, err := t.getLeafs(t.root) + if err != nil { + return nil, nil, err + } + aKvs, err := t.keysValuesToKvs(aKs, aVs) + if err != nil { + return nil, nil, err + } + // add already existing key-values to the inputted key-values + kvs = append(kvs, aKvs...) + + // proceed with CASE A + sortKvs(kvs) + + // cutPowerOfTwo, the excedent add it as normal Tree.Add + kvsP2, kvsNonP2 := cutPowerOfTwo(kvs) + invalids, err := t.buildTreeBottomUp(kvsP2) + if err != nil { + return nil, nil, err + } + // return the excedents which will be added at the full tree at the end + return invalids, kvsNonP2, nil +} + +func splitInBuckets(kvs []kv, nBuckets uint64) [][]kv { + buckets := make([][]kv, nBuckets) + // 1. classify the keyvalues into buckets + for i := 0; i < len(kvs); i++ { + pair := kvs[i] + + bucketnum := keyToBucket(pair.k, int(nBuckets)) + buckets[bucketnum] = append(buckets[bucketnum], pair) + } + return buckets +} + +// TODO rename in a more 'real' name (calculate bucket from/for key) +func keyToBucket(k []byte, nBuckets int) int { + nLevels := int(math.Log2(float64(nBuckets))) + b := make([]int, nBuckets) + for i := 0; i < nBuckets; i++ { + b[i] = i + } + r := b + mid := len(r) / 2 //nolint:gomnd + for i := 0; i < nLevels; i++ { + if int(k[i/8]&(1<<(i%8))) != 0 { + r = r[mid:] + mid = len(r) / 2 //nolint:gomnd + } else { + r = r[:mid] + mid = len(r) / 2 //nolint:gomnd + } + } + return r[0] +} + type kv struct { pos int // original position in the array keyPath []byte @@ -241,7 +367,8 @@ func (t *Tree) kvsToKeysValues(kvs []kv) ([][]byte, [][]byte) { } */ -// keys & values must be sorted by path, and must be length multiple of 2 +// keys & values must be sorted by path, and the array ks must be length +// multiple of 2 // TODO return index of failed keyvaules func (t *Tree) buildTreeBottomUp(kvs []kv) ([]int, error) { // build the leafs @@ -258,6 +385,7 @@ func (t *Tree) buildTreeBottomUp(kvs []kv) ([]int, error) { } leafKeys[i] = leafKey } + // TODO parallelize t.upFromKeys until level log2(nBuckets) is reached r, err := t.upFromKeys(leafKeys) if err != nil { return nil, err @@ -266,6 +394,8 @@ func (t *Tree) buildTreeBottomUp(kvs []kv) ([]int, error) { return nil, nil } +// keys & values must be sorted by path, and the array ks must be length +// multiple of 2 func (t *Tree) upFromKeys(ks [][]byte) ([]byte, error) { if len(ks) == 1 { return ks[0], nil @@ -287,9 +417,9 @@ func (t *Tree) upFromKeys(ks [][]byte) ([]byte, error) { return t.upFromKeys(rKs) } -func (t *Tree) getLeafs() ([][]byte, [][]byte, error) { +func (t *Tree) getLeafs(root []byte) ([][]byte, [][]byte, error) { var ks, vs [][]byte - err := t.Iterate(func(k, v []byte) { + err := t.iter(root, func(k, v []byte) { if v[0] != PrefixValueLeaf { return } @@ -299,3 +429,52 @@ func (t *Tree) getLeafs() ([][]byte, [][]byte, error) { }) return ks, vs, err } + +func (t *Tree) getKeysAtLevel(l int) ([][]byte, error) { + var keys [][]byte + err := t.iterWithStop(t.root, 0, func(currLvl int, k, v []byte) bool { + if currLvl == l { + keys = append(keys, k) + } + if currLvl >= l { + return true // to stop the iter from going down + } + return false + }) + + return keys, err +} + +// cutPowerOfTwo returns []kv of length that is a power of 2, and a second []kv +// with the extra elements that don't fit in a power of 2 length +func cutPowerOfTwo(kvs []kv) ([]kv, []kv) { + x := len(kvs) + if (x & (x - 1)) != 0 { + p2 := highestPowerOfTwo(x) + return kvs[:p2], kvs[p2:] + } + return kvs, nil +} + +func highestPowerOfTwo(n int) int { + res := 0 + for i := n; i >= 1; i-- { + if (i & (i - 1)) == 0 { + res = i + break + } + } + return res +} + +// func computeSimpleAddCost(nLeafs int) int { +// // nLvls 2^nLvls +// nLvls := int(math.Log2(float64(nLeafs))) +// return nLvls * int(math.Pow(2, float64(nLvls))) +// } +// +// func computeBottomUpAddCost(nLeafs int) int { +// // 2^nLvls * 2 - 1 +// nLvls := int(math.Log2(float64(nLeafs))) +// return (int(math.Pow(2, float64(nLvls))) * 2) - 1 +// } diff --git a/addbatch_test.go b/addbatch_test.go index a412ad0..30b2a03 100644 --- a/addbatch_test.go +++ b/addbatch_test.go @@ -1,6 +1,7 @@ package arbo import ( + "encoding/hex" "fmt" "math/big" "testing" @@ -99,3 +100,223 @@ func TestAddBatchCaseB(t *testing.T) { // check that both trees roots are equal c.Check(tree2.Root(), qt.DeepEquals, tree.Root()) } + +func TestGetKeysAtLevel(t *testing.T) { + c := qt.New(t) + + tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + defer tree.db.Close() + + for i := 0; i < 32; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + if err := tree.Add(k, v); err != nil { + t.Fatal(err) + } + } + + keys, err := tree.getKeysAtLevel(2) + c.Assert(err, qt.IsNil) + expected := []string{ + "a5d5f14fce7348e40751496cf25d107d91b0bd043435b9577d778a01f8aa6111", + "e9e8dd9b28a7f81d1ff34cb5cefc0146dd848b31031a427b79bdadb62e7f6910", + } + for i := 0; i < len(keys); i++ { + c.Assert(hex.EncodeToString(keys[i]), qt.Equals, expected[i]) + } + + keys, err = tree.getKeysAtLevel(3) + c.Assert(err, qt.IsNil) + expected = []string{ + "9f12c13e52bca96ad4882a26558e48ab67ddd63e062b839207e893d961390f01", + "16d246dd6826ec7346c7328f11c4261facf82d4689f33263ff6e207956a77f21", + "4a22cc901c6337daa17a431fa20170684b710e5f551509511492ec24e81a8f2f", + "470d61abcbd154977bffc9a9ec5a8daff0caabcf2a25e8441f604c79daa0f82d", + } + for i := 0; i < len(keys); i++ { + c.Assert(hex.EncodeToString(keys[i]), qt.Equals, expected[i]) + } + + keys, err = tree.getKeysAtLevel(4) + c.Assert(err, qt.IsNil) + expected = []string{ + "7a5d1c81f7b96318012de3417e53d4f13df5b1337718651cd29d0cb0a66edd20", + "3408213e4e844bdf3355eb8781c74e31626812898c2dbe141ed6d2c92256fc1c", + "dfd8a4d0b6954a3e9f3892e655b58d456eeedf9367f27dfdd9bc2dd6a5577312", + "9e99fbec06fb2a6725997c12c4995f62725eb4cce4808523a5a5e80cca64b007", + "0befa1e070231dbf4e8ff841c05878cdec823e0c09594c24910a248b3ff5a628", + "b7131b0a15c772a57005a4dc5d0d6dd4b3414f5d9ee7408ce5e86c5ab3520e04", + "6d1abe0364077846a56bab1deb1a04883eb796b74fe531a7676a9a370f83ab21", + "4270116394bede69cf9cd72069eca018238080380bef5de75be8dcbbe968e105", + } + for i := 0; i < len(keys); i++ { + c.Assert(hex.EncodeToString(keys[i]), qt.Equals, expected[i]) + } +} + +func TestSplitInBuckets(t *testing.T) { + c := qt.New(t) + + nLeafs := 16 + kvs := make([]kv, nLeafs) + for i := 0; i < nLeafs; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + keyPath := make([]byte, 32) + copy(keyPath[:], k) + kvs[i].pos = i + kvs[i].keyPath = k + kvs[i].k = k + kvs[i].v = v + } + + // check keyToBucket results for 4 buckets & 8 keys + c.Assert(keyToBucket(kvs[0].k, 4), qt.Equals, 0) + c.Assert(keyToBucket(kvs[1].k, 4), qt.Equals, 2) + c.Assert(keyToBucket(kvs[2].k, 4), qt.Equals, 1) + c.Assert(keyToBucket(kvs[3].k, 4), qt.Equals, 3) + c.Assert(keyToBucket(kvs[4].k, 4), qt.Equals, 0) + c.Assert(keyToBucket(kvs[5].k, 4), qt.Equals, 2) + c.Assert(keyToBucket(kvs[6].k, 4), qt.Equals, 1) + c.Assert(keyToBucket(kvs[7].k, 4), qt.Equals, 3) + + // check keyToBucket results for 8 buckets & 8 keys + c.Assert(keyToBucket(kvs[0].k, 8), qt.Equals, 0) + c.Assert(keyToBucket(kvs[1].k, 8), qt.Equals, 4) + c.Assert(keyToBucket(kvs[2].k, 8), qt.Equals, 2) + c.Assert(keyToBucket(kvs[3].k, 8), qt.Equals, 6) + c.Assert(keyToBucket(kvs[4].k, 8), qt.Equals, 1) + c.Assert(keyToBucket(kvs[5].k, 8), qt.Equals, 5) + c.Assert(keyToBucket(kvs[6].k, 8), qt.Equals, 3) + c.Assert(keyToBucket(kvs[7].k, 8), qt.Equals, 7) + + buckets := splitInBuckets(kvs, 4) + + expected := [][]string{ + { + "00000000", // bucket 0 + "08000000", + "04000000", + "0c000000", + }, + { + "02000000", // bucket 1 + "0a000000", + "06000000", + "0e000000", + }, + { + "01000000", // bucket 2 + "09000000", + "05000000", + "0d000000", + }, + { + "03000000", // bucket 3 + "0b000000", + "07000000", + "0f000000", + }, + } + + for i := 0; i < len(buckets); i++ { + sortKvs(buckets[i]) + c.Assert(len(buckets[i]), qt.Equals, len(expected[i])) + for j := 0; j < len(buckets[i]); j++ { + c.Check(hex.EncodeToString(buckets[i][j].k[:4]), qt.Equals, expected[i][j]) + } + } +} + +func TestAddBatchCaseC(t *testing.T) { + c := qt.New(t) + + nLeafs := 1024 + + tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + defer tree.db.Close() + + start := time.Now() + for i := 0; i < nLeafs; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + if err := tree.Add(k, v); err != nil { + t.Fatal(err) + } + } + fmt.Println(time.Since(start)) + + tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + defer tree2.db.Close() + + // add the initial leafs to fill a bit the tree before calling the + // AddBatch method + for i := 0; i < 101; i++ { // TMP TODO use const minLeafsThreshold-1 once ready + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + if err := tree2.Add(k, v); err != nil { + t.Fatal(err) + } + } + // tree2.PrintGraphviz(nil) + + var keys, values [][]byte + for i := 101; i < nLeafs; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + keys = append(keys, k) + values = append(values, v) + } + start = time.Now() + indexes, err := tree2.AddBatchOpt(keys, values) + c.Assert(err, qt.IsNil) + fmt.Println(time.Since(start)) + c.Check(len(indexes), qt.Equals, 0) + + // check that both trees roots are equal + c.Check(tree2.Root(), qt.DeepEquals, tree.Root()) + + // tree.PrintGraphviz(nil) + // tree2.PrintGraphviz(nil) + // // tree.PrintGraphvizFirstNLevels(nil, 4) + // // tree2.PrintGraphvizFirstNLevels(nil, 4) + // fmt.Println("TREE") + // printLeafs("t1.txt", tree) + // fmt.Println("TREE2") + // printLeafs("t2.txt", tree2) +} + +// func printLeafs(name string, t *Tree) { +// w := bytes.NewBufferString("") +// +// err := t.Iterate(func(k, v []byte) { +// if v[0] != PrefixValueLeaf { +// return +// } +// leafK, _ := readLeafValue(v) +// fmt.Fprintf(w, hex.EncodeToString(leafK[:4])+"\n") +// }) +// if err != nil { +// panic(err) +// } +// err = ioutil.WriteFile(name, w.Bytes(), 0644) +// if err != nil { +// panic(err) +// } +// +// } + +// func TestComputeCosts(t *testing.T) { +// fmt.Println(computeSimpleAddCost(10)) +// fmt.Println(computeBottomUpAddCost(10)) +// +// fmt.Println(computeSimpleAddCost(1024)) +// fmt.Println(computeBottomUpAddCost(1024)) +// } + +// TODO test tree with nLeafs > minLeafsThreshold, but that at level L, there is +// less keys than nBuckets (so CASE C could be applied if first few leafs are +// added to balance the tree) diff --git a/tree.go b/tree.go index 186de97..cdcb1e8 100644 --- a/tree.go +++ b/tree.go @@ -128,7 +128,7 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { var indexes []int for i := 0; i < len(keys); i++ { - err = t.add(keys[i], values[i]) + err = t.add(0, keys[i], values[i]) if err != nil { indexes = append(indexes, i) } @@ -163,7 +163,7 @@ func (t *Tree) Add(k, v []byte) error { return err } - err = t.add(k, v) + err = t.add(0, k, v) // add from level 0 if err != nil { return err } @@ -178,7 +178,7 @@ func (t *Tree) Add(k, v []byte) error { return t.tx.Commit() } -func (t *Tree) add(k, v []byte) error { +func (t *Tree) add(fromLvl int, k, v []byte) error { // TODO check validity of key & value (for the Tree.HashFunction type) keyPath := make([]byte, t.hashFunction.Len()) @@ -187,7 +187,7 @@ func (t *Tree) add(k, v []byte) error { path := getPath(t.maxLevels, keyPath) // go down to the leaf var siblings [][]byte - _, _, siblings, err := t.down(k, t.root, siblings, path, 0, false) + _, _, siblings, err := t.down(k, t.root, siblings, path, fromLvl, false) if err != nil { return err } @@ -217,9 +217,9 @@ func (t *Tree) add(k, v []byte) error { // down goes down to the leaf recursively func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, - path []bool, l int, getLeaf bool) ( + path []bool, currLvl int, getLeaf bool) ( []byte, []byte, [][]byte, error) { - if l > t.maxLevels-1 { + if currLvl > t.maxLevels-1 { return nil, nil, nil, fmt.Errorf("max level") } var err error @@ -254,7 +254,7 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, // if currKey is already used, go down until paths diverge oldPath := getPath(t.maxLevels, oldLeafKeyFull) - siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, l) + siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, currLvl) if err != nil { return nil, nil, nil, err } @@ -267,16 +267,16 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, PrefixValueLen+t.hashFunction.Len()*2, len(currValue)) } // collect siblings while going down - if path[l] { + if path[currLvl] { // right lChild, rChild := readIntermediateChilds(currValue) siblings = append(siblings, lChild) - return t.down(newKey, rChild, siblings, path, l+1, getLeaf) + return t.down(newKey, rChild, siblings, path, currLvl+1, getLeaf) } // left lChild, rChild := readIntermediateChilds(currValue) siblings = append(siblings, rChild) - return t.down(newKey, lChild, siblings, path, l+1, getLeaf) + return t.down(newKey, lChild, siblings, path, currLvl+1, getLeaf) default: return nil, nil, nil, fmt.Errorf("invalid value") } @@ -285,16 +285,16 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, // downVirtually is used when in a leaf already exists, and a new leaf which // shares the path until the existing leaf is being added func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath, - newPath []bool, l int) ([][]byte, error) { + newPath []bool, currLvl int) ([][]byte, error) { var err error - if l > t.maxLevels-1 { - return nil, fmt.Errorf("max virtual level %d", l) + if currLvl > t.maxLevels-1 { + return nil, fmt.Errorf("max virtual level %d", currLvl) } - if oldPath[l] == newPath[l] { + if oldPath[currLvl] == newPath[currLvl] { siblings = append(siblings, t.emptyHash) - siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1) + siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, currLvl+1) if err != nil { return nil, err } @@ -307,16 +307,16 @@ func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath, } // up goes up recursively updating the intermediate nodes -func (t *Tree) up(key []byte, siblings [][]byte, path []bool, l int) ([]byte, error) { +func (t *Tree) up(key []byte, siblings [][]byte, path []bool, currLvl int) ([]byte, error) { var k, v []byte var err error - if path[l] { - k, v, err = newIntermediate(t.hashFunction, siblings[l], key) + if path[currLvl] { + k, v, err = newIntermediate(t.hashFunction, siblings[currLvl], key) if err != nil { return nil, err } } else { - k, v, err = newIntermediate(t.hashFunction, key, siblings[l]) + k, v, err = newIntermediate(t.hashFunction, key, siblings[currLvl]) if err != nil { return nil, err } @@ -326,12 +326,12 @@ func (t *Tree) up(key []byte, siblings [][]byte, path []bool, l int) ([]byte, er return nil, err } - if l == 0 { + if currLvl == 0 { // reached the root return k, nil } - return t.up(k, siblings, path, l-1) + return t.up(k, siblings, path, currLvl-1) } func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) { @@ -666,24 +666,36 @@ func (t *Tree) Iterate(f func([]byte, []byte)) error { return t.iter(t.root, f) } -func (t *Tree) iter(k []byte, f func([]byte, []byte)) error { +// IterateWithStop does the same than Iterate, but with int for the current +// level, and a boolean parameter used by the passed function, is to indicate to +// stop iterating on the branch when the method returns 'true'. +func (t *Tree) IterateWithStop(f func(int, []byte, []byte) bool) error { + t.updateAccessTime() + return t.iterWithStop(t.root, 0, f) +} + +func (t *Tree) iterWithStop(k []byte, currLevel int, f func(int, []byte, []byte) bool) error { v, err := t.dbGet(k) if err != nil { return err } + currLevel++ switch v[0] { case PrefixValueEmpty: - f(k, v) + f(currLevel, k, v) case PrefixValueLeaf: - f(k, v) + f(currLevel, k, v) case PrefixValueIntermediate: - f(k, v) + stop := f(currLevel, k, v) + if stop { + return nil + } l, r := readIntermediateChilds(v) - if err = t.iter(l, f); err != nil { + if err = t.iterWithStop(l, currLevel, f); err != nil { return err } - if err = t.iter(r, f); err != nil { + if err = t.iterWithStop(r, currLevel, f); err != nil { return err } default: @@ -692,6 +704,14 @@ func (t *Tree) iter(k []byte, f func([]byte, []byte)) error { return nil } +func (t *Tree) iter(k []byte, f func([]byte, []byte)) error { + f2 := func(currLvl int, k, v []byte) bool { + f(k, v) + return false + } + return t.iterWithStop(k, 0, f2) +} + // Dump exports all the Tree leafs in a byte array of length: // [ N * (2+len(k+v)) ]. Where N is the number of key-values, and for each k+v: // [ 1 byte | 1 byte | S bytes | len(v) bytes ] @@ -768,12 +788,22 @@ func (t *Tree) ImportDump(b []byte) error { // Graphviz iterates across the full tree to generate a string Graphviz // representation of the tree and writes it to w func (t *Tree) Graphviz(w io.Writer, rootKey []byte) error { + return t.GraphvizFirstNLevels(w, rootKey, t.maxLevels) +} + +// GraphvizFirstNLevels iterates across the first NLevels of the tree to +// generate a string Graphviz representation of the first NLevels of the tree +// and writes it to w +func (t *Tree) GraphvizFirstNLevels(w io.Writer, rootKey []byte, untilLvl int) error { fmt.Fprintf(w, `digraph hierarchy { node [fontname=Monospace,fontsize=10,shape=box] `) nChars := 4 nEmpties := 0 - err := t.Iterate(func(k, v []byte) { + err := t.iterWithStop(t.root, 0, func(currLvl int, k, v []byte) bool { + if currLvl == untilLvl { + return true // to stop the iter from going down + } switch v[0] { case PrefixValueEmpty: case PrefixValueLeaf: @@ -807,6 +837,7 @@ node [fontname=Monospace,fontsize=10,shape=box] fmt.Fprint(w, eStr) default: } + return false }) fmt.Fprintf(w, "}\n") return err @@ -814,13 +845,18 @@ node [fontname=Monospace,fontsize=10,shape=box] // PrintGraphviz prints the output of Tree.Graphviz func (t *Tree) PrintGraphviz(rootKey []byte) error { + return t.PrintGraphvizFirstNLevels(rootKey, t.maxLevels) +} + +// PrintGraphvizFirstNLevels prints the output of Tree.GraphvizFirstNLevels +func (t *Tree) PrintGraphvizFirstNLevels(rootKey []byte, untilLvl int) error { if rootKey == nil { rootKey = t.Root() } w := bytes.NewBufferString("") fmt.Fprintf(w, "--------\nGraphviz of the Tree with Root "+hex.EncodeToString(rootKey)+":\n") - err := t.Graphviz(w, nil) + err := t.GraphvizFirstNLevels(w, nil, untilLvl) if err != nil { fmt.Println(w) return err diff --git a/utils.go b/utils.go index 47358fd..de19c8f 100644 --- a/utils.go +++ b/utils.go @@ -13,7 +13,7 @@ func SwapEndianness(b []byte) []byte { // BigIntToBytes converts a *big.Int into a byte array in Little-Endian func BigIntToBytes(bi *big.Int) []byte { - var b [32]byte + var b [32]byte // TODO make the length depending on the tree.hashFunction.Len() copy(b[:], SwapEndianness(bi.Bytes())) return b[:] }