Browse Source

Add AddBatch CaseD

CASE D: Already populated Tree
==============================
- Use A, B, C, D as subtree
- Sort the Keys in Buckets that share the initial part of the path
- For each subtree add there the new leafs

              R
             /  \
            /    \
           /      \
          *        *
         / |      / \
        /  |     /   \
       /   |    /     \
L:    A    B   C       D
     /\   /\  / \     / \
    ...  ... ... ... ... ...
master
arnaucube 3 years ago
parent
commit
890057cd82
4 changed files with 157 additions and 7 deletions
  1. +43
    -2
      addbatch.go
  2. +102
    -0
      addbatch_test.go
  3. +5
    -5
      tree.go
  4. +7
    -0
      tree_test.go

+ 43
- 2
addbatch.go

@ -168,7 +168,10 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) {
return nil, err return nil, err
} }
// TODO if nCPU is not a power of two, cut at the highest power of two
// under nCPU
nCPU := runtime.NumCPU() nCPU := runtime.NumCPU()
l := int(math.Log2(float64(nCPU)))
// CASE A: if nLeafs==0 (root==0) // CASE A: if nLeafs==0 (root==0)
if bytes.Equal(t.root, t.emptyHash) { if bytes.Equal(t.root, t.emptyHash) {
@ -212,7 +215,6 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) {
// CASE C: if nLeafs>=minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold // CASE C: if nLeafs>=minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold
// available parallelization, will need to be a power of 2 (2**n) // available parallelization, will need to be a power of 2 (2**n)
var excedents []kv var excedents []kv
l := int(math.Log2(float64(nCPU)))
if nLeafs >= minLeafsThreshold && (nLeafs/nCPU) < minLeafsThreshold { if nLeafs >= minLeafsThreshold && (nLeafs/nCPU) < minLeafsThreshold {
// TODO move to own function // TODO move to own function
// 1. go down until level L (L=log2(nBuckets)) // 1. go down until level L (L=log2(nBuckets))
@ -257,6 +259,11 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) {
return invalids, nil return invalids, nil
} }
// CASE D
if true { // TODO enter in CASE D if len(keysAtL)=nCPU, if not, CASE E
return t.caseD(nCPU, l, kvs)
}
// TODO store t.root into DB // TODO store t.root into DB
// TODO update NLeafs from DB // TODO update NLeafs from DB
@ -289,13 +296,47 @@ func (t *Tree) caseB(l int, kvs []kv) ([]int, []kv, error) {
return invalids, kvsNonP2, nil return invalids, kvsNonP2, nil
} }
func (t *Tree) caseD(nCPU, l int, kvs []kv) ([]int, error) {
fmt.Println("CASE D", nCPU)
keysAtL, err := t.getKeysAtLevel(l + 1)
if err != nil {
return nil, err
}
buckets := splitInBuckets(kvs, nCPU)
var subRoots [][]byte
var invalids []int
for i := 0; i < len(keysAtL); i++ {
bucketTree := Tree{tx: t.tx, db: t.db, maxLevels: t.maxLevels, // maxLevels-l
hashFunction: t.hashFunction, root: keysAtL[i]}
for j := 0; j < len(buckets[i]); j++ {
if err = bucketTree.add(l, buckets[i][j].k, buckets[i][j].v); err != nil {
fmt.Println("failed", buckets[i][j].k[:4])
panic(err)
// invalids = append(invalids, buckets[i][j].pos)
}
}
subRoots = append(subRoots, bucketTree.root)
}
newRoot, err := t.upFromKeys(subRoots)
if err != nil {
return nil, err
}
t.root = newRoot
return invalids, nil
}
func splitInBuckets(kvs []kv, nBuckets int) [][]kv { func splitInBuckets(kvs []kv, nBuckets int) [][]kv {
buckets := make([][]kv, nBuckets) buckets := make([][]kv, nBuckets)
// 1. classify the keyvalues into buckets // 1. classify the keyvalues into buckets
for i := 0; i < len(kvs); i++ { for i := 0; i < len(kvs); i++ {
pair := kvs[i] pair := kvs[i]
bucketnum := keyToBucket(pair.k, nBuckets)
// bucketnum := keyToBucket(pair.k, nBuckets)
bucketnum := keyToBucket(pair.keyPath, nBuckets)
buckets[bucketnum] = append(buckets[bucketnum], pair) buckets[bucketnum] = append(buckets[bucketnum], pair)
} }
return buckets return buckets

+ 102
- 0
addbatch_test.go

@ -11,6 +11,56 @@ import (
"github.com/iden3/go-merkletree/db/memory" "github.com/iden3/go-merkletree/db/memory"
) )
func TestBatchAux(t *testing.T) {
c := qt.New(t)
nLeafs := 16
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree.db.Close()
start := time.Now()
for i := 0; i < nLeafs; i++ {
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(int64(i * 2)))
if err := tree.Add(k, v); err != nil {
t.Fatal(err)
}
}
fmt.Println(time.Since(start))
tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree2.db.Close()
for i := 0; i < 8; i++ {
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(int64(i * 2)))
if err := tree2.Add(k, v); err != nil {
t.Fatal(err)
}
}
// tree.PrintGraphviz(nil)
// tree2.PrintGraphviz(nil)
var keys, values [][]byte
for i := 8; i < nLeafs; i++ {
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(int64(i * 2)))
keys = append(keys, k)
values = append(values, v)
}
start = time.Now()
indexes, err := tree2.AddBatchOpt(keys, values)
c.Assert(err, qt.IsNil)
fmt.Println(time.Since(start))
c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree.Root())
}
func TestAddBatchCaseA(t *testing.T) { func TestAddBatchCaseA(t *testing.T) {
c := qt.New(t) c := qt.New(t)
@ -289,6 +339,58 @@ func TestAddBatchCaseC(t *testing.T) {
// printLeafs("t2.txt", tree2) // printLeafs("t2.txt", tree2)
} }
func TestAddBatchCaseD(t *testing.T) {
c := qt.New(t)
nLeafs := 8192
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree.db.Close()
start := time.Now()
for i := 0; i < nLeafs; i++ {
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(int64(i * 2)))
if err := tree.Add(k, v); err != nil {
t.Fatal(err)
}
}
fmt.Println(time.Since(start))
tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree2.db.Close()
// add the initial leafs to fill a bit the tree before calling the
// AddBatch method
for i := 0; i < 900; i++ { // TMP TODO use const minLeafsThreshold-1 once ready
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(int64(i * 2)))
if err := tree2.Add(k, v); err != nil {
t.Fatal(err)
}
}
// tree2.PrintGraphvizFirstNLevels(nil, 4)
// tree2.PrintGraphviz(nil)
var keys, values [][]byte
for i := 900; i < nLeafs; i++ {
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(int64(i * 2)))
keys = append(keys, k)
values = append(values, v)
}
start = time.Now()
indexes, err := tree2.AddBatchOpt(keys, values)
c.Assert(err, qt.IsNil)
fmt.Println(time.Since(start))
c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree.Root())
}
// func printLeafs(name string, t *Tree) { // func printLeafs(name string, t *Tree) {
// w := bytes.NewBufferString("") // w := bytes.NewBufferString("")
// //

