From 890057cd826265f4457fc362b918696d3bb29dcb Mon Sep 17 00:00:00 2001 From: arnaucube Date: Sat, 24 Apr 2021 23:35:06 +0200 Subject: [PATCH] Add AddBatch CaseD CASE D: Already populated Tree ============================== - Use A, B, C, D as subtree - Sort the Keys in Buckets that share the initial part of the path - For each subtree add there the new leafs R / \ / \ / \ * * / | / \ / | / \ / | / \ L: A B C D /\ /\ / \ / \ ... ... ... ... ... ... --- addbatch.go | 45 ++++++++++++++++++++- addbatch_test.go | 102 +++++++++++++++++++++++++++++++++++++++++++++++ tree.go | 10 ++--- tree_test.go | 7 ++++ 4 files changed, 157 insertions(+), 7 deletions(-) diff --git a/addbatch.go b/addbatch.go index 974e0b4..8e03707 100644 --- a/addbatch.go +++ b/addbatch.go @@ -168,7 +168,10 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { return nil, err } + // TODO if nCPU is not a power of two, cut at the highest power of two + // under nCPU nCPU := runtime.NumCPU() + l := int(math.Log2(float64(nCPU))) // CASE A: if nLeafs==0 (root==0) if bytes.Equal(t.root, t.emptyHash) { @@ -212,7 +215,6 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { // CASE C: if nLeafs>=minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold // available parallelization, will need to be a power of 2 (2**n) var excedents []kv - l := int(math.Log2(float64(nCPU))) if nLeafs >= minLeafsThreshold && (nLeafs/nCPU) < minLeafsThreshold { // TODO move to own function // 1. go down until level L (L=log2(nBuckets)) @@ -257,6 +259,11 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { return invalids, nil } + // CASE D + if true { // TODO enter in CASE D if len(keysAtL)=nCPU, if not, CASE E + return t.caseD(nCPU, l, kvs) + } + // TODO store t.root into DB // TODO update NLeafs from DB @@ -289,13 +296,47 @@ func (t *Tree) caseB(l int, kvs []kv) ([]int, []kv, error) { return invalids, kvsNonP2, nil } +func (t *Tree) caseD(nCPU, l int, kvs []kv) ([]int, error) { + fmt.Println("CASE D", nCPU) + keysAtL, err := t.getKeysAtLevel(l + 1) + if err != nil { + return nil, err + } + buckets := splitInBuckets(kvs, nCPU) + + var subRoots [][]byte + var invalids []int + for i := 0; i < len(keysAtL); i++ { + bucketTree := Tree{tx: t.tx, db: t.db, maxLevels: t.maxLevels, // maxLevels-l + hashFunction: t.hashFunction, root: keysAtL[i]} + + for j := 0; j < len(buckets[i]); j++ { + if err = bucketTree.add(l, buckets[i][j].k, buckets[i][j].v); err != nil { + fmt.Println("failed", buckets[i][j].k[:4]) + + panic(err) + // invalids = append(invalids, buckets[i][j].pos) + } + } + subRoots = append(subRoots, bucketTree.root) + } + newRoot, err := t.upFromKeys(subRoots) + if err != nil { + return nil, err + } + t.root = newRoot + + return invalids, nil +} + func splitInBuckets(kvs []kv, nBuckets int) [][]kv { buckets := make([][]kv, nBuckets) // 1. classify the keyvalues into buckets for i := 0; i < len(kvs); i++ { pair := kvs[i] - bucketnum := keyToBucket(pair.k, nBuckets) + // bucketnum := keyToBucket(pair.k, nBuckets) + bucketnum := keyToBucket(pair.keyPath, nBuckets) buckets[bucketnum] = append(buckets[bucketnum], pair) } return buckets diff --git a/addbatch_test.go b/addbatch_test.go index 30b2a03..3c6a7cd 100644 --- a/addbatch_test.go +++ b/addbatch_test.go @@ -11,6 +11,56 @@ import ( "github.com/iden3/go-merkletree/db/memory" ) +func TestBatchAux(t *testing.T) { + c := qt.New(t) + + nLeafs := 16 + + tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + defer tree.db.Close() + + start := time.Now() + for i := 0; i < nLeafs; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + if err := tree.Add(k, v); err != nil { + t.Fatal(err) + } + } + fmt.Println(time.Since(start)) + + tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + defer tree2.db.Close() + + for i := 0; i < 8; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + if err := tree2.Add(k, v); err != nil { + t.Fatal(err) + } + } + // tree.PrintGraphviz(nil) + // tree2.PrintGraphviz(nil) + + var keys, values [][]byte + for i := 8; i < nLeafs; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + keys = append(keys, k) + values = append(values, v) + } + start = time.Now() + indexes, err := tree2.AddBatchOpt(keys, values) + c.Assert(err, qt.IsNil) + fmt.Println(time.Since(start)) + c.Check(len(indexes), qt.Equals, 0) + + // check that both trees roots are equal + c.Check(tree2.Root(), qt.DeepEquals, tree.Root()) +} + func TestAddBatchCaseA(t *testing.T) { c := qt.New(t) @@ -289,6 +339,58 @@ func TestAddBatchCaseC(t *testing.T) { // printLeafs("t2.txt", tree2) } +func TestAddBatchCaseD(t *testing.T) { + c := qt.New(t) + + nLeafs := 8192 + + tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + defer tree.db.Close() + + start := time.Now() + for i := 0; i < nLeafs; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + if err := tree.Add(k, v); err != nil { + t.Fatal(err) + } + } + fmt.Println(time.Since(start)) + + tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + defer tree2.db.Close() + + // add the initial leafs to fill a bit the tree before calling the + // AddBatch method + for i := 0; i < 900; i++ { // TMP TODO use const minLeafsThreshold-1 once ready + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + if err := tree2.Add(k, v); err != nil { + t.Fatal(err) + } + } + // tree2.PrintGraphvizFirstNLevels(nil, 4) + // tree2.PrintGraphviz(nil) + + var keys, values [][]byte + for i := 900; i < nLeafs; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + keys = append(keys, k) + values = append(values, v) + } + start = time.Now() + indexes, err := tree2.AddBatchOpt(keys, values) + c.Assert(err, qt.IsNil) + fmt.Println(time.Since(start)) + c.Check(len(indexes), qt.Equals, 0) + + // check that both trees roots are equal + c.Check(tree2.Root(), qt.DeepEquals, tree.Root()) +} + // func printLeafs(name string, t *Tree) { // w := bytes.NewBufferString("") // diff --git a/tree.go b/tree.go index 26f7892..23bb340 100644 --- a/tree.go +++ b/tree.go @@ -206,7 +206,7 @@ func (t *Tree) add(fromLvl int, k, v []byte) error { t.root = leafKey return nil } - root, err := t.up(leafKey, siblings, path, len(siblings)-1) + root, err := t.up(leafKey, siblings, path, len(siblings)-1, fromLvl) if err != nil { return err } @@ -307,10 +307,10 @@ func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath, } // up goes up recursively updating the intermediate nodes -func (t *Tree) up(key []byte, siblings [][]byte, path []bool, currLvl int) ([]byte, error) { +func (t *Tree) up(key []byte, siblings [][]byte, path []bool, currLvl, toLvl int) ([]byte, error) { var k, v []byte var err error - if path[currLvl] { + if path[currLvl+toLvl] { k, v, err = newIntermediate(t.hashFunction, siblings[currLvl], key) if err != nil { return nil, err @@ -331,7 +331,7 @@ func (t *Tree) up(key []byte, siblings [][]byte, path []bool, currLvl int) ([]by return k, nil } - return t.up(k, siblings, path, currLvl-1) + return t.up(k, siblings, path, currLvl-1, toLvl) } func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) { @@ -440,7 +440,7 @@ func (t *Tree) Update(k, v []byte) error { t.root = leafKey return t.tx.Commit() } - root, err := t.up(leafKey, siblings, path, len(siblings)-1) + root, err := t.up(leafKey, siblings, path, len(siblings)-1, 0) if err != nil { return err } diff --git a/tree_test.go b/tree_test.go index 6d3c6c3..f1647b8 100644 --- a/tree_test.go +++ b/tree_test.go @@ -219,6 +219,13 @@ func TestAux(t *testing.T) { // TODO split in proper tests k = BigIntToBytes(big.NewInt(int64(770))) err = tree.Add(k, v) c.Assert(err, qt.IsNil) + + k = BigIntToBytes(big.NewInt(int64(388))) + err = tree.Add(k, v) + c.Assert(err, qt.IsNil) + k = BigIntToBytes(big.NewInt(int64(900))) + err = tree.Add(k, v) + c.Assert(err, qt.IsNil) // // err = tree.PrintGraphviz(nil) // c.Assert(err, qt.IsNil)