Browse Source

Implement VirtualTree.addBatch with cpu parallelization

master
arnaucube 3 years ago
parent
commit
d09bd605bb
2 changed files with 198 additions and 8 deletions
  1. +163
    -8
      vt.go
  2. +35
    -0
      vt_test.go

+ 163
- 8
vt.go

@ -9,6 +9,9 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io" "io"
"math"
"runtime"
"sync"
) )
type node struct { type node struct {
@ -27,6 +30,24 @@ type params struct {
dbg *dbgStats dbg *dbgStats
} }
func (p *params) keysValuesToKvs(ks, vs [][]byte) ([]kv, error) {
if len(ks) != len(vs) {
return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
len(ks), len(vs))
}
kvs := make([]kv, len(ks))
for i := 0; i < len(ks); i++ {
keyPath := make([]byte, p.hashFunction.Len())
copy(keyPath[:], ks[i])
kvs[i].pos = i
kvs[i].keyPath = keyPath
kvs[i].k = ks[i]
kvs[i].v = vs[i]
}
return kvs, nil
}
// vt stands for virtual tree. It's a tree that does not have any computed hash // vt stands for virtual tree. It's a tree that does not have any computed hash
// while placing the leafs. Once all the leafs are placed, it computes all the // while placing the leafs. Once all the leafs are placed, it computes all the
// hashes. In this way, each node hash is only computed one time (at the end) // hashes. In this way, each node hash is only computed one time (at the end)
@ -47,14 +68,144 @@ func newVT(maxLevels int, hash HashFunction) vt {
} }
} }
// WIP
// func (t *vt) addBatch(fromLvl int, k, v []byte) error {
// // parallelize adding leafs in the virtual tree
// nCPU := flp2(runtime.NumCPU())
// l := int(math.Log2(float64(nCPU)))
//
// return nil
// }
func (t *vt) addBatch(ks, vs [][]byte) error {
// parallelize adding leafs in the virtual tree
nCPU := flp2(runtime.NumCPU())
if nCPU == 1 || len(ks) < nCPU {
// var invalids []int
for i := 0; i < len(ks); i++ {
if err := t.add(0, ks[i], vs[i]); err != nil {
// invalids = append(invalids, i)
fmt.Println(err) // TODO WIP
}
}
return nil // TODO invalids
}
l := int(math.Log2(float64(nCPU)))
kvs, err := t.params.keysValuesToKvs(ks, vs)
if err != nil {
return err
}
buckets := splitInBuckets(kvs, nCPU)
nodesAtL, err := t.getNodesAtLevel(l)
if err != nil {
return err
}
// fmt.Println("nodesatL pre-E", len(nodesAtL))
if len(nodesAtL) != nCPU {
// CASE E: add one key at each bucket, and then do CASE D
for i := 0; i < len(buckets); i++ {
// add one leaf of the bucket, if there is an error when
// adding the k-v, try to add the next one of the bucket
// (until one is added)
var inserted int
for j := 0; j < len(buckets[i]); j++ {
if err := t.add(0, buckets[i][j].k, buckets[i][j].v); err == nil {
inserted = j
break
}
}
// remove the inserted element from buckets[i]
buckets[i] = append(buckets[i][:inserted], buckets[i][inserted+1:]...)
}
nodesAtL, err = t.getNodesAtLevel(l)
if err != nil {
return err
}
}
subRoots := make([]*node, nCPU)
invalidsInBucket := make([][]int, nCPU)
var wg sync.WaitGroup
wg.Add(nCPU)
for i := 0; i < nCPU; i++ {
go func(cpu int) {
sortKvs(buckets[cpu])
bucketVT := newVT(t.params.maxLevels-l, t.params.hashFunction)
bucketVT.root = nodesAtL[cpu]
for j := 0; j < len(buckets[cpu]); j++ {
if err = bucketVT.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil {
invalidsInBucket[cpu] = append(invalidsInBucket[cpu], buckets[cpu][j].pos)
}
}
subRoots[cpu] = bucketVT.root
wg.Done()
}(i)
}
wg.Wait()
newRootNode, err := upFromNodes(subRoots)
if err != nil {
return err
}
t.root = newRootNode
return nil
}
func (t *vt) getNodesAtLevel(l int) ([]*node, error) {
if t.root == nil {
return nil, nil
}
return t.root.getNodesAtLevel(0, l)
}
func (n *node) getNodesAtLevel(currLvl, l int) ([]*node, error) {
var nodes []*node
typ := n.typ()
if currLvl == l && typ != vtEmpty {
nodes = append(nodes, n)
return nodes, nil
}
if currLvl >= l {
panic("should not reach this point") // TODO TMP
// return nil, nil
}
if n.l != nil {
nodesL, err := n.l.getNodesAtLevel(currLvl+1, l)
if err != nil {
return nil, err
}
nodes = append(nodes, nodesL...)
}
if n.r != nil {
nodesR, err := n.r.getNodesAtLevel(currLvl+1, l)
if err != nil {
return nil, err
}
nodes = append(nodes, nodesR...)
}
return nodes, nil
}
func upFromNodes(ns []*node) (*node, error) {
if len(ns) == 1 {
return ns[0], nil
}
var res []*node
for i := 0; i < len(ns); i += 2 {
if ns[i].typ() == vtEmpty && ns[i+1].typ() == vtEmpty {
// when both sub nodes are empty, the node is also empty
res = append(res, ns[i]) // empty node
}
n := &node{
l: ns[i],
r: ns[i+1],
}
res = append(res, n)
}
return upFromNodes(res)
}
func (t *vt) add(fromLvl int, k, v []byte) error { func (t *vt) add(fromLvl int, k, v []byte) error {
leaf := newLeafNode(t.params, k, v) leaf := newLeafNode(t.params, k, v)
@ -75,6 +226,7 @@ func (t *vt) add(fromLvl int, k, v []byte) error {
func (t *vt) computeHashes() ([][2][]byte, error) { func (t *vt) computeHashes() ([][2][]byte, error) {
var pairs [][2][]byte var pairs [][2][]byte
var err error var err error
// TODO parallelize computeHashes
pairs, err = t.root.computeHashes(t.params, pairs) pairs, err = t.root.computeHashes(t.params, pairs)
if err != nil { if err != nil {
return pairs, err return pairs, err
@ -103,6 +255,9 @@ const (
) )
func (n *node) typ() virtualNodeType { func (n *node) typ() virtualNodeType {
if n == nil {
return vtEmpty // TODO decide if return 'vtEmpty' or an error
}
if n.l == nil && n.r == nil && n.k != nil { if n.l == nil && n.r == nil && n.k != nil {
return vtLeaf return vtLeaf
} }

+ 35
- 0
vt_test.go

@ -90,3 +90,38 @@ func testVirtualTree(c *qt.C, maxLevels int, keys, values [][]byte) {
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
c.Assert(vTree.root.h, qt.DeepEquals, tree.root) c.Assert(vTree.root.h, qt.DeepEquals, tree.root)
} }
func TestVirtualTreeAddBatch(t *testing.T) {
c := qt.New(t)
nLeafs := 2000
maxLevels := 100
keys := make([][]byte, nLeafs)
values := make([][]byte, nLeafs)
for i := 0; i < nLeafs; i++ {
keys[i] = randomBytes(32)
values[i] = randomBytes(32)
}
// normal tree, to have an expected root value
tree, err := NewTree(memory.NewMemoryStorage(), maxLevels, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
for i := 0; i < len(keys); i++ {
err := tree.Add(keys[i], values[i])
c.Assert(err, qt.IsNil)
}
// virtual tree
vTree := newVT(maxLevels, HashFunctionBlake2b)
c.Assert(vTree.root, qt.IsNil)
err = vTree.addBatch(keys, values)
c.Assert(err, qt.IsNil)
// compute hashes, and check Root
_, err = vTree.computeHashes()
c.Assert(err, qt.IsNil)
c.Assert(vTree.root.h, qt.DeepEquals, tree.root)
}

Loading…
Cancel
Save