+ 5
- 5
tree.go

@ -206,7 +206,7 @@ func (t *Tree) add(fromLvl int, k, v []byte) error {
t.root = leafKey t.root = leafKey
return nil return nil
} }
root, err := t.up(leafKey, siblings, path, len(siblings)-1)
root, err := t.up(leafKey, siblings, path, len(siblings)-1, fromLvl)
if err != nil { if err != nil {
return err return err
} }
@ -307,10 +307,10 @@ func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
} }
// up goes up recursively updating the intermediate nodes // up goes up recursively updating the intermediate nodes
func (t *Tree) up(key []byte, siblings [][]byte, path []bool, currLvl int) ([]byte, error) {
func (t *Tree) up(key []byte, siblings [][]byte, path []bool, currLvl, toLvl int) ([]byte, error) {
var k, v []byte var k, v []byte
var err error var err error
if path[currLvl] {
if path[currLvl+toLvl] {
k, v, err = newIntermediate(t.hashFunction, siblings[currLvl], key) k, v, err = newIntermediate(t.hashFunction, siblings[currLvl], key)
if err != nil { if err != nil {
return nil, err return nil, err
@ -331,7 +331,7 @@ func (t *Tree) up(key []byte, siblings [][]byte, path []bool, currLvl int) ([]by
return k, nil return k, nil
} }
return t.up(k, siblings, path, currLvl-1)
return t.up(k, siblings, path, currLvl-1, toLvl)
} }
func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) { func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
@ -440,7 +440,7 @@ func (t *Tree) Update(k, v []byte) error {
t.root = leafKey t.root = leafKey
return t.tx.Commit() return t.tx.Commit()
} }
root, err := t.up(leafKey, siblings, path, len(siblings)-1)
root, err := t.up(leafKey, siblings, path, len(siblings)-1, 0)
if err != nil { if err != nil {
return err return err
} }

+ 7
- 0
tree_test.go

@ -219,6 +219,13 @@ func TestAux(t *testing.T) { // TODO split in proper tests
k = BigIntToBytes(big.NewInt(int64(770))) k = BigIntToBytes(big.NewInt(int64(770)))
err = tree.Add(k, v) err = tree.Add(k, v)
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
k = BigIntToBytes(big.NewInt(int64(388)))
err = tree.Add(k, v)
c.Assert(err, qt.IsNil)
k = BigIntToBytes(big.NewInt(int64(900)))
err = tree.Add(k, v)
c.Assert(err, qt.IsNil)
// //
// err = tree.PrintGraphviz(nil) // err = tree.PrintGraphviz(nil)
// c.Assert(err, qt.IsNil) // c.Assert(err, qt.IsNil)

Loading…
Cancel
Save