Browse Source

AddBatch use Virtual Tree for empty trees/subtrees

- AddBatch use Virtual Tree for cases A,B,C
- ImportDump use AddBatch instead of adding one by one
- Reorg & add more virtual tree tests
master
arnaucube 3 years ago
parent
commit
03bb9f7447
6 changed files with 329 additions and 84 deletions
  1. +39
    -25
      addbatch.go
  2. +196
    -0
      addbatch_test.go
  3. +6
    -20
      tree.go
  4. +7
    -7
      tree_test.go
  5. +9
    -3
      vt.go
  6. +72
    -29
      vt_test.go

+ 39
- 25
addbatch.go

@ -325,7 +325,7 @@ func (t *Tree) caseB(nCPU, l int, kvs []kv) ([]int, []kv, error) {
return nil, nil, err
}
} else {
invalids2, err = t.buildTreeBottomUpSingleThread(kvsP2)
invalids2, err = t.buildTreeBottomUpSingleThread(l, kvsP2)
if err != nil {
return nil, nil, err
}
@ -354,6 +354,9 @@ func (t *Tree) caseC(nCPU, l int, keysAtL [][]byte, kvs []kv) ([]int, error) {
if err != nil {
panic(err) // TODO WIP
}
if err := txs[cpu].Add(t.tx); err != nil {
panic(err) // TODO
}
bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels,
hashFunction: t.hashFunction, root: keysAtL[cpu]}
@ -567,6 +570,7 @@ func (t *Tree) kvsToKeysValues(kvs []kv) ([][]byte, [][]byte) {
// will have the complete Tree build from bottom to up, where until the
// log2(nCPU) level it has been computed in parallel.
func (t *Tree) buildTreeBottomUp(nCPU int, kvs []kv) ([]int, error) {
l := int(math.Log2(float64(nCPU)))
buckets := splitInBuckets(kvs, nCPU)
subRoots := make([][]byte, nCPU)
@ -584,10 +588,13 @@ func (t *Tree) buildTreeBottomUp(nCPU int, kvs []kv) ([]int, error) {
if err != nil {
panic(err) // TODO
}
if err := txs[cpu].Add(t.tx); err != nil {
panic(err) // TODO
}
bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels,
hashFunction: t.hashFunction, root: t.emptyHash}
currInvalids, err := bucketTree.buildTreeBottomUpSingleThread(buckets[cpu])
currInvalids, err := bucketTree.buildTreeBottomUpSingleThread(l, buckets[cpu])
if err != nil {
panic(err) // TODO
}
@ -615,39 +622,42 @@ func (t *Tree) buildTreeBottomUp(nCPU int, kvs []kv) ([]int, error) {
for i := 0; i < len(invalidsInBucket); i++ {
invalids = append(invalids, invalidsInBucket[i]...)
}
return invalids, err
}
// buildTreeBottomUpSingleThread builds the tree with the given []kv from bottom
// to the root. keys & values must be sorted by path, and the array ks must be
// length multiple of 2
func (t *Tree) buildTreeBottomUpSingleThread(kvs []kv) ([]int, error) {
// to the root
func (t *Tree) buildTreeBottomUpSingleThread(l int, kvsRaw []kv) ([]int, error) {
// TODO check that log2(len(leafs)) < t.maxLevels, if not, maxLevels
// would be reached and should return error
if len(kvsRaw) == 0 {
return nil, nil
}
var invalids []int
// build the leafs
leafKeys := make([][]byte, len(kvs))
for i := 0; i < len(kvs); i++ {
// TODO handle the case where Key&Value == 0
leafKey, leafValue, err := newLeafValue(t.hashFunction, kvs[i].k, kvs[i].v)
if err != nil {
// return nil, err
invalids = append(invalids, kvs[i].pos)
}
// store leafKey & leafValue to db
if err := t.tx.Put(leafKey, leafValue); err != nil {
// return nil, err
invalids = append(invalids, kvs[i].pos)
vt := newVT(t.maxLevels, t.hashFunction)
for i := 0; i < len(kvsRaw); i++ {
if err := vt.add(l, kvsRaw[i].k, kvsRaw[i].v); err != nil {
return nil, err
}
leafKeys[i] = leafKey
}
r, err := t.upFromKeys(leafKeys)
pairs, err := vt.computeHashes()
if err != nil {
return invalids, err
return nil, err
}
t.root = r
return invalids, nil
// store pairs in db
for i := 0; i < len(pairs); i++ {
if err := t.tx.Put(pairs[i][0], pairs[i][1]); err != nil {
return nil, err
}
}
// set tree.root from the virtual tree root
t.root = vt.root.h
return nil, nil // TODO invalids
}
// keys & values must be sorted by path, and the array ks must be length
@ -659,7 +669,11 @@ func (t *Tree) upFromKeys(ks [][]byte) ([]byte, error) {
var rKs [][]byte
for i := 0; i < len(ks); i += 2 {
// TODO handle the case where Key&Value == 0
if bytes.Equal(ks[i], t.emptyHash) && bytes.Equal(ks[i+1], t.emptyHash) {
// when both sub keys are empty, the key is also empty
rKs = append(rKs, t.emptyHash)
continue
}
k, v, err := newIntermediate(t.hashFunction, ks[i], ks[i+1])
if err != nil {
return nil, err

+ 196
- 0
addbatch_test.go

@ -1,6 +1,7 @@
package arbo
import (
"crypto/rand"
"encoding/hex"
"fmt"
"math/big"
@ -121,6 +122,201 @@ func TestAddBatchCaseANotPowerOf2(t *testing.T) {
c.Check(tree2.Root(), qt.DeepEquals, tree.Root())
}
func randomBytes(n int) []byte {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
panic(err)
}
return b
}
func TestBuildTreeBottomUpSingleThread(t *testing.T) {
c := qt.New(t)
tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree1.db.Close()
tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree2.db.Close()
testvectorKeys := []string{
"1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642",
"2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf",
"1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5",
"d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7",
}
var keys, values [][]byte
for i := 0; i < len(testvectorKeys); i++ {
key, err := hex.DecodeString(testvectorKeys[i])
c.Assert(err, qt.IsNil)
keys = append(keys, key)
values = append(values, []byte{0})
}
for i := 0; i < len(keys); i++ {
if err := tree1.Add(keys[i], values[i]); err != nil {
t.Fatal(err)
}
}
kvs, err := tree2.keysValuesToKvs(keys, values)
c.Assert(err, qt.IsNil)
sortKvs(kvs)
tree2.tx, err = tree2.db.NewTx()
c.Assert(err, qt.IsNil)
// indexes, err := tree2.buildTreeBottomUpSingleThread(kvs)
indexes, err := tree2.buildTreeBottomUp(4, kvs)
c.Assert(err, qt.IsNil)
// tree1.PrintGraphviz(nil)
// tree2.PrintGraphviz(nil)
c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
// 15b6a23945ae6c81342b7eb14e70fff50812dc8791cb15ec791eb08f91784139
}
func TestAddBatchCaseATestVector(t *testing.T) {
c := qt.New(t)
tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree1.db.Close()
tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree2.db.Close()
// leafs in 2nd level subtrees: [ 6, 0, 1, 1]
testvectorKeys := []string{
"1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642",
"2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf",
"1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5",
"d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7",
}
var keys, values [][]byte
for i := 0; i < len(testvectorKeys); i++ {
key, err := hex.DecodeString(testvectorKeys[i])
c.Assert(err, qt.IsNil)
keys = append(keys, key)
values = append(values, []byte{0})
}
for i := 0; i < len(keys); i++ {
if err := tree1.Add(keys[i], values[i]); err != nil {
t.Fatal(err)
}
}
indexes, err := tree2.AddBatch(keys, values)
c.Assert(err, qt.IsNil)
// tree1.PrintGraphviz(nil)
// tree2.PrintGraphviz(nil)
c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal
// fmt.Println(hex.EncodeToString(tree1.Root()))
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
//////
// tree1, err = NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b)
// c.Assert(err, qt.IsNil)
// defer tree1.db.Close()
//
// tree2, err = NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b)
// c.Assert(err, qt.IsNil)
// defer tree2.db.Close()
//
// // leafs in 2nd level subtrees: [ 6, 0, 1, 1]
// testvectorKeys = []string{
// "1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642",
// "2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf",
// "9cb87ec67e875c61390edcd1ab517f443591047709a4d4e45b0f9ed980857b8e",
// "9b4e9e92e974a589f426ceeb4cb291dc24893513fecf8e8460992dcf52621d4d",
// "1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5",
// "d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7",
// "3cd55dbfb8f975f20a0925dfbdabe79fa2d51dd0268afbb8ba6b01de9dfcdd3c",
// "5d0a9d6d9f197c091bf054fac9cb60e11ec723d6610ed8578e617b4d46cb43d5",
// }
// keys = [][]byte{}
// values = [][]byte{}
// for i := 0; i < len(testvectorKeys); i++ {
// key, err := hex.DecodeString(testvectorKeys[i])
// c.Assert(err, qt.IsNil)
// keys = append(keys, key)
// values = append(values, []byte{0})
// }
//
// for i := 0; i < len(keys); i++ {
// if err := tree1.Add(keys[i], values[i]); err != nil {
// t.Fatal(err)
// }
// }
//
// indexes, err = tree2.AddBatch(keys, values)
// c.Assert(err, qt.IsNil)
// tree1.PrintGraphviz(nil)
// tree2.PrintGraphviz(nil)
//
// c.Check(len(indexes), qt.Equals, 0)
//
// // check that both trees roots are equal
// // c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
}
func TestAddBatchCaseARandomKeys(t *testing.T) {
c := qt.New(t)
nLeafs := 8
tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree1.db.Close()
tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree2.db.Close()
var keys, values [][]byte
for i := 0; i < nLeafs; i++ {
keys = append(keys, randomBytes(32))
// values = append(values, randomBytes(32))
values = append(values, []byte{0})
// fmt.Println("K", hex.EncodeToString(keys[i]))
}
// TMP:
keys[0], _ = hex.DecodeString("1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642")
keys[1], _ = hex.DecodeString("2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf")
keys[2], _ = hex.DecodeString("9cb87ec67e875c61390edcd1ab517f443591047709a4d4e45b0f9ed980857b8e")
keys[3], _ = hex.DecodeString("9b4e9e92e974a589f426ceeb4cb291dc24893513fecf8e8460992dcf52621d4d")
keys[4], _ = hex.DecodeString("1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5")
keys[5], _ = hex.DecodeString("d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7")
keys[6], _ = hex.DecodeString("3cd55dbfb8f975f20a0925dfbdabe79fa2d51dd0268afbb8ba6b01de9dfcdd3c")
keys[7], _ = hex.DecodeString("5d0a9d6d9f197c091bf054fac9cb60e11ec723d6610ed8578e617b4d46cb43d5")
for i := 0; i < len(keys); i++ {
if err := tree1.Add(keys[i], values[i]); err != nil {
t.Fatal(err)
}
}
indexes, err := tree2.AddBatch(keys, values)
c.Assert(err, qt.IsNil)
// tree1.PrintGraphviz(nil)
// tree2.PrintGraphviz(nil)
c.Check(len(indexes), qt.Equals, 0)
// check that both trees roots are equal
c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
}
func TestAddBatchCaseB(t *testing.T) {
c := qt.New(t)

+ 6
- 20
tree.go

@ -445,7 +445,7 @@ func (t *Tree) GenProof(k []byte) ([]byte, []byte, error) {
}
s := PackSiblings(t.hashFunction, siblings)
return value, s, nil
return leafV, s, nil
}
// PackSiblings packs the siblings into a byte array.
@ -711,10 +711,8 @@ func (t *Tree) Dump() ([]byte, error) {
func (t *Tree) ImportDump(b []byte) error {
t.updateAccessTime()
r := bytes.NewReader(b)
count := 0
// TODO instead of adding one by one, use AddBatch (once AddBatch is
// optimized)
var err error
var keys, values [][]byte
for {
l := make([]byte, 2)
_, err = io.ReadFull(r, l)
@ -733,22 +731,10 @@ func (t *Tree) ImportDump(b []byte) error {
if err != nil {
return err
}
err = t.Add(k, v)
if err != nil {
return err
}
count++
}
// update nLeafs (once ImportDump uses AddBatch method, this will not be
// needed)
t.tx, err = t.db.NewTx()
if err != nil {
return err
}
if err := t.incNLeafs(count); err != nil {
return err
keys = append(keys, k)
values = append(values, v)
}
if err = t.tx.Commit(); err != nil {
if _, err = t.AddBatch(keys, values); err != nil {
return err
}
return nil
@ -767,7 +753,7 @@ func (t *Tree) GraphvizFirstNLevels(w io.Writer, rootKey []byte, untilLvl int) e
fmt.Fprintf(w, `digraph hierarchy {
node [fontname=Monospace,fontsize=10,shape=box]
`)
nChars := 4
nChars := 4 // TODO move to global constant
nEmpties := 0
err := t.iterWithStop(t.root, 0, func(currLvl int, k, v []byte) bool {
if currLvl == untilLvl {

+ 7
- 7
tree_test.go

@ -14,13 +14,13 @@ func TestAddTestVectors(t *testing.T) {
c := qt.New(t)
// Poseidon test vectors generated using https://github.com/iden3/circomlib smt.js
testVectorsPoseidon := []string{
"0000000000000000000000000000000000000000000000000000000000000000",
"13578938674299138072471463694055224830892726234048532520316387704878000008795",
"5412393676474193513566895793055462193090331607895808993925969873307089394741",
"14204494359367183802864593755198662203838502594566452929175967972147978322084",
}
testAdd(c, HashFunctionPoseidon, testVectorsPoseidon)
// testVectorsPoseidon := []string{
// "0000000000000000000000000000000000000000000000000000000000000000",
// "13578938674299138072471463694055224830892726234048532520316387704878000008795",
// "5412393676474193513566895793055462193090331607895808993925969873307089394741",
// "14204494359367183802864593755198662203838502594566452929175967972147978322084",
// }
// testAdd(c, HashFunctionPoseidon, testVectorsPoseidon)
testVectorsSha256 := []string{
"0000000000000000000000000000000000000000000000000000000000000000",

+ 9
- 3
vt.go

@ -45,14 +45,14 @@ func newVT(maxLevels int, hash HashFunction) vt {
}
}
func (t *vt) add(k, v []byte) error {
func (t *vt) add(fromLvl int, k, v []byte) error {
leaf := newLeafNode(t.params, k, v)
if t.root == nil {
t.root = leaf
return nil
}
if err := t.root.add(t.params, 0, leaf); err != nil {
if err := t.root.add(t.params, fromLvl, leaf); err != nil {
return err
}
@ -119,6 +119,7 @@ func (n *node) add(p *params, currLvl int, leaf *node) error {
if n.r == nil {
// empty sub-node, add the leaf here
n.r = leaf
return nil
}
if err := n.r.add(p, currLvl+1, leaf); err != nil {
return err
@ -127,6 +128,7 @@ func (n *node) add(p *params, currLvl int, leaf *node) error {
if n.l == nil {
// empty sub-node, add the leaf here
n.l = leaf
return nil
}
if err := n.l.add(p, currLvl+1, leaf); err != nil {
return err
@ -134,7 +136,8 @@ func (n *node) add(p *params, currLvl int, leaf *node) error {
}
case vtLeaf:
if bytes.Equal(n.k, leaf.k) {
return fmt.Errorf("key already exists")
return fmt.Errorf("key already exists. Existing node: %s, trying to add node: %s",
hex.EncodeToString(n.k), hex.EncodeToString(leaf.k))
}
oldLeaf := &node{
@ -145,10 +148,13 @@ func (n *node) add(p *params, currLvl int, leaf *node) error {
// remove values from current node (converting it to mid node)
n.k = nil
n.v = nil
n.h = nil
n.path = nil
if err := n.downUntilDivergence(p, currLvl, oldLeaf, leaf); err != nil {
return err
}
case vtEmpty:
panic(fmt.Errorf("EMPTY %v", n)) // TODO TMP
default:
return fmt.Errorf("ERR")
}

+ 72
- 29
vt_test.go

@ -1,51 +1,94 @@
package arbo
import (
"encoding/hex"
"math/big"
"testing"
qt "github.com/frankban/quicktest"
"github.com/iden3/go-merkletree/db/memory"
)
func TestVirtualTree(t *testing.T) {
func TestVirtualTreeTestVectors(t *testing.T) {
c := qt.New(t)
vTree := newVT(10, HashFunctionSha256)
c.Assert(vTree.root, qt.IsNil)
keys := [][]byte{
BigIntToBytes(big.NewInt(1)),
BigIntToBytes(big.NewInt(33)),
BigIntToBytes(big.NewInt(1234)),
BigIntToBytes(big.NewInt(123456789)),
}
values := [][]byte{
BigIntToBytes(big.NewInt(2)),
BigIntToBytes(big.NewInt(44)),
BigIntToBytes(big.NewInt(9876)),
BigIntToBytes(big.NewInt(987654321)),
}
k := BigIntToBytes(big.NewInt(1))
v := BigIntToBytes(big.NewInt(2))
err := vTree.add(k, v)
c.Assert(err, qt.IsNil)
// check the root for different batches of leafs
testVirtualTree(c, 10, keys[:1], values[:1])
testVirtualTree(c, 10, keys[:2], values[:2])
testVirtualTree(c, 10, keys[:3], values[:3])
testVirtualTree(c, 10, keys[:4], values[:4])
}
// check values
c.Assert(vTree.root.k, qt.DeepEquals, k)
c.Assert(vTree.root.v, qt.DeepEquals, v)
func TestVirtualTreeRandomKeys(t *testing.T) {
c := qt.New(t)
// compute hashes
pairs, err := vTree.computeHashes()
c.Assert(err, qt.IsNil)
c.Assert(len(pairs), qt.Equals, 1)
// test with hardcoded values
keys := make([][]byte, 8)
values := make([][]byte, 8)
keys[0], _ = hex.DecodeString("1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642")
keys[1], _ = hex.DecodeString("2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf")
keys[2], _ = hex.DecodeString("9cb87ec67e875c61390edcd1ab517f443591047709a4d4e45b0f9ed980857b8e")
keys[3], _ = hex.DecodeString("9b4e9e92e974a589f426ceeb4cb291dc24893513fecf8e8460992dcf52621d4d")
keys[4], _ = hex.DecodeString("1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5")
keys[5], _ = hex.DecodeString("d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7")
keys[6], _ = hex.DecodeString("3cd55dbfb8f975f20a0925dfbdabe79fa2d51dd0268afbb8ba6b01de9dfcdd3c")
keys[7], _ = hex.DecodeString("5d0a9d6d9f197c091bf054fac9cb60e11ec723d6610ed8578e617b4d46cb43d5")
rootBI := BytesToBigInt(vTree.root.h)
c.Assert(rootBI.String(), qt.Equals,
"46910109172468462938850740851377282682950237270676610513794735904325820156367")
// check the root for different batches of leafs
testVirtualTree(c, 10, keys[:1], values[:1])
testVirtualTree(c, 10, keys, values)
k = BigIntToBytes(big.NewInt(33))
v = BigIntToBytes(big.NewInt(44))
err = vTree.add(k, v)
c.Assert(err, qt.IsNil)
// test with random values
nLeafs := 1024
keys = make([][]byte, nLeafs)
values = make([][]byte, nLeafs)
for i := 0; i < nLeafs; i++ {
keys[i] = randomBytes(32)
values[i] = []byte{0}
}
// compute hashes
pairs, err = vTree.computeHashes()
// check the root for different batches of leafs
testVirtualTree(c, 100, keys[:1], values[:1])
testVirtualTree(c, 100, keys, values)
}
func testVirtualTree(c *qt.C, maxLevels int, keys, values [][]byte) {
c.Assert(len(keys), qt.Equals, len(values))
// normal tree, to have an expected root value
tree, err := NewTree(memory.NewMemoryStorage(), maxLevels, HashFunctionSha256)
c.Assert(err, qt.IsNil)
c.Assert(len(pairs), qt.Equals, 8)
for i := 0; i < len(keys); i++ {
err := tree.Add(keys[i], values[i])
c.Assert(err, qt.IsNil)
}
// virtual tree
vTree := newVT(maxLevels, HashFunctionSha256)
c.Assert(vTree.root, qt.IsNil)
// err = vTree.printGraphviz()
// c.Assert(err, qt.IsNil)
for i := 0; i < len(keys); i++ {
err := vTree.add(0, keys[i], values[i])
c.Assert(err, qt.IsNil)
}
rootBI = BytesToBigInt(vTree.root.h)
c.Assert(rootBI.String(), qt.Equals,
"59481735341404520835410489183267411392292882901306595567679529387376287440550")
// compute hashes, and check Root
_, err = vTree.computeHashes()
c.Assert(err, qt.IsNil)
c.Assert(vTree.root.h, qt.DeepEquals, tree.root)
}

Loading…
Cancel
Save