From 0eda440d93a15809a747318fe630c63838f1b014 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Sat, 8 May 2021 14:52:15 +0200 Subject: [PATCH] Update CaseB to handle repeated keys cases - Update CaseB to handle repeated keys cases - Add test for AddBatch/CaseB with repeated keys - AddBatch-tests abstract code reusage --- addbatch.go | 36 +++++++++++++++++--- addbatch_test.go | 85 +++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 100 insertions(+), 21 deletions(-) diff --git a/addbatch.go b/addbatch.go index fcc8a8e..d083329 100644 --- a/addbatch.go +++ b/addbatch.go @@ -309,25 +309,27 @@ func (t *Tree) caseB(nCPU, l int, kvs []kv) ([]int, []kv, error) { return nil, nil, err } // add already existing key-values to the inputted key-values - kvs = append(kvs, aKvs...) + // kvs = append(kvs, aKvs...) + kvs, invalids := combineInKVSet(aKvs, kvs) // proceed with CASE A sortKvs(kvs) // cutPowerOfTwo, the excedent add it as normal Tree.Add kvsP2, kvsNonP2 := cutPowerOfTwo(kvs) - var invalids []int + var invalids2 []int if nCPU > 1 { - invalids, err = t.buildTreeBottomUp(nCPU, kvsP2) + invalids2, err = t.buildTreeBottomUp(nCPU, kvsP2) if err != nil { return nil, nil, err } } else { - invalids, err = t.buildTreeBottomUpSingleThread(kvsP2) + invalids2, err = t.buildTreeBottomUpSingleThread(kvsP2) if err != nil { return nil, nil, err } } + invalids = append(invalids, invalids2...) // return the excedents which will be added at the full tree at the end return invalids, kvsNonP2, nil } @@ -357,7 +359,7 @@ func (t *Tree) caseC(nCPU, l int, keysAtL [][]byte, kvs []kv) ([]int, error) { // 3. do CASE B (with 1 cpu) for each key at level L _, bucketExcedents, err := bucketTree.caseB(1, l, buckets[cpu]) if err != nil { - panic(err) + panic(err) // TODO WIP // return nil, err } excedentsInBucket[cpu] = bucketExcedents @@ -720,6 +722,30 @@ func highestPowerOfTwo(n int) int { return res } +// combineInKVSet combines two kv array in one single array without repeated +// keys. +func combineInKVSet(base, toAdd []kv) ([]kv, []int) { + // TODO this is a naive version, this will be implemented in a more + // efficient way or through maps, or through sorted binary search + r := base + var invalids []int + for i := 0; i < len(toAdd); i++ { + e := false + // check if toAdd[i] exists in the base set + for j := 0; j < len(base); j++ { + if bytes.Equal(toAdd[i].k, base[j].k) { + e = true + } + } + if !e { + r = append(r, toAdd[i]) + } else { + invalids = append(invalids, toAdd[i].pos) + } + } + return r, invalids +} + // func computeSimpleAddCost(nLeafs int) int { // // nLvls 2^nLvls // nLvls := int(math.Log2(float64(nLeafs))) diff --git a/addbatch_test.go b/addbatch_test.go index fc70686..fdfb327 100644 --- a/addbatch_test.go +++ b/addbatch_test.go @@ -11,6 +11,15 @@ import ( "github.com/iden3/go-merkletree/db/memory" ) +var debug = true + +func debugTime(descr string, time1, time2 time.Duration) { + if debug { + fmt.Printf("%s was %f times faster than without AddBatch\n", + descr, float64(time1)/float64(time2)) + } +} + func testInit(c *qt.C, n int) (*Tree, *Tree) { tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) c.Assert(err, qt.IsNil) @@ -35,12 +44,6 @@ func testInit(c *qt.C, n int) (*Tree, *Tree) { return tree1, tree2 } -func ratio(t1, t2 time.Duration) float64 { - a := float64(t1) - b := float64(t2) - return (a / b) -} - func TestAddBatchCaseA(t *testing.T) { c := qt.New(t) @@ -75,8 +78,7 @@ func TestAddBatchCaseA(t *testing.T) { indexes, err := tree2.AddBatchOpt(keys, values) c.Assert(err, qt.IsNil) time2 := time.Since(start) - fmt.Printf("CASE A, AddBatch was %f times faster than without AddBatch\n", - ratio(time1, time2)) + debugTime("CASE A, AddBatch", time1, time2) c.Check(len(indexes), qt.Equals, 0) // check that both trees roots are equal @@ -149,14 +151,68 @@ func TestAddBatchCaseB(t *testing.T) { indexes, err := tree2.AddBatchOpt(keys, values) c.Assert(err, qt.IsNil) time2 := time.Since(start) - fmt.Printf("CASE B, AddBatch was %f times faster than without AddBatch\n", - ratio(time1, time2)) + debugTime("CASE B, AddBatch", time1, time2) c.Check(len(indexes), qt.Equals, 0) // check that both trees roots are equal c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) } +func TestAddBatchCaseBRepeatedLeafs(t *testing.T) { + c := qt.New(t) + + nLeafs := 1024 + initialNLeafs := 99 // TMP TODO use const minLeafsThreshold-1 once ready + + tree1, tree2 := testInit(c, initialNLeafs) + + for i := initialNLeafs; i < nLeafs; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + if err := tree1.Add(k, v); err != nil { + t.Fatal(err) + } + } + + // prepare the key-values to be added, including already existing keys + var keys, values [][]byte + for i := 0; i < nLeafs; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + keys = append(keys, k) + values = append(values, v) + } + indexes, err := tree2.AddBatchOpt(keys, values) + c.Assert(err, qt.IsNil) + c.Check(len(indexes), qt.Equals, initialNLeafs) + + // check that both trees roots are equal + c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) +} + +func TestCombineInKVSet(t *testing.T) { + c := qt.New(t) + + var a, b, expected []kv + for i := 0; i < 10; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + kv := kv{k: k} + if i < 7 { + a = append(a, kv) + } + if i >= 4 { + b = append(b, kv) + } + expected = append(expected, kv) + } + + r, invalids := combineInKVSet(a, b) + for i := 0; i < len(r); i++ { + c.Assert(r[i].k, qt.DeepEquals, expected[i].k) + } + c.Assert(len(invalids), qt.Equals, 7-4) +} + func TestGetKeysAtLevel(t *testing.T) { c := qt.New(t) @@ -315,8 +371,7 @@ func TestAddBatchCaseC(t *testing.T) { indexes, err := tree2.AddBatchOpt(keys, values) c.Assert(err, qt.IsNil) time2 := time.Since(start) - fmt.Printf("CASE C, AddBatch was %f times faster than without AddBatch\n", - ratio(time1, time2)) + debugTime("CASE C, AddBatch", time1, time2) c.Check(len(indexes), qt.Equals, 0) // check that both trees roots are equal @@ -353,8 +408,7 @@ func TestAddBatchCaseD(t *testing.T) { indexes, err := tree2.AddBatchOpt(keys, values) c.Assert(err, qt.IsNil) time2 := time.Since(start) - fmt.Printf("CASE D, AddBatch was %f times faster than without AddBatch\n", - ratio(time1, time2)) + debugTime("CASE D, AddBatch", time1, time2) c.Check(len(indexes), qt.Equals, 0) // check that both trees roots are equal @@ -411,8 +465,7 @@ func TestAddBatchCaseE(t *testing.T) { indexes, err := tree2.AddBatchOpt(keys, values) c.Assert(err, qt.IsNil) time2 := time.Since(start) - fmt.Printf("CASE E, AddBatch was %f times faster than without AddBatch\n", - ratio(time1, time2)) + debugTime("CASE E, AddBatch", time1, time2) c.Check(len(indexes), qt.Equals, 0) // check that both trees roots are equal