Browse Source

Update CaseB to handle repeated keys cases

- Update CaseB to handle repeated keys cases
- Add test for AddBatch/CaseB with repeated keys
- AddBatch-tests abstract code reusage
master
arnaucube 3 years ago
parent
commit
0eda440d93
2 changed files with 100 additions and 21 deletions
  1. +31
    -5
      addbatch.go
  2. +69
    -16
      addbatch_test.go

+ 31
- 5
addbatch.go

@ -309,25 +309,27 @@ func (t *Tree) caseB(nCPU, l int, kvs []kv) ([]int, []kv, error) {
return nil, nil, err return nil, nil, err
} }
// add already existing key-values to the inputted key-values // add already existing key-values to the inputted key-values
kvs = append(kvs, aKvs...)
// kvs = append(kvs, aKvs...)
kvs, invalids := combineInKVSet(aKvs, kvs)
// proceed with CASE A // proceed with CASE A
sortKvs(kvs) sortKvs(kvs)
// cutPowerOfTwo, the excedent add it as normal Tree.Add // cutPowerOfTwo, the excedent add it as normal Tree.Add
kvsP2, kvsNonP2 := cutPowerOfTwo(kvs) kvsP2, kvsNonP2 := cutPowerOfTwo(kvs)
var invalids []int
var invalids2 []int
if nCPU > 1 { if nCPU > 1 {
invalids, err = t.buildTreeBottomUp(nCPU, kvsP2)
invalids2, err = t.buildTreeBottomUp(nCPU, kvsP2)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
} else { } else {
invalids, err = t.buildTreeBottomUpSingleThread(kvsP2)
invalids2, err = t.buildTreeBottomUpSingleThread(kvsP2)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
} }
invalids = append(invalids, invalids2...)
// return the excedents which will be added at the full tree at the end // return the excedents which will be added at the full tree at the end
return invalids, kvsNonP2, nil return invalids, kvsNonP2, nil
} }
@ -357,7 +359,7 @@ func (t *Tree) caseC(nCPU, l int, keysAtL [][]byte, kvs []kv) ([]int, error) {
// 3. do CASE B (with 1 cpu) for each key at level L // 3. do CASE B (with 1 cpu) for each key at level L
_, bucketExcedents, err := bucketTree.caseB(1, l, buckets[cpu]) _, bucketExcedents, err := bucketTree.caseB(1, l, buckets[cpu])
if err != nil { if err != nil {
panic(err)
panic(err) // TODO WIP
// return nil, err // return nil, err
} }
excedentsInBucket[cpu] = bucketExcedents excedentsInBucket[cpu] = bucketExcedents
@ -720,6 +722,30 @@ func highestPowerOfTwo(n int) int {
return res return res
} }
// combineInKVSet combines two kv array in one single array without repeated
// keys.
func combineInKVSet(base, toAdd []kv) ([]kv, []int) {
// TODO this is a naive version, this will be implemented in a more
// efficient way or through maps, or through sorted binary search
r := base
var invalids []int
for i := 0; i < len(toAdd); i++ {
e := false
// check if toAdd[i] exists in the base set
for j := 0; j < len(base); j++ {
if bytes.Equal(toAdd[i].k, base[j].k) {
e = true
}
}
if !e {
r = append(r, toAdd[i])
} else {
invalids = append(invalids, toAdd[i].pos)
}
}
return r, invalids
}
// func computeSimpleAddCost(nLeafs int) int { // func computeSimpleAddCost(nLeafs int) int {
// // nLvls 2^nLvls // // nLvls 2^nLvls
// nLvls := int(math.Log2(float64(nLeafs))) // nLvls := int(math.Log2(float64(nLeafs)))

+ 69
- 16
addbatch_test.go

