mirror of
https://github.com/arnaucube/arbo.git
synced 2026-01-18 02:51:29 +01:00
AddBatch in CaseD, is parallelized (for each CPU) until almost the top level, almost dividing the needed time by the number of CPUs.
609 lines
15 KiB
Go
609 lines
15 KiB
Go
package arbo
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"math"
|
|
"runtime"
|
|
"sort"
|
|
"sync"
|
|
|
|
"github.com/iden3/go-merkletree/db"
|
|
)
|
|
|
|
/*
|
|
|
|
AddBatch design
|
|
===============
|
|
|
|
|
|
CASE A: Empty Tree --> if tree is empty (root==0)
|
|
=================================================
|
|
- Build the full tree from bottom to top (from all the leaf to the root)
|
|
|
|
|
|
CASE B: ALMOST CASE A, Almost empty Tree --> if Tree has numLeafs < minLeafsThreshold
|
|
==============================================================================
|
|
- Get the Leafs (key & value) (iterate the tree from the current root getting
|
|
the leafs)
|
|
- Create a new empty Tree
|
|
- Do CASE A for the new Tree, giving the already existing key&values (leafs)
|
|
from the original Tree + the new key&values to be added from the AddBatch call
|
|
|
|
|
|
R R
|
|
/ \ / \
|
|
A * / \
|
|
/ \ / \
|
|
B C * *
|
|
/ | / \
|
|
/ | / \
|
|
/ | / \
|
|
L: A B G D
|
|
/ \
|
|
/ \
|
|
/ \
|
|
C *
|
|
/ \
|
|
/ \
|
|
/ \
|
|
... ... (nLeafs < minLeafsThreshold)
|
|
|
|
|
|
CASE C: ALMOST CASE B --> if Tree has few Leafs (but numLeafs>=minLeafsThreshold)
|
|
==============================================================================
|
|
- Use A, B, G, F as Roots of subtrees
|
|
- Do CASE B for each subtree
|
|
- Then go from L to the Root
|
|
|
|
R
|
|
/ \
|
|
/ \
|
|
/ \
|
|
* *
|
|
/ | / \
|
|
/ | / \
|
|
/ | / \
|
|
L: A B G D
|
|
/ \
|
|
/ \
|
|
/ \
|
|
C *
|
|
/ \
|
|
/ \
|
|
/ \
|
|
... ... (nLeafs >= minLeafsThreshold)
|
|
|
|
|
|
|
|
CASE D: Already populated Tree
|
|
==============================
|
|
- Use A, B, C, D as subtree
|
|
- Sort the Keys in Buckets that share the initial part of the path
|
|
- For each subtree add there the new leafs
|
|
|
|
R
|
|
/ \
|
|
/ \
|
|
/ \
|
|
* *
|
|
/ | / \
|
|
/ | / \
|
|
/ | / \
|
|
L: A B C D
|
|
/\ /\ / \ / \
|
|
... ... ... ... ... ...
|
|
|
|
|
|
CASE E: Already populated Tree Unbalanced
|
|
=========================================
|
|
- Need to fill M1 and M2, and then will be able to use CASE D
|
|
- Search for M1 & M2 in the inputed Keys
|
|
- Add M1 & M2 to the Tree
|
|
- From here can use CASE D
|
|
|
|
R
|
|
/ \
|
|
/ \
|
|
/ \
|
|
* *
|
|
| \
|
|
| \
|
|
| \
|
|
L: M1 * M2 * (where M1 and M2 are empty)
|
|
/ | /
|
|
/ | /
|
|
/ | /
|
|
A * *
|
|
/ \ | \
|
|
/ \ | \
|
|
/ \ | \
|
|
B * * C
|
|
/ \ |\
|
|
... ... | \
|
|
| \
|
|
D E
|
|
|
|
|
|
|
|
Algorithm decision
|
|
==================
|
|
- if nLeafs==0 (root==0): CASE A
|
|
- if nLeafs<minLeafsThreshold: CASE B
|
|
- if nLeafs>=minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold: CASE C
|
|
- else: CASE D & CASE E
|
|
|
|
|
|
- Multiple tree.Add calls: O(n log n)
|
|
- Used in: cases A, B, C
|
|
- Tree from bottom to top: O(log n)
|
|
- Used in: cases D, E
|
|
|
|
*/
|
|
|
|
const (
|
|
minLeafsThreshold = 100 // nolint:gomnd // TMP WIP this will be autocalculated
|
|
)
|
|
|
|
// AddBatchOpt is the WIP implementation of the AddBatch method in a more
|
|
// optimized approach.
|
|
func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) {
|
|
t.updateAccessTime()
|
|
t.Lock()
|
|
defer t.Unlock()
|
|
|
|
// TODO if len(keys) is not a power of 2, add padding of empty
|
|
// keys&values. Maybe when len(keyvalues) is not a power of 2, cut at
|
|
// the biggest power of 2 under the len(keys), add those 2**n key-values
|
|
// using the AddBatch approach, and then add the remaining key-values
|
|
// using tree.Add.
|
|
|
|
kvs, err := t.keysValuesToKvs(keys, values)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
t.tx, err = t.db.NewTx() // TODO add t.tx.Commit()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// TODO if nCPU is not a power of two, cut at the highest power of two
|
|
// under nCPU
|
|
nCPU := runtime.NumCPU()
|
|
l := int(math.Log2(float64(nCPU)))
|
|
|
|
// CASE A: if nLeafs==0 (root==0)
|
|
if bytes.Equal(t.root, t.emptyHash) {
|
|
// if len(kvs) is not a power of 2, cut at the bigger power
|
|
// of two under len(kvs), build the tree with that, and add
|
|
// later the excedents
|
|
kvsP2, kvsNonP2 := cutPowerOfTwo(kvs)
|
|
invalids, err := t.buildTreeBottomUp(nCPU, kvsP2)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := 0; i < len(kvsNonP2); i++ {
|
|
err = t.add(0, kvsNonP2[i].k, kvsNonP2[i].v)
|
|
if err != nil {
|
|
invalids = append(invalids, kvsNonP2[i].pos)
|
|
}
|
|
}
|
|
return invalids, nil
|
|
}
|
|
|
|
// CASE B: if nLeafs<nBuckets
|
|
nLeafs, err := t.GetNLeafs()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if nLeafs < minLeafsThreshold { // CASE B
|
|
invalids, excedents, err := t.caseB(0, kvs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// add the excedents
|
|
for i := 0; i < len(excedents); i++ {
|
|
err = t.add(0, excedents[i].k, excedents[i].v)
|
|
if err != nil {
|
|
invalids = append(invalids, excedents[i].pos)
|
|
}
|
|
}
|
|
return invalids, nil
|
|
}
|
|
|
|
// CASE C: if nLeafs>=minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold
|
|
// available parallelization, will need to be a power of 2 (2**n)
|
|
var excedents []kv
|
|
if nLeafs >= minLeafsThreshold && (nLeafs/nCPU) < minLeafsThreshold {
|
|
// TODO move to own function
|
|
// 1. go down until level L (L=log2(nBuckets))
|
|
keysAtL, err := t.getKeysAtLevel(l + 1)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
buckets := splitInBuckets(kvs, nCPU)
|
|
|
|
// 2. use keys at level L as roots of the subtrees under each one
|
|
var subRoots [][]byte
|
|
// TODO parallelize
|
|
for i := 0; i < len(keysAtL); i++ {
|
|
bucketTree := Tree{tx: t.tx, db: t.db, maxLevels: t.maxLevels,
|
|
hashFunction: t.hashFunction, root: keysAtL[i]}
|
|
|
|
// 3. and do CASE B for each
|
|
_, bucketExcedents, err := bucketTree.caseB(l, buckets[i])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
excedents = append(excedents, bucketExcedents...)
|
|
subRoots = append(subRoots, bucketTree.root)
|
|
}
|
|
// 4. go upFromKeys from the new roots of the subtrees
|
|
newRoot, err := t.upFromKeys(subRoots)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
t.root = newRoot
|
|
|
|
var invalids []int
|
|
for i := 0; i < len(excedents); i++ {
|
|
// Add until the level L
|
|
err = t.add(0, excedents[i].k, excedents[i].v)
|
|
if err != nil {
|
|
invalids = append(invalids, excedents[i].pos) // TODO WIP
|
|
}
|
|
}
|
|
|
|
return invalids, nil
|
|
}
|
|
|
|
// CASE D
|
|
if true { // TODO enter in CASE D if len(keysAtL)=nCPU, if not, CASE E
|
|
return t.caseD(nCPU, l, kvs)
|
|
}
|
|
|
|
// TODO store t.root into DB
|
|
// TODO update NLeafs from DB
|
|
|
|
return nil, fmt.Errorf("UNIMPLEMENTED")
|
|
}
|
|
|
|
func (t *Tree) caseB(l int, kvs []kv) ([]int, []kv, error) {
|
|
// get already existing keys
|
|
aKs, aVs, err := t.getLeafs(t.root)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
aKvs, err := t.keysValuesToKvs(aKs, aVs)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
// add already existing key-values to the inputted key-values
|
|
kvs = append(kvs, aKvs...)
|
|
|
|
// proceed with CASE A
|
|
sortKvs(kvs)
|
|
|
|
// cutPowerOfTwo, the excedent add it as normal Tree.Add
|
|
kvsP2, kvsNonP2 := cutPowerOfTwo(kvs)
|
|
invalids, err := t.buildTreeBottomUpSingleThread(kvsP2)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
// return the excedents which will be added at the full tree at the end
|
|
return invalids, kvsNonP2, nil
|
|
}
|
|
|
|
func (t *Tree) caseD(nCPU, l int, kvs []kv) ([]int, error) {
|
|
keysAtL, err := t.getKeysAtLevel(l + 1)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
buckets := splitInBuckets(kvs, nCPU)
|
|
|
|
subRoots := make([][]byte, nCPU)
|
|
invalidsInBucket := make([][]int, nCPU)
|
|
txs := make([]db.Tx, nCPU)
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(nCPU)
|
|
for i := 0; i < nCPU; i++ {
|
|
go func(cpu int) {
|
|
var err error
|
|
txs[cpu], err = t.db.NewTx()
|
|
if err != nil {
|
|
panic(err) // TODO
|
|
}
|
|
bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels, // maxLevels-l
|
|
hashFunction: t.hashFunction, root: keysAtL[cpu]}
|
|
|
|
for j := 0; j < len(buckets[cpu]); j++ {
|
|
if err = bucketTree.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil {
|
|
fmt.Println("failed", buckets[cpu][j].k[:4])
|
|
invalidsInBucket[cpu] = append(invalidsInBucket[cpu], buckets[cpu][j].pos)
|
|
}
|
|
}
|
|
subRoots[cpu] = bucketTree.root
|
|
wg.Done()
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
|
|
newRoot, err := t.upFromKeys(subRoots)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
t.root = newRoot
|
|
|
|
var invalids []int
|
|
for i := 0; i < len(invalidsInBucket); i++ {
|
|
invalids = append(invalids, invalidsInBucket[i]...)
|
|
}
|
|
|
|
return invalids, nil
|
|
}
|
|
|
|
func splitInBuckets(kvs []kv, nBuckets int) [][]kv {
|
|
buckets := make([][]kv, nBuckets)
|
|
// 1. classify the keyvalues into buckets
|
|
for i := 0; i < len(kvs); i++ {
|
|
pair := kvs[i]
|
|
|
|
// bucketnum := keyToBucket(pair.k, nBuckets)
|
|
bucketnum := keyToBucket(pair.keyPath, nBuckets)
|
|
buckets[bucketnum] = append(buckets[bucketnum], pair)
|
|
}
|
|
return buckets
|
|
}
|
|
|
|
// TODO rename in a more 'real' name (calculate bucket from/for key)
|
|
func keyToBucket(k []byte, nBuckets int) int {
|
|
nLevels := int(math.Log2(float64(nBuckets)))
|
|
b := make([]int, nBuckets)
|
|
for i := 0; i < nBuckets; i++ {
|
|
b[i] = i
|
|
}
|
|
r := b
|
|
mid := len(r) / 2 //nolint:gomnd
|
|
for i := 0; i < nLevels; i++ {
|
|
if int(k[i/8]&(1<<(i%8))) != 0 {
|
|
r = r[mid:]
|
|
mid = len(r) / 2 //nolint:gomnd
|
|
} else {
|
|
r = r[:mid]
|
|
mid = len(r) / 2 //nolint:gomnd
|
|
}
|
|
}
|
|
return r[0]
|
|
}
|
|
|
|
type kv struct {
|
|
pos int // original position in the array
|
|
keyPath []byte
|
|
k []byte
|
|
v []byte
|
|
}
|
|
|
|
// compareBytes compares byte slices where the bytes are compared from left to
|
|
// right and each byte is compared by bit from right to left
|
|
func compareBytes(a, b []byte) bool {
|
|
// WIP
|
|
for i := 0; i < len(a); i++ {
|
|
for j := 0; j < 8; j++ {
|
|
aBit := a[i] & (1 << j)
|
|
bBit := b[i] & (1 << j)
|
|
if aBit > bBit {
|
|
return false
|
|
} else if aBit < bBit {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// sortKvs sorts the kv by path
|
|
func sortKvs(kvs []kv) {
|
|
sort.Slice(kvs, func(i, j int) bool {
|
|
return compareBytes(kvs[i].keyPath, kvs[j].keyPath)
|
|
})
|
|
}
|
|
|
|
func (t *Tree) keysValuesToKvs(ks, vs [][]byte) ([]kv, error) {
|
|
if len(ks) != len(vs) {
|
|
return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
|
|
len(ks), len(vs))
|
|
}
|
|
kvs := make([]kv, len(ks))
|
|
for i := 0; i < len(ks); i++ {
|
|
keyPath := make([]byte, t.hashFunction.Len())
|
|
copy(keyPath[:], ks[i])
|
|
kvs[i].pos = i
|
|
kvs[i].keyPath = ks[i]
|
|
kvs[i].k = ks[i]
|
|
kvs[i].v = vs[i]
|
|
}
|
|
|
|
return kvs, nil
|
|
}
|
|
|
|
/*
|
|
func (t *Tree) kvsToKeysValues(kvs []kv) ([][]byte, [][]byte) {
|
|
ks := make([][]byte, len(kvs))
|
|
vs := make([][]byte, len(kvs))
|
|
for i := 0; i < len(kvs); i++ {
|
|
ks[i] = kvs[i].k
|
|
vs[i] = kvs[i].v
|
|
}
|
|
return ks, vs
|
|
}
|
|
*/
|
|
|
|
// buildTreeBottomUp splits the key-values into n Buckets (where n is the number
|
|
// of CPUs), in parallel builds a subtree for each bucket, once all the subtrees
|
|
// are built, uses the subtrees roots as keys for a new tree, which as result
|
|
// will have the complete Tree build from bottom to up, where until the
|
|
// log2(nCPU) level it has been computed in parallel.
|
|
func (t *Tree) buildTreeBottomUp(nCPU int, kvs []kv) ([]int, error) {
|
|
buckets := splitInBuckets(kvs, nCPU)
|
|
subRoots := make([][]byte, nCPU)
|
|
invalidsInBucket := make([][]int, nCPU)
|
|
txs := make([]db.Tx, nCPU)
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(nCPU)
|
|
for i := 0; i < nCPU; i++ {
|
|
go func(cpu int) {
|
|
sortKvs(buckets[cpu])
|
|
|
|
var err error
|
|
txs[cpu], err = t.db.NewTx()
|
|
if err != nil {
|
|
panic(err) // TODO
|
|
}
|
|
bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels,
|
|
hashFunction: t.hashFunction, root: t.emptyHash}
|
|
|
|
currInvalids, err := bucketTree.buildTreeBottomUpSingleThread(buckets[cpu])
|
|
if err != nil {
|
|
panic(err) // TODO
|
|
}
|
|
invalidsInBucket[cpu] = currInvalids
|
|
subRoots[cpu] = bucketTree.root
|
|
wg.Done()
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
|
|
newRoot, err := t.upFromKeys(subRoots)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
t.root = newRoot
|
|
|
|
var invalids []int
|
|
for i := 0; i < len(invalidsInBucket); i++ {
|
|
invalids = append(invalids, invalidsInBucket[i]...)
|
|
}
|
|
return invalids, err
|
|
}
|
|
|
|
// buildTreeBottomUpSingleThread builds the tree with the given []kv from bottom
|
|
// to the root. keys & values must be sorted by path, and the array ks must be
|
|
// length multiple of 2
|
|
func (t *Tree) buildTreeBottomUpSingleThread(kvs []kv) ([]int, error) {
|
|
// TODO check that log2(len(leafs)) < t.maxLevels, if not, maxLevels
|
|
// would be reached and should return error
|
|
|
|
var invalids []int
|
|
// build the leafs
|
|
leafKeys := make([][]byte, len(kvs))
|
|
for i := 0; i < len(kvs); i++ {
|
|
// TODO handle the case where Key&Value == 0
|
|
leafKey, leafValue, err := newLeafValue(t.hashFunction, kvs[i].k, kvs[i].v)
|
|
if err != nil {
|
|
// return nil, err
|
|
invalids = append(invalids, kvs[i].pos)
|
|
}
|
|
// store leafKey & leafValue to db
|
|
if err := t.tx.Put(leafKey, leafValue); err != nil {
|
|
// return nil, err
|
|
invalids = append(invalids, kvs[i].pos)
|
|
}
|
|
leafKeys[i] = leafKey
|
|
}
|
|
r, err := t.upFromKeys(leafKeys)
|
|
if err != nil {
|
|
return invalids, err
|
|
}
|
|
t.root = r
|
|
return invalids, nil
|
|
}
|
|
|
|
// keys & values must be sorted by path, and the array ks must be length
|
|
// multiple of 2
|
|
func (t *Tree) upFromKeys(ks [][]byte) ([]byte, error) {
|
|
if len(ks) == 1 {
|
|
return ks[0], nil
|
|
}
|
|
|
|
var rKs [][]byte
|
|
for i := 0; i < len(ks); i += 2 {
|
|
// TODO handle the case where Key&Value == 0
|
|
k, v, err := newIntermediate(t.hashFunction, ks[i], ks[i+1])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// store k-v to db
|
|
if err = t.tx.Put(k, v); err != nil {
|
|
return nil, err
|
|
}
|
|
rKs = append(rKs, k)
|
|
}
|
|
return t.upFromKeys(rKs)
|
|
}
|
|
|
|
func (t *Tree) getLeafs(root []byte) ([][]byte, [][]byte, error) {
|
|
var ks, vs [][]byte
|
|
err := t.iter(root, func(k, v []byte) {
|
|
if v[0] != PrefixValueLeaf {
|
|
return
|
|
}
|
|
leafK, leafV := readLeafValue(v)
|
|
ks = append(ks, leafK)
|
|
vs = append(vs, leafV)
|
|
})
|
|
return ks, vs, err
|
|
}
|
|
|
|
func (t *Tree) getKeysAtLevel(l int) ([][]byte, error) {
|
|
var keys [][]byte
|
|
err := t.iterWithStop(t.root, 0, func(currLvl int, k, v []byte) bool {
|
|
if currLvl == l {
|
|
keys = append(keys, k)
|
|
}
|
|
if currLvl >= l {
|
|
return true // to stop the iter from going down
|
|
}
|
|
return false
|
|
})
|
|
|
|
return keys, err
|
|
}
|
|
|
|
// cutPowerOfTwo returns []kv of length that is a power of 2, and a second []kv
|
|
// with the extra elements that don't fit in a power of 2 length
|
|
func cutPowerOfTwo(kvs []kv) ([]kv, []kv) {
|
|
x := len(kvs)
|
|
if (x & (x - 1)) != 0 {
|
|
p2 := highestPowerOfTwo(x)
|
|
return kvs[:p2], kvs[p2:]
|
|
}
|
|
return kvs, nil
|
|
}
|
|
|
|
func highestPowerOfTwo(n int) int {
|
|
res := 0
|
|
for i := n; i >= 1; i-- {
|
|
if (i & (i - 1)) == 0 {
|
|
res = i
|
|
break
|
|
}
|
|
}
|
|
return res
|
|
}
|
|
|
|
// func computeSimpleAddCost(nLeafs int) int {
|
|
// // nLvls 2^nLvls
|
|
// nLvls := int(math.Log2(float64(nLeafs)))
|
|
// return nLvls * int(math.Pow(2, float64(nLvls)))
|
|
// }
|
|
//
|
|
// func computeBottomUpAddCost(nLeafs int) int {
|
|
// // 2^nLvls * 2 - 1
|
|
// nLvls := int(math.Log2(float64(nLeafs)))
|
|
// return (int(math.Pow(2, float64(nLvls))) * 2) - 1
|
|
// }
|