Add CPU parallelization to AddBatch CaseD

AddBatch in CaseD, is parallelized (for each CPU) until almost the top
level, almost dividing the needed time by the number of CPUs.
This commit is contained in:
2021-04-24 23:46:23 +02:00
parent 890057cd82
commit 1c2b7d6871
2 changed files with 33 additions and 19 deletions

View File

@@ -297,35 +297,51 @@ func (t *Tree) caseB(l int, kvs []kv) ([]int, []kv, error) {
} }
func (t *Tree) caseD(nCPU, l int, kvs []kv) ([]int, error) { func (t *Tree) caseD(nCPU, l int, kvs []kv) ([]int, error) {
fmt.Println("CASE D", nCPU)
keysAtL, err := t.getKeysAtLevel(l + 1) keysAtL, err := t.getKeysAtLevel(l + 1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
buckets := splitInBuckets(kvs, nCPU) buckets := splitInBuckets(kvs, nCPU)
var subRoots [][]byte subRoots := make([][]byte, nCPU)
var invalids []int invalidsInBucket := make([][]int, nCPU)
for i := 0; i < len(keysAtL); i++ { txs := make([]db.Tx, nCPU)
bucketTree := Tree{tx: t.tx, db: t.db, maxLevels: t.maxLevels, // maxLevels-l
hashFunction: t.hashFunction, root: keysAtL[i]}
for j := 0; j < len(buckets[i]); j++ { var wg sync.WaitGroup
if err = bucketTree.add(l, buckets[i][j].k, buckets[i][j].v); err != nil { wg.Add(nCPU)
fmt.Println("failed", buckets[i][j].k[:4]) for i := 0; i < nCPU; i++ {
go func(cpu int) {
panic(err) var err error
// invalids = append(invalids, buckets[i][j].pos) txs[cpu], err = t.db.NewTx()
if err != nil {
panic(err) // TODO
} }
} bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels, // maxLevels-l
subRoots = append(subRoots, bucketTree.root) hashFunction: t.hashFunction, root: keysAtL[cpu]}
for j := 0; j < len(buckets[cpu]); j++ {
if err = bucketTree.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil {
fmt.Println("failed", buckets[cpu][j].k[:4])
invalidsInBucket[cpu] = append(invalidsInBucket[cpu], buckets[cpu][j].pos)
}
}
subRoots[cpu] = bucketTree.root
wg.Done()
}(i)
} }
wg.Wait()
newRoot, err := t.upFromKeys(subRoots) newRoot, err := t.upFromKeys(subRoots)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.root = newRoot t.root = newRoot
var invalids []int
for i := 0; i < len(invalidsInBucket); i++ {
invalids = append(invalids, invalidsInBucket[i]...)
}
return invalids, nil return invalids, nil
} }

View File

@@ -11,7 +11,7 @@ import (
"github.com/iden3/go-merkletree/db/memory" "github.com/iden3/go-merkletree/db/memory"
) )
func TestBatchAux(t *testing.T) { func TestBatchAux(t *testing.T) { // TODO TMP this test will be delted
c := qt.New(t) c := qt.New(t)
nLeafs := 16 nLeafs := 16
@@ -342,7 +342,7 @@ func TestAddBatchCaseC(t *testing.T) {
func TestAddBatchCaseD(t *testing.T) { func TestAddBatchCaseD(t *testing.T) {
c := qt.New(t) c := qt.New(t)
nLeafs := 8192 nLeafs := 4096
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
@@ -364,15 +364,13 @@ func TestAddBatchCaseD(t *testing.T) {
// add the initial leafs to fill a bit the tree before calling the // add the initial leafs to fill a bit the tree before calling the
// AddBatch method // AddBatch method
for i := 0; i < 900; i++ { // TMP TODO use const minLeafsThreshold-1 once ready for i := 0; i < 900; i++ { // TMP TODO use const minLeafsThreshold+1 once ready
k := BigIntToBytes(big.NewInt(int64(i))) k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(int64(i * 2))) v := BigIntToBytes(big.NewInt(int64(i * 2)))
if err := tree2.Add(k, v); err != nil { if err := tree2.Add(k, v); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
// tree2.PrintGraphvizFirstNLevels(nil, 4)
// tree2.PrintGraphviz(nil)
var keys, values [][]byte var keys, values [][]byte
for i := 900; i < nLeafs; i++ { for i := 900; i < nLeafs; i++ {