From 91a98bf18d8562dc44df76eb2f2cf67ccc810fbe Mon Sep 17 00:00:00 2001 From: arnaucube Date: Sun, 25 Apr 2021 17:42:45 +0200 Subject: [PATCH] AddBatch: commit tx at end,allow batch w/ len!=2^n --- addbatch.go | 69 +++++++++++++++++++++++++++++++++++++++++------- addbatch_test.go | 62 ++++++++++++++++++++++++++++++------------- 2 files changed, 103 insertions(+), 28 deletions(-) diff --git a/addbatch.go b/addbatch.go index 032b587..4f0987b 100644 --- a/addbatch.go +++ b/addbatch.go @@ -189,6 +189,13 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { invalids = append(invalids, kvsNonP2[i].pos) } } + // store root to db + if err := t.tx.Put(dbKeyRoot, t.root); err != nil { + return nil, err + } + if err = t.tx.Commit(); err != nil { + return nil, err + } return invalids, nil } @@ -209,6 +216,13 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { invalids = append(invalids, excedents[i].pos) } } + // store root to db + if err := t.tx.Put(dbKeyRoot, t.root); err != nil { + return nil, err + } + if err = t.tx.Commit(); err != nil { + return nil, err + } return invalids, nil } @@ -247,6 +261,7 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { } t.root = newRoot + // add the key-values that have not been used yet var invalids []int for i := 0; i < len(excedents); i++ { // Add until the level L @@ -255,15 +270,40 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { invalids = append(invalids, excedents[i].pos) // TODO WIP } } + // store root to db + if err := t.tx.Put(dbKeyRoot, t.root); err != nil { + return nil, err + } + if err = t.tx.Commit(); err != nil { + return nil, err + } return invalids, nil } + keysAtL, err := t.getKeysAtLevel(l + 1) + if err != nil { + return nil, err + } // CASE D - if true { // TODO enter in CASE D if len(keysAtL)=nCPU, if not, CASE E - return t.caseD(nCPU, l, kvs) + if len(keysAtL) == nCPU { // enter in CASE D if len(keysAtL)=nCPU, if not, CASE E + invalids, err := t.caseD(nCPU, l, keysAtL, kvs) + if err != nil { + return nil, err + } + // store root to db + if err := t.tx.Put(dbKeyRoot, t.root); err != nil { + return nil, err + } + + if err = t.tx.Commit(); err != nil { + return nil, err + } + return invalids, nil } + // CASE E: add one key of each bucket, and then do CASE D + // TODO store t.root into DB // TODO update NLeafs from DB @@ -296,11 +336,7 @@ func (t *Tree) caseB(l int, kvs []kv) ([]int, []kv, error) { return invalids, kvsNonP2, nil } -func (t *Tree) caseD(nCPU, l int, kvs []kv) ([]int, error) { - keysAtL, err := t.getKeysAtLevel(l + 1) - if err != nil { - return nil, err - } +func (t *Tree) caseD(nCPU, l int, keysAtL [][]byte, kvs []kv) ([]int, error) { buckets := splitInBuckets(kvs, nCPU) subRoots := make([][]byte, nCPU) @@ -314,14 +350,13 @@ func (t *Tree) caseD(nCPU, l int, kvs []kv) ([]int, error) { var err error txs[cpu], err = t.db.NewTx() if err != nil { - panic(err) // TODO + panic(err) // TODO WIP } - bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels, // maxLevels-l + bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels - l, // maxLevels-l hashFunction: t.hashFunction, root: keysAtL[cpu]} for j := 0; j < len(buckets[cpu]); j++ { if err = bucketTree.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil { - fmt.Println("failed", buckets[cpu][j].k[:4]) invalidsInBucket[cpu] = append(invalidsInBucket[cpu], buckets[cpu][j].pos) } } @@ -331,6 +366,13 @@ func (t *Tree) caseD(nCPU, l int, kvs []kv) ([]int, error) { } wg.Wait() + // merge buckets txs into Tree.tx + for i := 0; i < len(txs); i++ { + if err := t.tx.Add(txs[i]); err != nil { + return nil, err + } + } + newRoot, err := t.upFromKeys(subRoots) if err != nil { return nil, err @@ -477,6 +519,13 @@ func (t *Tree) buildTreeBottomUp(nCPU int, kvs []kv) ([]int, error) { } wg.Wait() + // merge buckets txs into Tree.tx + for i := 0; i < len(txs); i++ { + if err := t.tx.Add(txs[i]); err != nil { + return nil, err + } + } + newRoot, err := t.upFromKeys(subRoots) if err != nil { return nil, err diff --git a/addbatch_test.go b/addbatch_test.go index 6830445..9e1d488 100644 --- a/addbatch_test.go +++ b/addbatch_test.go @@ -78,7 +78,7 @@ func TestAddBatchCaseA(t *testing.T) { t.Fatal(err) } } - fmt.Println(time.Since(start)) + fmt.Println("time elapsed without CASE A: ", time.Since(start)) tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) c.Assert(err, qt.IsNil) @@ -94,7 +94,43 @@ func TestAddBatchCaseA(t *testing.T) { start = time.Now() indexes, err := tree2.AddBatchOpt(keys, values) c.Assert(err, qt.IsNil) - fmt.Println(time.Since(start)) + fmt.Println("time elapsed with CASE A: ", time.Since(start)) + c.Check(len(indexes), qt.Equals, 0) + + // check that both trees roots are equal + c.Check(tree2.Root(), qt.DeepEquals, tree.Root()) +} + +func TestAddBatchCaseANotPowerOf2(t *testing.T) { + c := qt.New(t) + + nLeafs := 1027 + + tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + defer tree.db.Close() + + for i := 0; i < nLeafs; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + if err := tree.Add(k, v); err != nil { + t.Fatal(err) + } + } + + tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + defer tree2.db.Close() + + var keys, values [][]byte + for i := 0; i < nLeafs; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + keys = append(keys, k) + values = append(values, v) + } + indexes, err := tree2.AddBatchOpt(keys, values) + c.Assert(err, qt.IsNil) c.Check(len(indexes), qt.Equals, 0) // check that both trees roots are equal @@ -118,7 +154,7 @@ func TestAddBatchCaseB(t *testing.T) { t.Fatal(err) } } - fmt.Println(time.Since(start)) + fmt.Println("time elapsed without CASE B: ", time.Since(start)) tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) c.Assert(err, qt.IsNil) @@ -144,7 +180,7 @@ func TestAddBatchCaseB(t *testing.T) { start = time.Now() indexes, err := tree2.AddBatchOpt(keys, values) c.Assert(err, qt.IsNil) - fmt.Println(time.Since(start)) + fmt.Println("time elapsed with CASE B: ", time.Since(start)) c.Check(len(indexes), qt.Equals, 0) // check that both trees roots are equal @@ -296,7 +332,7 @@ func TestAddBatchCaseC(t *testing.T) { t.Fatal(err) } } - fmt.Println(time.Since(start)) + fmt.Println("time elapsed without CASE C: ", time.Since(start)) tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) c.Assert(err, qt.IsNil) @@ -311,7 +347,6 @@ func TestAddBatchCaseC(t *testing.T) { t.Fatal(err) } } - // tree2.PrintGraphviz(nil) var keys, values [][]byte for i := 101; i < nLeafs; i++ { @@ -323,20 +358,11 @@ func TestAddBatchCaseC(t *testing.T) { start = time.Now() indexes, err := tree2.AddBatchOpt(keys, values) c.Assert(err, qt.IsNil) - fmt.Println(time.Since(start)) + fmt.Println("time elapsed with CASE C: ", time.Since(start)) c.Check(len(indexes), qt.Equals, 0) // check that both trees roots are equal c.Check(tree2.Root(), qt.DeepEquals, tree.Root()) - - // tree.PrintGraphviz(nil) - // tree2.PrintGraphviz(nil) - // // tree.PrintGraphvizFirstNLevels(nil, 4) - // // tree2.PrintGraphvizFirstNLevels(nil, 4) - // fmt.Println("TREE") - // printLeafs("t1.txt", tree) - // fmt.Println("TREE2") - // printLeafs("t2.txt", tree2) } func TestAddBatchCaseD(t *testing.T) { @@ -356,7 +382,7 @@ func TestAddBatchCaseD(t *testing.T) { t.Fatal(err) } } - fmt.Println(time.Since(start)) + fmt.Println("time elapsed without CASE D: ", time.Since(start)) tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) c.Assert(err, qt.IsNil) @@ -382,7 +408,7 @@ func TestAddBatchCaseD(t *testing.T) { start = time.Now() indexes, err := tree2.AddBatchOpt(keys, values) c.Assert(err, qt.IsNil) - fmt.Println(time.Since(start)) + fmt.Println("time elapsed with CASE D: ", time.Since(start)) c.Check(len(indexes), qt.Equals, 0) // check that both trees roots are equal