mirror of
https://github.com/arnaucube/arbo.git
synced 2026-01-15 01:41:28 +01:00
Implement VirtualTree.addBatch with cpu parallelization
This commit is contained in:
171
vt.go
171
vt.go
@@ -9,6 +9,9 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type node struct {
|
type node struct {
|
||||||
@@ -27,6 +30,24 @@ type params struct {
|
|||||||
dbg *dbgStats
|
dbg *dbgStats
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *params) keysValuesToKvs(ks, vs [][]byte) ([]kv, error) {
|
||||||
|
if len(ks) != len(vs) {
|
||||||
|
return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
|
||||||
|
len(ks), len(vs))
|
||||||
|
}
|
||||||
|
kvs := make([]kv, len(ks))
|
||||||
|
for i := 0; i < len(ks); i++ {
|
||||||
|
keyPath := make([]byte, p.hashFunction.Len())
|
||||||
|
copy(keyPath[:], ks[i])
|
||||||
|
kvs[i].pos = i
|
||||||
|
kvs[i].keyPath = keyPath
|
||||||
|
kvs[i].k = ks[i]
|
||||||
|
kvs[i].v = vs[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return kvs, nil
|
||||||
|
}
|
||||||
|
|
||||||
// vt stands for virtual tree. It's a tree that does not have any computed hash
|
// vt stands for virtual tree. It's a tree that does not have any computed hash
|
||||||
// while placing the leafs. Once all the leafs are placed, it computes all the
|
// while placing the leafs. Once all the leafs are placed, it computes all the
|
||||||
// hashes. In this way, each node hash is only computed one time (at the end)
|
// hashes. In this way, each node hash is only computed one time (at the end)
|
||||||
@@ -47,14 +68,144 @@ func newVT(maxLevels int, hash HashFunction) vt {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WIP
|
func (t *vt) addBatch(ks, vs [][]byte) error {
|
||||||
// func (t *vt) addBatch(fromLvl int, k, v []byte) error {
|
// parallelize adding leafs in the virtual tree
|
||||||
// // parallelize adding leafs in the virtual tree
|
nCPU := flp2(runtime.NumCPU())
|
||||||
// nCPU := flp2(runtime.NumCPU())
|
if nCPU == 1 || len(ks) < nCPU {
|
||||||
// l := int(math.Log2(float64(nCPU)))
|
// var invalids []int
|
||||||
//
|
for i := 0; i < len(ks); i++ {
|
||||||
// return nil
|
if err := t.add(0, ks[i], vs[i]); err != nil {
|
||||||
// }
|
// invalids = append(invalids, i)
|
||||||
|
fmt.Println(err) // TODO WIP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil // TODO invalids
|
||||||
|
}
|
||||||
|
|
||||||
|
l := int(math.Log2(float64(nCPU)))
|
||||||
|
|
||||||
|
kvs, err := t.params.keysValuesToKvs(ks, vs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
buckets := splitInBuckets(kvs, nCPU)
|
||||||
|
|
||||||
|
nodesAtL, err := t.getNodesAtLevel(l)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// fmt.Println("nodesatL pre-E", len(nodesAtL))
|
||||||
|
if len(nodesAtL) != nCPU {
|
||||||
|
// CASE E: add one key at each bucket, and then do CASE D
|
||||||
|
for i := 0; i < len(buckets); i++ {
|
||||||
|
// add one leaf of the bucket, if there is an error when
|
||||||
|
// adding the k-v, try to add the next one of the bucket
|
||||||
|
// (until one is added)
|
||||||
|
var inserted int
|
||||||
|
for j := 0; j < len(buckets[i]); j++ {
|
||||||
|
if err := t.add(0, buckets[i][j].k, buckets[i][j].v); err == nil {
|
||||||
|
inserted = j
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove the inserted element from buckets[i]
|
||||||
|
buckets[i] = append(buckets[i][:inserted], buckets[i][inserted+1:]...)
|
||||||
|
}
|
||||||
|
nodesAtL, err = t.getNodesAtLevel(l)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
subRoots := make([]*node, nCPU)
|
||||||
|
invalidsInBucket := make([][]int, nCPU)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(nCPU)
|
||||||
|
for i := 0; i < nCPU; i++ {
|
||||||
|
go func(cpu int) {
|
||||||
|
sortKvs(buckets[cpu])
|
||||||
|
|
||||||
|
bucketVT := newVT(t.params.maxLevels-l, t.params.hashFunction)
|
||||||
|
bucketVT.root = nodesAtL[cpu]
|
||||||
|
for j := 0; j < len(buckets[cpu]); j++ {
|
||||||
|
if err = bucketVT.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil {
|
||||||
|
invalidsInBucket[cpu] = append(invalidsInBucket[cpu], buckets[cpu][j].pos)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
subRoots[cpu] = bucketVT.root
|
||||||
|
wg.Done()
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
newRootNode, err := upFromNodes(subRoots)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.root = newRootNode
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *vt) getNodesAtLevel(l int) ([]*node, error) {
|
||||||
|
if t.root == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return t.root.getNodesAtLevel(0, l)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *node) getNodesAtLevel(currLvl, l int) ([]*node, error) {
|
||||||
|
var nodes []*node
|
||||||
|
|
||||||
|
typ := n.typ()
|
||||||
|
if currLvl == l && typ != vtEmpty {
|
||||||
|
nodes = append(nodes, n)
|
||||||
|
return nodes, nil
|
||||||
|
}
|
||||||
|
if currLvl >= l {
|
||||||
|
panic("should not reach this point") // TODO TMP
|
||||||
|
// return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if n.l != nil {
|
||||||
|
nodesL, err := n.l.getNodesAtLevel(currLvl+1, l)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
nodes = append(nodes, nodesL...)
|
||||||
|
}
|
||||||
|
if n.r != nil {
|
||||||
|
nodesR, err := n.r.getNodesAtLevel(currLvl+1, l)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
nodes = append(nodes, nodesR...)
|
||||||
|
}
|
||||||
|
return nodes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func upFromNodes(ns []*node) (*node, error) {
|
||||||
|
if len(ns) == 1 {
|
||||||
|
return ns[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var res []*node
|
||||||
|
for i := 0; i < len(ns); i += 2 {
|
||||||
|
if ns[i].typ() == vtEmpty && ns[i+1].typ() == vtEmpty {
|
||||||
|
// when both sub nodes are empty, the node is also empty
|
||||||
|
res = append(res, ns[i]) // empty node
|
||||||
|
}
|
||||||
|
n := &node{
|
||||||
|
l: ns[i],
|
||||||
|
r: ns[i+1],
|
||||||
|
}
|
||||||
|
res = append(res, n)
|
||||||
|
}
|
||||||
|
return upFromNodes(res)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *vt) add(fromLvl int, k, v []byte) error {
|
func (t *vt) add(fromLvl int, k, v []byte) error {
|
||||||
leaf := newLeafNode(t.params, k, v)
|
leaf := newLeafNode(t.params, k, v)
|
||||||
@@ -75,6 +226,7 @@ func (t *vt) add(fromLvl int, k, v []byte) error {
|
|||||||
func (t *vt) computeHashes() ([][2][]byte, error) {
|
func (t *vt) computeHashes() ([][2][]byte, error) {
|
||||||
var pairs [][2][]byte
|
var pairs [][2][]byte
|
||||||
var err error
|
var err error
|
||||||
|
// TODO parallelize computeHashes
|
||||||
pairs, err = t.root.computeHashes(t.params, pairs)
|
pairs, err = t.root.computeHashes(t.params, pairs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return pairs, err
|
return pairs, err
|
||||||
@@ -103,6 +255,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (n *node) typ() virtualNodeType {
|
func (n *node) typ() virtualNodeType {
|
||||||
|
if n == nil {
|
||||||
|
return vtEmpty // TODO decide if return 'vtEmpty' or an error
|
||||||
|
}
|
||||||
if n.l == nil && n.r == nil && n.k != nil {
|
if n.l == nil && n.r == nil && n.k != nil {
|
||||||
return vtLeaf
|
return vtLeaf
|
||||||
}
|
}
|
||||||
|
|||||||
35
vt_test.go
35
vt_test.go
@@ -90,3 +90,38 @@ func testVirtualTree(c *qt.C, maxLevels int, keys, values [][]byte) {
|
|||||||
c.Assert(err, qt.IsNil)
|
c.Assert(err, qt.IsNil)
|
||||||
c.Assert(vTree.root.h, qt.DeepEquals, tree.root)
|
c.Assert(vTree.root.h, qt.DeepEquals, tree.root)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestVirtualTreeAddBatch(t *testing.T) {
|
||||||
|
c := qt.New(t)
|
||||||
|
|
||||||
|
nLeafs := 2000
|
||||||
|
maxLevels := 100
|
||||||
|
|
||||||
|
keys := make([][]byte, nLeafs)
|
||||||
|
values := make([][]byte, nLeafs)
|
||||||
|
for i := 0; i < nLeafs; i++ {
|
||||||
|
keys[i] = randomBytes(32)
|
||||||
|
values[i] = randomBytes(32)
|
||||||
|
}
|
||||||
|
|
||||||
|
// normal tree, to have an expected root value
|
||||||
|
tree, err := NewTree(memory.NewMemoryStorage(), maxLevels, HashFunctionBlake2b)
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
for i := 0; i < len(keys); i++ {
|
||||||
|
err := tree.Add(keys[i], values[i])
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// virtual tree
|
||||||
|
vTree := newVT(maxLevels, HashFunctionBlake2b)
|
||||||
|
|
||||||
|
c.Assert(vTree.root, qt.IsNil)
|
||||||
|
|
||||||
|
err = vTree.addBatch(keys, values)
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
|
||||||
|
// compute hashes, and check Root
|
||||||
|
_, err = vTree.computeHashes()
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
c.Assert(vTree.root.h, qt.DeepEquals, tree.root)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user