Browse Source

Add CPU parallelization to buildTreBottomUp

buildTreeBottomUp splits the key-values into n Buckets (where n is the
number of CPUs), in parallel builds a subtree for each bucket, once all
the subtrees are built, uses the subtrees roots as keys for a new tree,
which as result will have the complete Tree build from bottom to up,
where until the log2(nCPU) level it has been computed in parallel.

As result of this, the tree construction can be parallelized until
almost the top level, almost dividing the time by the number of CPUs.
master
arnaucube 3 years ago
parent
commit
a4ada7e2ee
3 changed files with 78 additions and 23 deletions
  1. +64
    -12
      addbatch.go
  2. +7
    -7
      tree.go
  3. +7
    -4
      tree_test.go

+ 64
- 12
addbatch.go

@ -4,7 +4,11 @@ import (
"bytes"
"fmt"
"math"
"runtime"
"sort"
"sync"
"github.com/iden3/go-merkletree/db"
)
/*
@ -138,8 +142,7 @@ Algorithm decision
*/
const (
minLeafsThreshold = uint64(100) // nolint:gomnd // TMP WIP this will be autocalculated
nBuckets = uint64(4) // TMP WIP this will be autocalculated from
minLeafsThreshold = 100 // nolint:gomnd // TMP WIP this will be autocalculated
)
// AddBatchOpt is the WIP implementation of the AddBatch method in a more
@ -165,11 +168,14 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) {
return nil, err
}
nCPU := runtime.NumCPU()
// CASE A: if nLeafs==0 (root==0)
if bytes.Equal(t.root, t.emptyHash) {
// sort keys & values by path
sortKvs(kvs)
return t.buildTreeBottomUp(kvs)
// TODO if len(kvs) is not a power of 2, cut at the bigger power
// of two under len(kvs), build the tree with that, and add
// later the excedents
return t.buildTreeBottomUp(nCPU, kvs)
}
// CASE B: if nLeafs<nBuckets
@ -195,8 +201,8 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) {
// CASE C: if nLeafs>=minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold
// available parallelization, will need to be a power of 2 (2**n)
var excedents []kv
l := int(math.Log2(float64(nBuckets)))
if nLeafs >= minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold {
l := int(math.Log2(float64(nCPU)))
if nLeafs >= minLeafsThreshold && (nLeafs/nCPU) < minLeafsThreshold {
// TODO move to own function
// 1. go down until level L (L=log2(nBuckets))
keysAtL, err := t.getKeysAtLevel(l + 1)
@ -204,7 +210,7 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) {
return nil, err
}
buckets := splitInBuckets(kvs, nBuckets)
buckets := splitInBuckets(kvs, nCPU)
// 2. use keys at level L as roots of the subtrees under each one
var subRoots [][]byte
@ -264,7 +270,7 @@ func (t *Tree) caseB(l int, kvs []kv) ([]int, []kv, error) {
// cutPowerOfTwo, the excedent add it as normal Tree.Add
kvsP2, kvsNonP2 := cutPowerOfTwo(kvs)
invalids, err := t.buildTreeBottomUp(kvsP2)
invalids, err := t.buildTreeBottomUpSingleThread(kvsP2)
if err != nil {
return nil, nil, err
}
@ -272,13 +278,13 @@ func (t *Tree) caseB(l int, kvs []kv) ([]int, []kv, error) {
return invalids, kvsNonP2, nil
}
func splitInBuckets(kvs []kv, nBuckets uint64) [][]kv {
func splitInBuckets(kvs []kv, nBuckets int) [][]kv {
buckets := make([][]kv, nBuckets)
// 1. classify the keyvalues into buckets
for i := 0; i < len(kvs); i++ {
pair := kvs[i]
bucketnum := keyToBucket(pair.k, int(nBuckets))
bucketnum := keyToBucket(pair.k, nBuckets)
buckets[bucketnum] = append(buckets[bucketnum], pair)
}
return buckets
@ -367,10 +373,56 @@ func (t *Tree) kvsToKeysValues(kvs []kv) ([][]byte, [][]byte) {
}
*/
// buildTreeBottomUp splits the key-values into n Buckets (where n is the number
// of CPUs), in parallel builds a subtree for each bucket, once all the subtrees
// are built, uses the subtrees roots as keys for a new tree, which as result
// will have the complete Tree build from bottom to up, where until the
// log2(nCPU) level it has been computed in parallel.
func (t *Tree) buildTreeBottomUp(nCPU int, kvs []kv) ([]int, error) {
buckets := splitInBuckets(kvs, nCPU)
subRoots := make([][]byte, nCPU)
txs := make([]db.Tx, nCPU)
var wg sync.WaitGroup
wg.Add(nCPU)
for i := 0; i < nCPU; i++ {
go func(cpu int) {
sortKvs(buckets[cpu])
var err error
txs[cpu], err = t.db.NewTx()
if err != nil {
panic(err) // TODO
}
bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels,
hashFunction: t.hashFunction, root: t.emptyHash}
// TODO use invalids array
_, err = bucketTree.buildTreeBottomUpSingleThread(buckets[cpu])
if err != nil {
panic(err) // TODO
}
subRoots[cpu] = bucketTree.root
wg.Done()
}(i)
}
wg.Wait()
newRoot, err := t.upFromKeys(subRoots)
if err != nil {
return nil, err
}
t.root = newRoot
return nil, err
}
// keys & values must be sorted by path, and the array ks must be length
// multiple of 2
// TODO return index of failed keyvaules
func (t *Tree) buildTreeBottomUp(kvs []kv) ([]int, error) {
func (t *Tree) buildTreeBottomUpSingleThread(kvs []kv) ([]int, error) {
// TODO check that log2(len(leafs)) < t.maxLevels, if not, maxLevels
// would be reached and should return error
// build the leafs
leafKeys := make([][]byte, len(kvs))
for i := 0; i < len(kvs); i++ {

+ 7
- 7
tree.go

@ -138,7 +138,7 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
return indexes, err
}
// update nLeafs
if err = t.incNLeafs(uint64(len(keys) - len(indexes))); err != nil {
if err = t.incNLeafs(len(keys) - len(indexes)); err != nil {
return indexes, err
}
@ -629,7 +629,7 @@ func (t *Tree) dbGet(k []byte) ([]byte, error) {
// Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
// after the setNLeafs call.
func (t *Tree) incNLeafs(nLeafs uint64) error {
func (t *Tree) incNLeafs(nLeafs int) error {
oldNLeafs, err := t.GetNLeafs()
if err != nil {
return err
@ -640,9 +640,9 @@ func (t *Tree) incNLeafs(nLeafs uint64) error {
// Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
// after the setNLeafs call.
func (t *Tree) setNLeafs(nLeafs uint64) error {
func (t *Tree) setNLeafs(nLeafs int) error {
b := make([]byte, 8)
binary.LittleEndian.PutUint64(b, nLeafs)
binary.LittleEndian.PutUint64(b, uint64(nLeafs))
if err := t.tx.Put(dbKeyNLeafs, b); err != nil {
return err
}
@ -650,13 +650,13 @@ func (t *Tree) setNLeafs(nLeafs uint64) error {
}
// GetNLeafs returns the number of Leafs of the Tree.
func (t *Tree) GetNLeafs() (uint64, error) {
func (t *Tree) GetNLeafs() (int, error) {
b, err := t.dbGet(dbKeyNLeafs)
if err != nil {
return 0, err
}
nLeafs := binary.LittleEndian.Uint64(b)
return nLeafs, nil
return int(nLeafs), nil
}
// Iterate iterates through the full Tree, executing the given function on each
@ -776,7 +776,7 @@ func (t *Tree) ImportDump(b []byte) error {
if err != nil {
return err
}
if err := t.incNLeafs(uint64(count)); err != nil {
if err := t.incNLeafs(count); err != nil {
return err
}
if err = t.tx.Commit(); err != nil {

+ 7
- 4
tree_test.go

@ -342,7 +342,7 @@ func TestSetGetNLeafs(t *testing.T) {
n, err := tree.GetNLeafs()
c.Assert(err, qt.IsNil)
c.Assert(n, qt.Equals, uint64(0))
c.Assert(n, qt.Equals, 0)
// 1024
tree.tx, err = tree.db.NewTx()
@ -356,13 +356,16 @@ func TestSetGetNLeafs(t *testing.T) {
n, err = tree.GetNLeafs()
c.Assert(err, qt.IsNil)
c.Assert(n, qt.Equals, uint64(1024))
c.Assert(n, qt.Equals, 1024)
// 2**64 -1
tree.tx, err = tree.db.NewTx()
c.Assert(err, qt.IsNil)
err = tree.setNLeafs(18446744073709551615)
maxUint := ^uint(0)
maxInt := int(maxUint >> 1)
err = tree.setNLeafs(maxInt)
c.Assert(err, qt.IsNil)
err = tree.tx.Commit()
@ -370,7 +373,7 @@ func TestSetGetNLeafs(t *testing.T) {
n, err = tree.GetNLeafs()
c.Assert(err, qt.IsNil)
c.Assert(n, qt.Equals, uint64(18446744073709551615))
c.Assert(n, qt.Equals, maxInt)
}
func BenchmarkAdd(b *testing.B) {

Loading…
Cancel
Save