@ -11,6 +11,15 @@ import (
"github.com/iden3/go-merkletree/db/memory" "github.com/iden3/go-merkletree/db/memory"
) )
var debug = true
func debugTime(descr string, time1, time2 time.Duration) {
if debug {
fmt.Printf("%s was %f times faster than without AddBatch\n",
descr, float64(time1)/float64(time2))
}
}
func testInit(c *qt.C, n int) (*Tree, *Tree) { func testInit(c *qt.C, n int) (*Tree, *Tree) {
tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
@ -35,12 +44,6 @@ func testInit(c *qt.C, n int) (*Tree, *Tree) {
return tree1, tree2 return tree1, tree2
} }
func ratio(t1, t2 time.Duration) float64 {
a := float64(t1)
b := float64(t2)
return (a / b)
}
func TestAddBatchCaseA(t *testing.T) { func TestAddBatchCaseA(t *testing.T) {
c := qt.New(t) c := qt.New(t)
@ -75,8 +78,7 @@ func TestAddBatchCaseA(t *testing.T) {
indexes, err := tree2.AddBatchOpt(keys, values) indexes, err := tree2.AddBatchOpt(keys, values)
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
time2 := time.Since(start) time2 := time.Since(start)
fmt.Printf("CASE A, AddBatch was %f times faster than without AddBatch\n",
ratio(time1, time2))
debugTime("CASE A, AddBatch", time1, time2)
c.Check(len(indexes), qt.Equals, 0) c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal // check that both trees roots are equal
@ -149,14 +151,68 @@ func TestAddBatchCaseB(t *testing.T) {
indexes, err := tree2.AddBatchOpt(keys, values) indexes, err := tree2.AddBatchOpt(keys, values)
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
time2 := time.Since(start) time2 := time.Since(start)
fmt.Printf("CASE B, AddBatch was %f times faster than without AddBatch\n",
ratio(time1, time2))
debugTime("CASE B, AddBatch", time1, time2)
c.Check(len(indexes), qt.Equals, 0) c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal // check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
} }
func TestAddBatchCaseBRepeatedLeafs(t *testing.T) {
c := qt.New(t)
nLeafs := 1024
initialNLeafs := 99 // TMP TODO use const minLeafsThreshold-1 once ready
tree1, tree2 := testInit(c, initialNLeafs)
for i := initialNLeafs; i < nLeafs; i++ {
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(int64(i * 2)))
if err := tree1.Add(k, v); err != nil {
t.Fatal(err)
}
}
// prepare the key-values to be added, including already existing keys
var keys, values [][]byte
for i := 0; i < nLeafs; i++ {
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(int64(i * 2)))
keys = append(keys, k)
values = append(values, v)
}
indexes, err := tree2.AddBatchOpt(keys, values)
c.Assert(err, qt.IsNil)
c.Check(len(indexes), qt.Equals, initialNLeafs)
// check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
}
func TestCombineInKVSet(t *testing.T) {
c := qt.New(t)
var a, b, expected []kv
for i := 0; i < 10; i++ {
k := BigIntToBytes(big.NewInt(int64(i)))
kv := kv{k: k}
if i < 7 {
a = append(a, kv)
}
if i >= 4 {
b = append(b, kv)
}
expected = append(expected, kv)
}
r, invalids := combineInKVSet(a, b)
for i := 0; i < len(r); i++ {
c.Assert(r[i].k, qt.DeepEquals, expected[i].k)
}
c.Assert(len(invalids), qt.Equals, 7-4)
}
func TestGetKeysAtLevel(t *testing.T) { func TestGetKeysAtLevel(t *testing.T) {
c := qt.New(t) c := qt.New(t)
@ -315,8 +371,7 @@ func TestAddBatchCaseC(t *testing.T) {
indexes, err := tree2.AddBatchOpt(keys, values) indexes, err := tree2.AddBatchOpt(keys, values)
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
time2 := time.Since(start) time2 := time.Since(start)
fmt.Printf("CASE C, AddBatch was %f times faster than without AddBatch\n",
ratio(time1, time2))
debugTime("CASE C, AddBatch", time1, time2)
c.Check(len(indexes), qt.Equals, 0) c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal // check that both trees roots are equal
@ -353,8 +408,7 @@ func TestAddBatchCaseD(t *testing.T) {
indexes, err := tree2.AddBatchOpt(keys, values) indexes, err := tree2.AddBatchOpt(keys, values)
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
time2 := time.Since(start) time2 := time.Since(start)
fmt.Printf("CASE D, AddBatch was %f times faster than without AddBatch\n",
ratio(time1, time2))
debugTime("CASE D, AddBatch", time1, time2)
c.Check(len(indexes), qt.Equals, 0) c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal // check that both trees roots are equal
@ -411,8 +465,7 @@ func TestAddBatchCaseE(t *testing.T) {
indexes, err := tree2.AddBatchOpt(keys, values) indexes, err := tree2.AddBatchOpt(keys, values)
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
time2 := time.Since(start) time2 := time.Since(start)
fmt.Printf("CASE E, AddBatch was %f times faster than without AddBatch\n",
ratio(time1, time2))
debugTime("CASE E, AddBatch", time1, time2)
c.Check(len(indexes), qt.Equals, 0) c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal // check that both trees roots are equal

Loading…
Cancel
Save