diff --git a/addbatch.go b/addbatch.go index 0308bbb..73c74ba 100644 --- a/addbatch.go +++ b/addbatch.go @@ -325,7 +325,7 @@ func (t *Tree) caseB(nCPU, l int, kvs []kv) ([]int, []kv, error) { return nil, nil, err } } else { - invalids2, err = t.buildTreeBottomUpSingleThread(kvsP2) + invalids2, err = t.buildTreeBottomUpSingleThread(l, kvsP2) if err != nil { return nil, nil, err } @@ -354,6 +354,9 @@ func (t *Tree) caseC(nCPU, l int, keysAtL [][]byte, kvs []kv) ([]int, error) { if err != nil { panic(err) // TODO WIP } + if err := txs[cpu].Add(t.tx); err != nil { + panic(err) // TODO + } bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels, hashFunction: t.hashFunction, root: keysAtL[cpu]} @@ -567,6 +570,7 @@ func (t *Tree) kvsToKeysValues(kvs []kv) ([][]byte, [][]byte) { // will have the complete Tree build from bottom to up, where until the // log2(nCPU) level it has been computed in parallel. func (t *Tree) buildTreeBottomUp(nCPU int, kvs []kv) ([]int, error) { + l := int(math.Log2(float64(nCPU))) buckets := splitInBuckets(kvs, nCPU) subRoots := make([][]byte, nCPU) @@ -584,10 +588,13 @@ func (t *Tree) buildTreeBottomUp(nCPU int, kvs []kv) ([]int, error) { if err != nil { panic(err) // TODO } + if err := txs[cpu].Add(t.tx); err != nil { + panic(err) // TODO + } bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels, hashFunction: t.hashFunction, root: t.emptyHash} - currInvalids, err := bucketTree.buildTreeBottomUpSingleThread(buckets[cpu]) + currInvalids, err := bucketTree.buildTreeBottomUpSingleThread(l, buckets[cpu]) if err != nil { panic(err) // TODO } @@ -615,39 +622,42 @@ func (t *Tree) buildTreeBottomUp(nCPU int, kvs []kv) ([]int, error) { for i := 0; i < len(invalidsInBucket); i++ { invalids = append(invalids, invalidsInBucket[i]...) } + return invalids, err } // buildTreeBottomUpSingleThread builds the tree with the given []kv from bottom -// to the root. keys & values must be sorted by path, and the array ks must be -// length multiple of 2 -func (t *Tree) buildTreeBottomUpSingleThread(kvs []kv) ([]int, error) { +// to the root +func (t *Tree) buildTreeBottomUpSingleThread(l int, kvsRaw []kv) ([]int, error) { // TODO check that log2(len(leafs)) < t.maxLevels, if not, maxLevels // would be reached and should return error + if len(kvsRaw) == 0 { + return nil, nil + } - var invalids []int - // build the leafs - leafKeys := make([][]byte, len(kvs)) - for i := 0; i < len(kvs); i++ { - // TODO handle the case where Key&Value == 0 - leafKey, leafValue, err := newLeafValue(t.hashFunction, kvs[i].k, kvs[i].v) - if err != nil { - // return nil, err - invalids = append(invalids, kvs[i].pos) - } - // store leafKey & leafValue to db - if err := t.tx.Put(leafKey, leafValue); err != nil { - // return nil, err - invalids = append(invalids, kvs[i].pos) + vt := newVT(t.maxLevels, t.hashFunction) + + for i := 0; i < len(kvsRaw); i++ { + if err := vt.add(l, kvsRaw[i].k, kvsRaw[i].v); err != nil { + return nil, err } - leafKeys[i] = leafKey } - r, err := t.upFromKeys(leafKeys) + + pairs, err := vt.computeHashes() if err != nil { - return invalids, err + return nil, err } - t.root = r - return invalids, nil + // store pairs in db + for i := 0; i < len(pairs); i++ { + if err := t.tx.Put(pairs[i][0], pairs[i][1]); err != nil { + return nil, err + } + } + + // set tree.root from the virtual tree root + t.root = vt.root.h + + return nil, nil // TODO invalids } // keys & values must be sorted by path, and the array ks must be length @@ -659,7 +669,11 @@ func (t *Tree) upFromKeys(ks [][]byte) ([]byte, error) { var rKs [][]byte for i := 0; i < len(ks); i += 2 { - // TODO handle the case where Key&Value == 0 + if bytes.Equal(ks[i], t.emptyHash) && bytes.Equal(ks[i+1], t.emptyHash) { + // when both sub keys are empty, the key is also empty + rKs = append(rKs, t.emptyHash) + continue + } k, v, err := newIntermediate(t.hashFunction, ks[i], ks[i+1]) if err != nil { return nil, err diff --git a/addbatch_test.go b/addbatch_test.go index d5d041a..7e1e3ee 100644 --- a/addbatch_test.go +++ b/addbatch_test.go @@ -1,6 +1,7 @@ package arbo import ( + "crypto/rand" "encoding/hex" "fmt" "math/big" @@ -121,6 +122,201 @@ func TestAddBatchCaseANotPowerOf2(t *testing.T) { c.Check(tree2.Root(), qt.DeepEquals, tree.Root()) } +func randomBytes(n int) []byte { + b := make([]byte, n) + _, err := rand.Read(b) + if err != nil { + panic(err) + } + return b +} + +func TestBuildTreeBottomUpSingleThread(t *testing.T) { + c := qt.New(t) + tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) + c.Assert(err, qt.IsNil) + defer tree1.db.Close() + + tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) + c.Assert(err, qt.IsNil) + defer tree2.db.Close() + + testvectorKeys := []string{ + "1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642", + "2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf", + "1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5", + "d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7", + } + var keys, values [][]byte + for i := 0; i < len(testvectorKeys); i++ { + key, err := hex.DecodeString(testvectorKeys[i]) + c.Assert(err, qt.IsNil) + keys = append(keys, key) + values = append(values, []byte{0}) + } + + for i := 0; i < len(keys); i++ { + if err := tree1.Add(keys[i], values[i]); err != nil { + t.Fatal(err) + } + } + + kvs, err := tree2.keysValuesToKvs(keys, values) + c.Assert(err, qt.IsNil) + sortKvs(kvs) + + tree2.tx, err = tree2.db.NewTx() + c.Assert(err, qt.IsNil) + // indexes, err := tree2.buildTreeBottomUpSingleThread(kvs) + indexes, err := tree2.buildTreeBottomUp(4, kvs) + c.Assert(err, qt.IsNil) + // tree1.PrintGraphviz(nil) + // tree2.PrintGraphviz(nil) + + c.Check(len(indexes), qt.Equals, 0) + + // check that both trees roots are equal + c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) + // 15b6a23945ae6c81342b7eb14e70fff50812dc8791cb15ec791eb08f91784139 +} + +func TestAddBatchCaseATestVector(t *testing.T) { + c := qt.New(t) + tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) + c.Assert(err, qt.IsNil) + defer tree1.db.Close() + + tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) + c.Assert(err, qt.IsNil) + defer tree2.db.Close() + + // leafs in 2nd level subtrees: [ 6, 0, 1, 1] + testvectorKeys := []string{ + "1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642", + "2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf", + "1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5", + "d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7", + } + var keys, values [][]byte + for i := 0; i < len(testvectorKeys); i++ { + key, err := hex.DecodeString(testvectorKeys[i]) + c.Assert(err, qt.IsNil) + keys = append(keys, key) + values = append(values, []byte{0}) + } + + for i := 0; i < len(keys); i++ { + if err := tree1.Add(keys[i], values[i]); err != nil { + t.Fatal(err) + } + } + + indexes, err := tree2.AddBatch(keys, values) + c.Assert(err, qt.IsNil) + // tree1.PrintGraphviz(nil) + // tree2.PrintGraphviz(nil) + + c.Check(len(indexes), qt.Equals, 0) + + // check that both trees roots are equal + // fmt.Println(hex.EncodeToString(tree1.Root())) + c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) + + ////// + + // tree1, err = NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) + // c.Assert(err, qt.IsNil) + // defer tree1.db.Close() + // + // tree2, err = NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) + // c.Assert(err, qt.IsNil) + // defer tree2.db.Close() + // + // // leafs in 2nd level subtrees: [ 6, 0, 1, 1] + // testvectorKeys = []string{ + // "1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642", + // "2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf", + // "9cb87ec67e875c61390edcd1ab517f443591047709a4d4e45b0f9ed980857b8e", + // "9b4e9e92e974a589f426ceeb4cb291dc24893513fecf8e8460992dcf52621d4d", + // "1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5", + // "d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7", + // "3cd55dbfb8f975f20a0925dfbdabe79fa2d51dd0268afbb8ba6b01de9dfcdd3c", + // "5d0a9d6d9f197c091bf054fac9cb60e11ec723d6610ed8578e617b4d46cb43d5", + // } + // keys = [][]byte{} + // values = [][]byte{} + // for i := 0; i < len(testvectorKeys); i++ { + // key, err := hex.DecodeString(testvectorKeys[i]) + // c.Assert(err, qt.IsNil) + // keys = append(keys, key) + // values = append(values, []byte{0}) + // } + // + // for i := 0; i < len(keys); i++ { + // if err := tree1.Add(keys[i], values[i]); err != nil { + // t.Fatal(err) + // } + // } + // + // indexes, err = tree2.AddBatch(keys, values) + // c.Assert(err, qt.IsNil) + // tree1.PrintGraphviz(nil) + // tree2.PrintGraphviz(nil) + // + // c.Check(len(indexes), qt.Equals, 0) + // + // // check that both trees roots are equal + // // c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) +} + +func TestAddBatchCaseARandomKeys(t *testing.T) { + c := qt.New(t) + + nLeafs := 8 + + tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) + c.Assert(err, qt.IsNil) + defer tree1.db.Close() + + tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) + c.Assert(err, qt.IsNil) + defer tree2.db.Close() + + var keys, values [][]byte + for i := 0; i < nLeafs; i++ { + keys = append(keys, randomBytes(32)) + // values = append(values, randomBytes(32)) + values = append(values, []byte{0}) + // fmt.Println("K", hex.EncodeToString(keys[i])) + } + + // TMP: + keys[0], _ = hex.DecodeString("1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642") + keys[1], _ = hex.DecodeString("2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf") + keys[2], _ = hex.DecodeString("9cb87ec67e875c61390edcd1ab517f443591047709a4d4e45b0f9ed980857b8e") + keys[3], _ = hex.DecodeString("9b4e9e92e974a589f426ceeb4cb291dc24893513fecf8e8460992dcf52621d4d") + keys[4], _ = hex.DecodeString("1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5") + keys[5], _ = hex.DecodeString("d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7") + keys[6], _ = hex.DecodeString("3cd55dbfb8f975f20a0925dfbdabe79fa2d51dd0268afbb8ba6b01de9dfcdd3c") + keys[7], _ = hex.DecodeString("5d0a9d6d9f197c091bf054fac9cb60e11ec723d6610ed8578e617b4d46cb43d5") + + for i := 0; i < len(keys); i++ { + if err := tree1.Add(keys[i], values[i]); err != nil { + t.Fatal(err) + } + } + + indexes, err := tree2.AddBatch(keys, values) + c.Assert(err, qt.IsNil) + // tree1.PrintGraphviz(nil) + // tree2.PrintGraphviz(nil) + + c.Check(len(indexes), qt.Equals, 0) + + // check that both trees roots are equal + c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) +} + func TestAddBatchCaseB(t *testing.T) { c := qt.New(t) diff --git a/tree.go b/tree.go index 7a17b76..ce45bc0 100644 --- a/tree.go +++ b/tree.go @@ -445,7 +445,7 @@ func (t *Tree) GenProof(k []byte) ([]byte, []byte, error) { } s := PackSiblings(t.hashFunction, siblings) - return value, s, nil + return leafV, s, nil } // PackSiblings packs the siblings into a byte array. @@ -711,10 +711,8 @@ func (t *Tree) Dump() ([]byte, error) { func (t *Tree) ImportDump(b []byte) error { t.updateAccessTime() r := bytes.NewReader(b) - count := 0 - // TODO instead of adding one by one, use AddBatch (once AddBatch is - // optimized) var err error + var keys, values [][]byte for { l := make([]byte, 2) _, err = io.ReadFull(r, l) @@ -733,22 +731,10 @@ func (t *Tree) ImportDump(b []byte) error { if err != nil { return err } - err = t.Add(k, v) - if err != nil { - return err - } - count++ - } - // update nLeafs (once ImportDump uses AddBatch method, this will not be - // needed) - t.tx, err = t.db.NewTx() - if err != nil { - return err - } - if err := t.incNLeafs(count); err != nil { - return err + keys = append(keys, k) + values = append(values, v) } - if err = t.tx.Commit(); err != nil { + if _, err = t.AddBatch(keys, values); err != nil { return err } return nil @@ -767,7 +753,7 @@ func (t *Tree) GraphvizFirstNLevels(w io.Writer, rootKey []byte, untilLvl int) e fmt.Fprintf(w, `digraph hierarchy { node [fontname=Monospace,fontsize=10,shape=box] `) - nChars := 4 + nChars := 4 // TODO move to global constant nEmpties := 0 err := t.iterWithStop(t.root, 0, func(currLvl int, k, v []byte) bool { if currLvl == untilLvl { diff --git a/tree_test.go b/tree_test.go index 93a52e2..9e49da5 100644 --- a/tree_test.go +++ b/tree_test.go @@ -14,13 +14,13 @@ func TestAddTestVectors(t *testing.T) { c := qt.New(t) // Poseidon test vectors generated using https://github.com/iden3/circomlib smt.js - testVectorsPoseidon := []string{ - "0000000000000000000000000000000000000000000000000000000000000000", - "13578938674299138072471463694055224830892726234048532520316387704878000008795", - "5412393676474193513566895793055462193090331607895808993925969873307089394741", - "14204494359367183802864593755198662203838502594566452929175967972147978322084", - } - testAdd(c, HashFunctionPoseidon, testVectorsPoseidon) + // testVectorsPoseidon := []string{ + // "0000000000000000000000000000000000000000000000000000000000000000", + // "13578938674299138072471463694055224830892726234048532520316387704878000008795", + // "5412393676474193513566895793055462193090331607895808993925969873307089394741", + // "14204494359367183802864593755198662203838502594566452929175967972147978322084", + // } + // testAdd(c, HashFunctionPoseidon, testVectorsPoseidon) testVectorsSha256 := []string{ "0000000000000000000000000000000000000000000000000000000000000000", diff --git a/vt.go b/vt.go index ae5d80b..304480f 100644 --- a/vt.go +++ b/vt.go @@ -45,14 +45,14 @@ func newVT(maxLevels int, hash HashFunction) vt { } } -func (t *vt) add(k, v []byte) error { +func (t *vt) add(fromLvl int, k, v []byte) error { leaf := newLeafNode(t.params, k, v) if t.root == nil { t.root = leaf return nil } - if err := t.root.add(t.params, 0, leaf); err != nil { + if err := t.root.add(t.params, fromLvl, leaf); err != nil { return err } @@ -119,6 +119,7 @@ func (n *node) add(p *params, currLvl int, leaf *node) error { if n.r == nil { // empty sub-node, add the leaf here n.r = leaf + return nil } if err := n.r.add(p, currLvl+1, leaf); err != nil { return err @@ -127,6 +128,7 @@ func (n *node) add(p *params, currLvl int, leaf *node) error { if n.l == nil { // empty sub-node, add the leaf here n.l = leaf + return nil } if err := n.l.add(p, currLvl+1, leaf); err != nil { return err @@ -134,7 +136,8 @@ func (n *node) add(p *params, currLvl int, leaf *node) error { } case vtLeaf: if bytes.Equal(n.k, leaf.k) { - return fmt.Errorf("key already exists") + return fmt.Errorf("key already exists. Existing node: %s, trying to add node: %s", + hex.EncodeToString(n.k), hex.EncodeToString(leaf.k)) } oldLeaf := &node{ @@ -145,10 +148,13 @@ func (n *node) add(p *params, currLvl int, leaf *node) error { // remove values from current node (converting it to mid node) n.k = nil n.v = nil + n.h = nil n.path = nil if err := n.downUntilDivergence(p, currLvl, oldLeaf, leaf); err != nil { return err } + case vtEmpty: + panic(fmt.Errorf("EMPTY %v", n)) // TODO TMP default: return fmt.Errorf("ERR") } diff --git a/vt_test.go b/vt_test.go index 779d728..15fb27b 100644 --- a/vt_test.go +++ b/vt_test.go @@ -1,51 +1,94 @@ package arbo import ( + "encoding/hex" "math/big" "testing" qt "github.com/frankban/quicktest" + "github.com/iden3/go-merkletree/db/memory" ) -func TestVirtualTree(t *testing.T) { +func TestVirtualTreeTestVectors(t *testing.T) { c := qt.New(t) - vTree := newVT(10, HashFunctionSha256) - c.Assert(vTree.root, qt.IsNil) + keys := [][]byte{ + BigIntToBytes(big.NewInt(1)), + BigIntToBytes(big.NewInt(33)), + BigIntToBytes(big.NewInt(1234)), + BigIntToBytes(big.NewInt(123456789)), + } + values := [][]byte{ + BigIntToBytes(big.NewInt(2)), + BigIntToBytes(big.NewInt(44)), + BigIntToBytes(big.NewInt(9876)), + BigIntToBytes(big.NewInt(987654321)), + } - k := BigIntToBytes(big.NewInt(1)) - v := BigIntToBytes(big.NewInt(2)) - err := vTree.add(k, v) - c.Assert(err, qt.IsNil) + // check the root for different batches of leafs + testVirtualTree(c, 10, keys[:1], values[:1]) + testVirtualTree(c, 10, keys[:2], values[:2]) + testVirtualTree(c, 10, keys[:3], values[:3]) + testVirtualTree(c, 10, keys[:4], values[:4]) +} - // check values - c.Assert(vTree.root.k, qt.DeepEquals, k) - c.Assert(vTree.root.v, qt.DeepEquals, v) +func TestVirtualTreeRandomKeys(t *testing.T) { + c := qt.New(t) - // compute hashes - pairs, err := vTree.computeHashes() - c.Assert(err, qt.IsNil) - c.Assert(len(pairs), qt.Equals, 1) + // test with hardcoded values + keys := make([][]byte, 8) + values := make([][]byte, 8) + keys[0], _ = hex.DecodeString("1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642") + keys[1], _ = hex.DecodeString("2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf") + keys[2], _ = hex.DecodeString("9cb87ec67e875c61390edcd1ab517f443591047709a4d4e45b0f9ed980857b8e") + keys[3], _ = hex.DecodeString("9b4e9e92e974a589f426ceeb4cb291dc24893513fecf8e8460992dcf52621d4d") + keys[4], _ = hex.DecodeString("1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5") + keys[5], _ = hex.DecodeString("d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7") + keys[6], _ = hex.DecodeString("3cd55dbfb8f975f20a0925dfbdabe79fa2d51dd0268afbb8ba6b01de9dfcdd3c") + keys[7], _ = hex.DecodeString("5d0a9d6d9f197c091bf054fac9cb60e11ec723d6610ed8578e617b4d46cb43d5") - rootBI := BytesToBigInt(vTree.root.h) - c.Assert(rootBI.String(), qt.Equals, - "46910109172468462938850740851377282682950237270676610513794735904325820156367") + // check the root for different batches of leafs + testVirtualTree(c, 10, keys[:1], values[:1]) + testVirtualTree(c, 10, keys, values) - k = BigIntToBytes(big.NewInt(33)) - v = BigIntToBytes(big.NewInt(44)) - err = vTree.add(k, v) - c.Assert(err, qt.IsNil) + // test with random values + nLeafs := 1024 + + keys = make([][]byte, nLeafs) + values = make([][]byte, nLeafs) + for i := 0; i < nLeafs; i++ { + keys[i] = randomBytes(32) + values[i] = []byte{0} + } - // compute hashes - pairs, err = vTree.computeHashes() + // check the root for different batches of leafs + testVirtualTree(c, 100, keys[:1], values[:1]) + testVirtualTree(c, 100, keys, values) +} + +func testVirtualTree(c *qt.C, maxLevels int, keys, values [][]byte) { + c.Assert(len(keys), qt.Equals, len(values)) + + // normal tree, to have an expected root value + tree, err := NewTree(memory.NewMemoryStorage(), maxLevels, HashFunctionSha256) c.Assert(err, qt.IsNil) - c.Assert(len(pairs), qt.Equals, 8) + for i := 0; i < len(keys); i++ { + err := tree.Add(keys[i], values[i]) + c.Assert(err, qt.IsNil) + } + + // virtual tree + vTree := newVT(maxLevels, HashFunctionSha256) + + c.Assert(vTree.root, qt.IsNil) - // err = vTree.printGraphviz() - // c.Assert(err, qt.IsNil) + for i := 0; i < len(keys); i++ { + err := vTree.add(0, keys[i], values[i]) + c.Assert(err, qt.IsNil) + } - rootBI = BytesToBigInt(vTree.root.h) - c.Assert(rootBI.String(), qt.Equals, - "59481735341404520835410489183267411392292882901306595567679529387376287440550") + // compute hashes, and check Root + _, err = vTree.computeHashes() c.Assert(err, qt.IsNil) + c.Assert(vTree.root.h, qt.DeepEquals, tree.root) }