From 1c2b7d68712b710c96fefdc294a1ddcd3d5d1fbe Mon Sep 17 00:00:00 2001 From: arnaucube Date: Sat, 24 Apr 2021 23:46:23 +0200 Subject: [PATCH] Add CPU parallelization to AddBatch CaseD AddBatch in CaseD, is parallelized (for each CPU) until almost the top level, almost dividing the needed time by the number of CPUs. --- addbatch.go | 44 ++++++++++++++++++++++++++++++-------------- addbatch_test.go | 8 +++----- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/addbatch.go b/addbatch.go index 8e03707..032b587 100644 --- a/addbatch.go +++ b/addbatch.go @@ -297,35 +297,51 @@ func (t *Tree) caseB(l int, kvs []kv) ([]int, []kv, error) { } func (t *Tree) caseD(nCPU, l int, kvs []kv) ([]int, error) { - fmt.Println("CASE D", nCPU) keysAtL, err := t.getKeysAtLevel(l + 1) if err != nil { return nil, err } buckets := splitInBuckets(kvs, nCPU) - var subRoots [][]byte - var invalids []int - for i := 0; i < len(keysAtL); i++ { - bucketTree := Tree{tx: t.tx, db: t.db, maxLevels: t.maxLevels, // maxLevels-l - hashFunction: t.hashFunction, root: keysAtL[i]} - - for j := 0; j < len(buckets[i]); j++ { - if err = bucketTree.add(l, buckets[i][j].k, buckets[i][j].v); err != nil { - fmt.Println("failed", buckets[i][j].k[:4]) + subRoots := make([][]byte, nCPU) + invalidsInBucket := make([][]int, nCPU) + txs := make([]db.Tx, nCPU) - panic(err) - // invalids = append(invalids, buckets[i][j].pos) + var wg sync.WaitGroup + wg.Add(nCPU) + for i := 0; i < nCPU; i++ { + go func(cpu int) { + var err error + txs[cpu], err = t.db.NewTx() + if err != nil { + panic(err) // TODO } - } - subRoots = append(subRoots, bucketTree.root) + bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels, // maxLevels-l + hashFunction: t.hashFunction, root: keysAtL[cpu]} + + for j := 0; j < len(buckets[cpu]); j++ { + if err = bucketTree.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil { + fmt.Println("failed", buckets[cpu][j].k[:4]) + invalidsInBucket[cpu] = append(invalidsInBucket[cpu], buckets[cpu][j].pos) + } + } + subRoots[cpu] = bucketTree.root + wg.Done() + }(i) } + wg.Wait() + newRoot, err := t.upFromKeys(subRoots) if err != nil { return nil, err } t.root = newRoot + var invalids []int + for i := 0; i < len(invalidsInBucket); i++ { + invalids = append(invalids, invalidsInBucket[i]...) + } + return invalids, nil } diff --git a/addbatch_test.go b/addbatch_test.go index 3c6a7cd..6830445 100644 --- a/addbatch_test.go +++ b/addbatch_test.go @@ -11,7 +11,7 @@ import ( "github.com/iden3/go-merkletree/db/memory" ) -func TestBatchAux(t *testing.T) { +func TestBatchAux(t *testing.T) { // TODO TMP this test will be delted c := qt.New(t) nLeafs := 16 @@ -342,7 +342,7 @@ func TestAddBatchCaseC(t *testing.T) { func TestAddBatchCaseD(t *testing.T) { c := qt.New(t) - nLeafs := 8192 + nLeafs := 4096 tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) c.Assert(err, qt.IsNil) @@ -364,15 +364,13 @@ func TestAddBatchCaseD(t *testing.T) { // add the initial leafs to fill a bit the tree before calling the // AddBatch method - for i := 0; i < 900; i++ { // TMP TODO use const minLeafsThreshold-1 once ready + for i := 0; i < 900; i++ { // TMP TODO use const minLeafsThreshold+1 once ready k := BigIntToBytes(big.NewInt(int64(i))) v := BigIntToBytes(big.NewInt(int64(i * 2))) if err := tree2.Add(k, v); err != nil { t.Fatal(err) } } - // tree2.PrintGraphvizFirstNLevels(nil, 4) - // tree2.PrintGraphviz(nil) var keys, values [][]byte for i := 900; i < nLeafs; i++ {