Browse Source

Update upFromNodes function for unbalanced tree

- Update upFromNodes function for unbalanced tree case
- Add AddBatchTestVector2 & 3 with some edge cases
- Add checkRoots test method, which stores the Dump of the tree to file for after-debug
master
arnaucube 3 years ago
parent
commit
2c62f31446
4 changed files with 243 additions and 25 deletions
  1. +1
    -0
      .gitignore
  2. +109
    -16
      addbatch_test.go
  3. +107
    -0
      helpers_test.go
  4. +26
    -9
      vt.go

+ 1
- 0
.gitignore

@ -0,0 +1 @@
err-dump

+ 109
- 16
addbatch_test.go

@ -72,7 +72,7 @@ func TestAddBatchTreeEmpty(t *testing.T) {
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
tree, err := NewTree(database, 100, HashFunctionPoseidon) tree, err := NewTree(database, 100, HashFunctionPoseidon)
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck //nolint:errcheck
defer tree.db.Close() //nolint:errcheck
bLen := tree.HashFunction().Len() bLen := tree.HashFunction().Len()
start := time.Now() start := time.Now()
@ -89,7 +89,7 @@ func TestAddBatchTreeEmpty(t *testing.T) {
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
tree2, err := NewTree(database2, 100, HashFunctionPoseidon) tree2, err := NewTree(database2, 100, HashFunctionPoseidon)
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck //nolint:errcheck
defer tree2.db.Close() //nolint:errcheck
tree2.dbgInit() tree2.dbgInit()
var keys, values [][]byte var keys, values [][]byte
@ -111,7 +111,7 @@ func TestAddBatchTreeEmpty(t *testing.T) {
c.Check(len(indexes), qt.Equals, 0) c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal // check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree.Root())
checkRoots(c, tree, tree2)
} }
func TestAddBatchTreeEmptyNotPowerOf2(t *testing.T) { func TestAddBatchTreeEmptyNotPowerOf2(t *testing.T) {
@ -152,7 +152,7 @@ func TestAddBatchTreeEmptyNotPowerOf2(t *testing.T) {
c.Check(len(indexes), qt.Equals, 0) c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal // check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree.Root())
checkRoots(c, tree, tree2)
} }
func randomBytes(n int) []byte { func randomBytes(n int) []byte {
@ -164,7 +164,7 @@ func randomBytes(n int) []byte {
return b return b
} }
func TestAddBatchTreeEmptyTestVector(t *testing.T) {
func TestAddBatchTestVector1(t *testing.T) {
c := qt.New(t) c := qt.New(t)
database1, err := db.NewBadgerDB(c.TempDir()) database1, err := db.NewBadgerDB(c.TempDir())
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
@ -203,7 +203,7 @@ func TestAddBatchTreeEmptyTestVector(t *testing.T) {
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
c.Check(len(indexes), qt.Equals, 0) c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal // check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
checkRoots(c, tree1, tree2)
// 2nd test vectors // 2nd test vectors
database1, err = db.NewBadgerDB(c.TempDir()) database1, err = db.NewBadgerDB(c.TempDir())
@ -247,7 +247,100 @@ func TestAddBatchTreeEmptyTestVector(t *testing.T) {
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
c.Check(len(indexes), qt.Equals, 0) c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal // check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
checkRoots(c, tree1, tree2)
}
func TestAddBatchTestVector2(t *testing.T) {
// test vector with unbalanced tree
c := qt.New(t)
database, err := db.NewBadgerDB(c.TempDir())
c.Assert(err, qt.IsNil)
tree1, err := NewTree(database, 100, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck
database2, err := db.NewBadgerDB(c.TempDir())
c.Assert(err, qt.IsNil)
tree2, err := NewTree(database2, 100, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
bLen := tree1.HashFunction().Len()
var keys, values [][]byte
// 1
keys = append(keys, BigIntToBytes(bLen, big.NewInt(int64(1))))
values = append(values, BigIntToBytes(bLen, big.NewInt(int64(1))))
// 2
keys = append(keys, BigIntToBytes(bLen, big.NewInt(int64(2))))
values = append(values, BigIntToBytes(bLen, big.NewInt(int64(2))))
// 3
keys = append(keys, BigIntToBytes(bLen, big.NewInt(int64(3))))
values = append(values, BigIntToBytes(bLen, big.NewInt(int64(3))))
// 5
keys = append(keys, BigIntToBytes(bLen, big.NewInt(int64(5))))
values = append(values, BigIntToBytes(bLen, big.NewInt(int64(5))))
for i := 0; i < len(keys); i++ {
if err := tree1.Add(keys[i], values[i]); err != nil {
t.Fatal(err)
}
}
indexes, err := tree2.AddBatch(keys, values)
c.Assert(err, qt.IsNil)
c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal
checkRoots(c, tree1, tree2)
}
func TestAddBatchTestVector3(t *testing.T) {
// test vector with unbalanced tree
c := qt.New(t)
database, err := db.NewBadgerDB(c.TempDir())
c.Assert(err, qt.IsNil)
tree1, err := NewTree(database, 100, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck
database2, err := db.NewBadgerDB(c.TempDir())
c.Assert(err, qt.IsNil)
tree2, err := NewTree(database2, 100, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
bLen := tree1.HashFunction().Len()
var keys, values [][]byte
// 0
keys = append(keys, BigIntToBytes(bLen, big.NewInt(int64(0))))
values = append(values, BigIntToBytes(bLen, big.NewInt(int64(0))))
// 3
keys = append(keys, BigIntToBytes(bLen, big.NewInt(int64(3))))
values = append(values, BigIntToBytes(bLen, big.NewInt(int64(3))))
// 7
keys = append(keys, BigIntToBytes(bLen, big.NewInt(int64(7))))
values = append(values, BigIntToBytes(bLen, big.NewInt(int64(7))))
// 135
keys = append(keys, BigIntToBytes(bLen, big.NewInt(int64(135))))
values = append(values, BigIntToBytes(bLen, big.NewInt(int64(135))))
for i := 0; i < len(keys); i++ {
if err := tree1.Add(keys[i], values[i]); err != nil {
t.Fatal(err)
}
}
indexes, err := tree2.AddBatch(keys, values)
c.Assert(err, qt.IsNil)
c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal
checkRoots(c, tree1, tree2)
//
// tree1.PrintGraphvizFirstNLevels(nil, 100)
// tree2.PrintGraphvizFirstNLevels(nil, 100)
} }
func TestAddBatchTreeEmptyRandomKeys(t *testing.T) { func TestAddBatchTreeEmptyRandomKeys(t *testing.T) {
@ -283,7 +376,7 @@ func TestAddBatchTreeEmptyRandomKeys(t *testing.T) {
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
c.Check(len(indexes), qt.Equals, 0) c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal // check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
checkRoots(c, tree1, tree2)
} }
func TestAddBatchTreeNotEmptyFewLeafs(t *testing.T) { func TestAddBatchTreeNotEmptyFewLeafs(t *testing.T) {
@ -326,7 +419,7 @@ func TestAddBatchTreeNotEmptyFewLeafs(t *testing.T) {
c.Check(len(indexes), qt.Equals, 0) c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal // check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
checkRoots(c, tree1, tree2)
} }
func TestAddBatchTreeNotEmptyEnoughLeafs(t *testing.T) { func TestAddBatchTreeNotEmptyEnoughLeafs(t *testing.T) {
@ -368,7 +461,7 @@ func TestAddBatchTreeNotEmptyEnoughLeafs(t *testing.T) {
} }
c.Check(len(indexes), qt.Equals, 0) c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal // check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
checkRoots(c, tree1, tree2)
} }
func TestAddBatchTreeEmptyRepeatedLeafs(t *testing.T) { func TestAddBatchTreeEmptyRepeatedLeafs(t *testing.T) {
@ -407,7 +500,7 @@ func TestAddBatchTreeEmptyRepeatedLeafs(t *testing.T) {
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
c.Check(len(indexes), qt.Equals, nRepeatedKeys) c.Check(len(indexes), qt.Equals, nRepeatedKeys)
// check that both trees roots are equal // check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
checkRoots(c, tree1, tree2)
} }
func TestAddBatchTreeNotEmptyFewLeafsRepeatedLeafs(t *testing.T) { func TestAddBatchTreeNotEmptyFewLeafsRepeatedLeafs(t *testing.T) {
@ -439,7 +532,7 @@ func TestAddBatchTreeNotEmptyFewLeafsRepeatedLeafs(t *testing.T) {
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
c.Check(len(indexes), qt.Equals, initialNLeafs) c.Check(len(indexes), qt.Equals, initialNLeafs)
// check that both trees roots are equal // check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
checkRoots(c, tree1, tree2)
} }
func TestSplitInBuckets(t *testing.T) { func TestSplitInBuckets(t *testing.T) {
@ -583,7 +676,7 @@ func TestAddBatchTreeNotEmpty(t *testing.T) {
c.Check(len(indexes), qt.Equals, 0) c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal // check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
checkRoots(c, tree1, tree2)
} }
func TestAddBatchNotEmptyUnbalanced(t *testing.T) { func TestAddBatchNotEmptyUnbalanced(t *testing.T) {
@ -648,7 +741,7 @@ func TestAddBatchNotEmptyUnbalanced(t *testing.T) {
c.Check(len(indexes), qt.Equals, 0) c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal // check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
checkRoots(c, tree1, tree2)
} }
func TestFlp2(t *testing.T) { func TestFlp2(t *testing.T) {
@ -781,8 +874,8 @@ func TestDbgStats(t *testing.T) {
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
c.Assert(len(invalids), qt.Equals, 0) c.Assert(len(invalids), qt.Equals, 0)
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
c.Check(tree3.Root(), qt.DeepEquals, tree1.Root())
checkRoots(c, tree1, tree2)
checkRoots(c, tree1, tree3)
if debug { if debug {
fmt.Println("TestDbgStats") fmt.Println("TestDbgStats")

+ 107
- 0
helpers_test.go

@ -0,0 +1,107 @@
package arbo
import (
"bytes"
"io"
"io/ioutil"
"os"
"testing"
"time"
qt "github.com/frankban/quicktest"
"go.vocdoni.io/dvote/db"
)
func checkRoots(c *qt.C, tree1, tree2 *Tree) {
if !bytes.Equal(tree2.Root(), tree1.Root()) {
dir := "err-dump"
if _, err := os.Stat(dir); os.IsNotExist(err) {
err := os.Mkdir(dir, os.ModePerm)
c.Assert(err, qt.IsNil)
}
// store tree1
storeTree(c, tree1, dir+"/tree1")
// store tree2
storeTree(c, tree2, dir+"/tree2")
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
}
}
func storeTree(c *qt.C, tree *Tree, path string) {
dump, err := tree.Dump(nil)
c.Assert(err, qt.IsNil)
err = ioutil.WriteFile(path+"-"+time.Now().String()+".debug", dump, 0600)
c.Assert(err, qt.IsNil)
}
// nolint:unused
func readTree(c *qt.C, tree *Tree, path string) {
b, err := ioutil.ReadFile(path) //nolint:gosec
c.Assert(err, qt.IsNil)
err = tree.ImportDump(b)
c.Assert(err, qt.IsNil)
}
// nolint:unused
func importDumpLoopAdd(tree *Tree, b []byte) error {
r := bytes.NewReader(b)
var err error
for {
l := make([]byte, 2)
_, err = io.ReadFull(r, l)
if err == io.EOF {
break
} else if err != nil {
return err
}
k := make([]byte, l[0])
_, err = io.ReadFull(r, k)
if err != nil {
return err
}
v := make([]byte, l[1])
_, err = io.ReadFull(r, v)
if err != nil {
return err
}
err = tree.Add(k, v)
if err != nil {
return err
}
}
return nil
}
func TestReadTreeDBG(t *testing.T) {
t.Skip() // test just for debugging purposes, disabled by default
c := qt.New(t)
database1, err := db.NewBadgerDB(c.TempDir())
c.Assert(err, qt.IsNil)
tree1, err := NewTree(database1, 100, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
database2, err := db.NewBadgerDB(c.TempDir())
c.Assert(err, qt.IsNil)
tree2, err := NewTree(database2, 100, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
// tree1 is generated by a loop of .Add
path := "err-dump/tree1-2021-06-03 16:45:54.104449306 +0200 CEST m=+0.073874545.debug"
b, err := ioutil.ReadFile(path)
c.Assert(err, qt.IsNil)
err = importDumpLoopAdd(tree1, b)
c.Assert(err, qt.IsNil)
// tree2 is generated by .AddBatch
path = "err-dump/tree2-2021-06-03 16:45:54.104525519 +0200 CEST m=+0.073950756.debug"
readTree(c, tree2, path)
// tree1.PrintGraphvizFirstNLevels(nil, 6)
// tree2.PrintGraphvizFirstNLevels(nil, 6)
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
}

+ 26
- 9
vt.go

@ -150,7 +150,6 @@ func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) {
} }
// remove the inserted element from buckets[i] // remove the inserted element from buckets[i]
// fmt.Println("rm-ins", inserted)
if inserted != -1 { if inserted != -1 {
buckets[i] = append(buckets[i][:inserted], buckets[i][inserted+1:]...) buckets[i] = append(buckets[i][:inserted], buckets[i][inserted+1:]...)
} }
@ -253,10 +252,19 @@ func upFromNodes(ns []*node) (*node, error) {
var res []*node var res []*node
for i := 0; i < len(ns); i += 2 { for i := 0; i < len(ns); i += 2 {
if ns[i].typ() == vtEmpty && ns[i+1].typ() == vtEmpty {
// if ns[i] == nil && ns[i+1] == nil {
// when both sub nodes are empty, the node is also empty
res = append(res, ns[i]) // empty node
if (ns[i].typ() == vtEmpty && ns[i+1].typ() == vtEmpty) ||
(ns[i].typ() == vtLeaf && ns[i+1].typ() == vtEmpty) {
// when both sub nodes are empty, the parent is also empty
// or
// when 1st sub node is a leaf but the 2nd is empty, the
// leaf is used as parent
res = append(res, ns[i])
continue
}
if ns[i].typ() == vtEmpty && ns[i+1].typ() == vtLeaf {
// when 2nd sub node is a leaf but the 1st is empty, the
// leaf is used as 'parent'
res = append(res, ns[i+1])
continue continue
} }
n := &node{ n := &node{
@ -611,12 +619,21 @@ func (n *node) graphviz(w io.Writer, p *params, nEmpties int) (int, error) {
} }
fmt.Fprintf(w, "\"%p\" [style=filled,label=\"%v\"];\n", n, hex.EncodeToString(leafKey[:nChars])) fmt.Fprintf(w, "\"%p\" [style=filled,label=\"%v\"];\n", n, hex.EncodeToString(leafKey[:nChars]))
k := n.k
v := n.v
if len(n.k) >= nChars {
k = n.k[:nChars]
}
if len(n.v) >= nChars {
v = n.v[:nChars]
}
fmt.Fprintf(w, "\"%p\" -> {\"k:%v\\nv:%v\"}\n", n, fmt.Fprintf(w, "\"%p\" -> {\"k:%v\\nv:%v\"}\n", n,
hex.EncodeToString(n.k[:nChars]),
hex.EncodeToString(n.v[:nChars]))
hex.EncodeToString(k),
hex.EncodeToString(v))
fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n", fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
hex.EncodeToString(n.k[:nChars]),
hex.EncodeToString(n.v[:nChars]))
hex.EncodeToString(k),
hex.EncodeToString(v))
case vtMid: case vtMid:
fmt.Fprintf(w, "\"%p\" [label=\"\"];\n", n) fmt.Fprintf(w, "\"%p\" [label=\"\"];\n", n)

Loading…
Cancel
Save