diff --git a/addbatch.go b/addbatch.go index 5622973..bba5a5d 100644 --- a/addbatch.go +++ b/addbatch.go @@ -4,7 +4,11 @@ import ( "bytes" "fmt" "math" + "runtime" "sort" + "sync" + + "github.com/iden3/go-merkletree/db" ) /* @@ -138,8 +142,7 @@ Algorithm decision */ const ( - minLeafsThreshold = uint64(100) // nolint:gomnd // TMP WIP this will be autocalculated - nBuckets = uint64(4) // TMP WIP this will be autocalculated from + minLeafsThreshold = 100 // nolint:gomnd // TMP WIP this will be autocalculated ) // AddBatchOpt is the WIP implementation of the AddBatch method in a more @@ -165,11 +168,14 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { return nil, err } + nCPU := runtime.NumCPU() + // CASE A: if nLeafs==0 (root==0) if bytes.Equal(t.root, t.emptyHash) { - // sort keys & values by path - sortKvs(kvs) - return t.buildTreeBottomUp(kvs) + // TODO if len(kvs) is not a power of 2, cut at the bigger power + // of two under len(kvs), build the tree with that, and add + // later the excedents + return t.buildTreeBottomUp(nCPU, kvs) } // CASE B: if nLeafs=minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold // available parallelization, will need to be a power of 2 (2**n) var excedents []kv - l := int(math.Log2(float64(nBuckets))) - if nLeafs >= minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold { + l := int(math.Log2(float64(nCPU))) + if nLeafs >= minLeafsThreshold && (nLeafs/nCPU) < minLeafsThreshold { // TODO move to own function // 1. go down until level L (L=log2(nBuckets)) keysAtL, err := t.getKeysAtLevel(l + 1) @@ -204,7 +210,7 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { return nil, err } - buckets := splitInBuckets(kvs, nBuckets) + buckets := splitInBuckets(kvs, nCPU) // 2. use keys at level L as roots of the subtrees under each one var subRoots [][]byte @@ -264,7 +270,7 @@ func (t *Tree) caseB(l int, kvs []kv) ([]int, []kv, error) { // cutPowerOfTwo, the excedent add it as normal Tree.Add kvsP2, kvsNonP2 := cutPowerOfTwo(kvs) - invalids, err := t.buildTreeBottomUp(kvsP2) + invalids, err := t.buildTreeBottomUpSingleThread(kvsP2) if err != nil { return nil, nil, err } @@ -272,13 +278,13 @@ func (t *Tree) caseB(l int, kvs []kv) ([]int, []kv, error) { return invalids, kvsNonP2, nil } -func splitInBuckets(kvs []kv, nBuckets uint64) [][]kv { +func splitInBuckets(kvs []kv, nBuckets int) [][]kv { buckets := make([][]kv, nBuckets) // 1. classify the keyvalues into buckets for i := 0; i < len(kvs); i++ { pair := kvs[i] - bucketnum := keyToBucket(pair.k, int(nBuckets)) + bucketnum := keyToBucket(pair.k, nBuckets) buckets[bucketnum] = append(buckets[bucketnum], pair) } return buckets @@ -367,10 +373,56 @@ func (t *Tree) kvsToKeysValues(kvs []kv) ([][]byte, [][]byte) { } */ +// buildTreeBottomUp splits the key-values into n Buckets (where n is the number +// of CPUs), in parallel builds a subtree for each bucket, once all the subtrees +// are built, uses the subtrees roots as keys for a new tree, which as result +// will have the complete Tree build from bottom to up, where until the +// log2(nCPU) level it has been computed in parallel. +func (t *Tree) buildTreeBottomUp(nCPU int, kvs []kv) ([]int, error) { + buckets := splitInBuckets(kvs, nCPU) + subRoots := make([][]byte, nCPU) + txs := make([]db.Tx, nCPU) + + var wg sync.WaitGroup + wg.Add(nCPU) + for i := 0; i < nCPU; i++ { + go func(cpu int) { + sortKvs(buckets[cpu]) + + var err error + txs[cpu], err = t.db.NewTx() + if err != nil { + panic(err) // TODO + } + bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels, + hashFunction: t.hashFunction, root: t.emptyHash} + + // TODO use invalids array + _, err = bucketTree.buildTreeBottomUpSingleThread(buckets[cpu]) + if err != nil { + panic(err) // TODO + } + + subRoots[cpu] = bucketTree.root + wg.Done() + }(i) + } + wg.Wait() + newRoot, err := t.upFromKeys(subRoots) + if err != nil { + return nil, err + } + t.root = newRoot + return nil, err +} + // keys & values must be sorted by path, and the array ks must be length // multiple of 2 // TODO return index of failed keyvaules -func (t *Tree) buildTreeBottomUp(kvs []kv) ([]int, error) { +func (t *Tree) buildTreeBottomUpSingleThread(kvs []kv) ([]int, error) { + // TODO check that log2(len(leafs)) < t.maxLevels, if not, maxLevels + // would be reached and should return error + // build the leafs leafKeys := make([][]byte, len(kvs)) for i := 0; i < len(kvs); i++ { diff --git a/tree.go b/tree.go index cdcb1e8..26f7892 100644 --- a/tree.go +++ b/tree.go @@ -138,7 +138,7 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { return indexes, err } // update nLeafs - if err = t.incNLeafs(uint64(len(keys) - len(indexes))); err != nil { + if err = t.incNLeafs(len(keys) - len(indexes)); err != nil { return indexes, err } @@ -629,7 +629,7 @@ func (t *Tree) dbGet(k []byte) ([]byte, error) { // Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit // after the setNLeafs call. -func (t *Tree) incNLeafs(nLeafs uint64) error { +func (t *Tree) incNLeafs(nLeafs int) error { oldNLeafs, err := t.GetNLeafs() if err != nil { return err @@ -640,9 +640,9 @@ func (t *Tree) incNLeafs(nLeafs uint64) error { // Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit // after the setNLeafs call. -func (t *Tree) setNLeafs(nLeafs uint64) error { +func (t *Tree) setNLeafs(nLeafs int) error { b := make([]byte, 8) - binary.LittleEndian.PutUint64(b, nLeafs) + binary.LittleEndian.PutUint64(b, uint64(nLeafs)) if err := t.tx.Put(dbKeyNLeafs, b); err != nil { return err } @@ -650,13 +650,13 @@ func (t *Tree) setNLeafs(nLeafs uint64) error { } // GetNLeafs returns the number of Leafs of the Tree. -func (t *Tree) GetNLeafs() (uint64, error) { +func (t *Tree) GetNLeafs() (int, error) { b, err := t.dbGet(dbKeyNLeafs) if err != nil { return 0, err } nLeafs := binary.LittleEndian.Uint64(b) - return nLeafs, nil + return int(nLeafs), nil } // Iterate iterates through the full Tree, executing the given function on each @@ -776,7 +776,7 @@ func (t *Tree) ImportDump(b []byte) error { if err != nil { return err } - if err := t.incNLeafs(uint64(count)); err != nil { + if err := t.incNLeafs(count); err != nil { return err } if err = t.tx.Commit(); err != nil { diff --git a/tree_test.go b/tree_test.go index 963d2bc..6d3c6c3 100644 --- a/tree_test.go +++ b/tree_test.go @@ -342,7 +342,7 @@ func TestSetGetNLeafs(t *testing.T) { n, err := tree.GetNLeafs() c.Assert(err, qt.IsNil) - c.Assert(n, qt.Equals, uint64(0)) + c.Assert(n, qt.Equals, 0) // 1024 tree.tx, err = tree.db.NewTx() @@ -356,13 +356,16 @@ func TestSetGetNLeafs(t *testing.T) { n, err = tree.GetNLeafs() c.Assert(err, qt.IsNil) - c.Assert(n, qt.Equals, uint64(1024)) + c.Assert(n, qt.Equals, 1024) // 2**64 -1 tree.tx, err = tree.db.NewTx() c.Assert(err, qt.IsNil) - err = tree.setNLeafs(18446744073709551615) + maxUint := ^uint(0) + maxInt := int(maxUint >> 1) + + err = tree.setNLeafs(maxInt) c.Assert(err, qt.IsNil) err = tree.tx.Commit() @@ -370,7 +373,7 @@ func TestSetGetNLeafs(t *testing.T) { n, err = tree.GetNLeafs() c.Assert(err, qt.IsNil) - c.Assert(n, qt.Equals, uint64(18446744073709551615)) + c.Assert(n, qt.Equals, maxInt) } func BenchmarkAdd(b *testing.B) {