From d09bd605bb0e5f24f974b5ca7a2affa20b2ba51a Mon Sep 17 00:00:00 2001 From: arnaucube Date: Sun, 23 May 2021 21:31:33 +0200 Subject: [PATCH] Implement VirtualTree.addBatch with cpu parallelization --- vt.go | 171 ++++++++++++++++++++++++++++++++++++++++++++++++++--- vt_test.go | 35 +++++++++++ 2 files changed, 198 insertions(+), 8 deletions(-) diff --git a/vt.go b/vt.go index a8b4740..509feeb 100644 --- a/vt.go +++ b/vt.go @@ -9,6 +9,9 @@ import ( "encoding/hex" "fmt" "io" + "math" + "runtime" + "sync" ) type node struct { @@ -27,6 +30,24 @@ type params struct { dbg *dbgStats } +func (p *params) keysValuesToKvs(ks, vs [][]byte) ([]kv, error) { + if len(ks) != len(vs) { + return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)", + len(ks), len(vs)) + } + kvs := make([]kv, len(ks)) + for i := 0; i < len(ks); i++ { + keyPath := make([]byte, p.hashFunction.Len()) + copy(keyPath[:], ks[i]) + kvs[i].pos = i + kvs[i].keyPath = keyPath + kvs[i].k = ks[i] + kvs[i].v = vs[i] + } + + return kvs, nil +} + // vt stands for virtual tree. It's a tree that does not have any computed hash // while placing the leafs. Once all the leafs are placed, it computes all the // hashes. In this way, each node hash is only computed one time (at the end) @@ -47,14 +68,144 @@ func newVT(maxLevels int, hash HashFunction) vt { } } -// WIP -// func (t *vt) addBatch(fromLvl int, k, v []byte) error { -// // parallelize adding leafs in the virtual tree -// nCPU := flp2(runtime.NumCPU()) -// l := int(math.Log2(float64(nCPU))) -// -// return nil -// } +func (t *vt) addBatch(ks, vs [][]byte) error { + // parallelize adding leafs in the virtual tree + nCPU := flp2(runtime.NumCPU()) + if nCPU == 1 || len(ks) < nCPU { + // var invalids []int + for i := 0; i < len(ks); i++ { + if err := t.add(0, ks[i], vs[i]); err != nil { + // invalids = append(invalids, i) + fmt.Println(err) // TODO WIP + } + } + return nil // TODO invalids + } + + l := int(math.Log2(float64(nCPU))) + + kvs, err := t.params.keysValuesToKvs(ks, vs) + if err != nil { + return err + } + + buckets := splitInBuckets(kvs, nCPU) + + nodesAtL, err := t.getNodesAtLevel(l) + if err != nil { + return err + } + // fmt.Println("nodesatL pre-E", len(nodesAtL)) + if len(nodesAtL) != nCPU { + // CASE E: add one key at each bucket, and then do CASE D + for i := 0; i < len(buckets); i++ { + // add one leaf of the bucket, if there is an error when + // adding the k-v, try to add the next one of the bucket + // (until one is added) + var inserted int + for j := 0; j < len(buckets[i]); j++ { + if err := t.add(0, buckets[i][j].k, buckets[i][j].v); err == nil { + inserted = j + break + } + } + + // remove the inserted element from buckets[i] + buckets[i] = append(buckets[i][:inserted], buckets[i][inserted+1:]...) + } + nodesAtL, err = t.getNodesAtLevel(l) + if err != nil { + return err + } + } + + subRoots := make([]*node, nCPU) + invalidsInBucket := make([][]int, nCPU) + + var wg sync.WaitGroup + wg.Add(nCPU) + for i := 0; i < nCPU; i++ { + go func(cpu int) { + sortKvs(buckets[cpu]) + + bucketVT := newVT(t.params.maxLevels-l, t.params.hashFunction) + bucketVT.root = nodesAtL[cpu] + for j := 0; j < len(buckets[cpu]); j++ { + if err = bucketVT.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil { + invalidsInBucket[cpu] = append(invalidsInBucket[cpu], buckets[cpu][j].pos) + } + } + subRoots[cpu] = bucketVT.root + wg.Done() + }(i) + } + wg.Wait() + + newRootNode, err := upFromNodes(subRoots) + if err != nil { + return err + } + t.root = newRootNode + + return nil +} + +func (t *vt) getNodesAtLevel(l int) ([]*node, error) { + if t.root == nil { + return nil, nil + } + return t.root.getNodesAtLevel(0, l) +} + +func (n *node) getNodesAtLevel(currLvl, l int) ([]*node, error) { + var nodes []*node + + typ := n.typ() + if currLvl == l && typ != vtEmpty { + nodes = append(nodes, n) + return nodes, nil + } + if currLvl >= l { + panic("should not reach this point") // TODO TMP + // return nil, nil + } + + if n.l != nil { + nodesL, err := n.l.getNodesAtLevel(currLvl+1, l) + if err != nil { + return nil, err + } + nodes = append(nodes, nodesL...) + } + if n.r != nil { + nodesR, err := n.r.getNodesAtLevel(currLvl+1, l) + if err != nil { + return nil, err + } + nodes = append(nodes, nodesR...) + } + return nodes, nil +} + +func upFromNodes(ns []*node) (*node, error) { + if len(ns) == 1 { + return ns[0], nil + } + + var res []*node + for i := 0; i < len(ns); i += 2 { + if ns[i].typ() == vtEmpty && ns[i+1].typ() == vtEmpty { + // when both sub nodes are empty, the node is also empty + res = append(res, ns[i]) // empty node + } + n := &node{ + l: ns[i], + r: ns[i+1], + } + res = append(res, n) + } + return upFromNodes(res) +} func (t *vt) add(fromLvl int, k, v []byte) error { leaf := newLeafNode(t.params, k, v) @@ -75,6 +226,7 @@ func (t *vt) add(fromLvl int, k, v []byte) error { func (t *vt) computeHashes() ([][2][]byte, error) { var pairs [][2][]byte var err error + // TODO parallelize computeHashes pairs, err = t.root.computeHashes(t.params, pairs) if err != nil { return pairs, err @@ -103,6 +255,9 @@ const ( ) func (n *node) typ() virtualNodeType { + if n == nil { + return vtEmpty // TODO decide if return 'vtEmpty' or an error + } if n.l == nil && n.r == nil && n.k != nil { return vtLeaf } diff --git a/vt_test.go b/vt_test.go index 0ac8055..bbd0f35 100644 --- a/vt_test.go +++ b/vt_test.go @@ -90,3 +90,38 @@ func testVirtualTree(c *qt.C, maxLevels int, keys, values [][]byte) { c.Assert(err, qt.IsNil) c.Assert(vTree.root.h, qt.DeepEquals, tree.root) } + +func TestVirtualTreeAddBatch(t *testing.T) { + c := qt.New(t) + + nLeafs := 2000 + maxLevels := 100 + + keys := make([][]byte, nLeafs) + values := make([][]byte, nLeafs) + for i := 0; i < nLeafs; i++ { + keys[i] = randomBytes(32) + values[i] = randomBytes(32) + } + + // normal tree, to have an expected root value + tree, err := NewTree(memory.NewMemoryStorage(), maxLevels, HashFunctionBlake2b) + c.Assert(err, qt.IsNil) + for i := 0; i < len(keys); i++ { + err := tree.Add(keys[i], values[i]) + c.Assert(err, qt.IsNil) + } + + // virtual tree + vTree := newVT(maxLevels, HashFunctionBlake2b) + + c.Assert(vTree.root, qt.IsNil) + + err = vTree.addBatch(keys, values) + c.Assert(err, qt.IsNil) + + // compute hashes, and check Root + _, err = vTree.computeHashes() + c.Assert(err, qt.IsNil) + c.Assert(vTree.root.h, qt.DeepEquals, tree.root) +}