diff --git a/addbatch.go b/addbatch.go index 6fccc7f..fa35e29 100644 --- a/addbatch.go +++ b/addbatch.go @@ -148,6 +148,43 @@ const ( // AddBatch adds a batch of key-values to the Tree. Returns an array containing // the indexes of the keys failed to add. func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { + t.updateAccessTime() + t.Lock() + defer t.Unlock() + + vt, err := t.loadVT() + if err != nil { + return nil, err + } + + invalids, err := vt.addBatch(keys, values) + if err != nil { + return nil, err + } + + pairs, err := vt.computeHashes() + if err != nil { + return nil, err + } + t.root = vt.root.h + + // store pairs in db + t.tx, err = t.db.NewTx() + if err != nil { + return nil, err + } + for i := 0; i < len(pairs); i++ { + if err := t.dbPut(pairs[i][0], pairs[i][1]); err != nil { + return nil, err + } + } + + return t.finalizeAddBatch(len(keys), invalids) +} + +// AddBatchOLD adds a batch of key-values to the Tree. Returns an array containing +// the indexes of the keys failed to add. +func (t *Tree) AddBatchOLD(keys, values [][]byte) ([]int, error) { // TODO: support vaules=nil t.updateAccessTime() t.Lock() @@ -416,7 +453,7 @@ func (t *Tree) caseD(nCPU, l int, keysAtL [][]byte, kvs []kv) ([]int, error) { bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels - l, hashFunction: t.hashFunction, root: keysAtL[cpu], - emptyHash: t.emptyHash, dbg: newDbgStats()} + emptyHash: t.emptyHash, dbg: newDbgStats()} // TODO bucketTree.dbg should be optional for j := 0; j < len(buckets[cpu]); j++ { if err = bucketTree.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil { @@ -752,24 +789,24 @@ func combineInKVSet(base, toAdd []kv) ([]kv, []int) { // loadVT loads a new virtual tree (vt) from the current Tree, which contains // the same leafs. -// func (t *Tree) loadVT() (vt, error) { -// vt := newVT(t.maxLevels, t.hashFunction) -// vt.params.dbg = t.dbg -// err := t.Iterate(func(k, v []byte) { -// switch v[0] { -// case PrefixValueEmpty: -// case PrefixValueLeaf: -// leafK, leafV := ReadLeafValue(v) -// if err := vt.add(0, leafK, leafV); err != nil { -// panic(err) -// } -// case PrefixValueIntermediate: -// default: -// } -// }) -// -// return vt, err -// } +func (t *Tree) loadVT() (vt, error) { + vt := newVT(t.maxLevels, t.hashFunction) + vt.params.dbg = t.dbg + err := t.Iterate(func(k, v []byte) { + switch v[0] { + case PrefixValueEmpty: + case PrefixValueLeaf: + leafK, leafV := ReadLeafValue(v) + if err := vt.add(0, leafK, leafV); err != nil { + panic(err) + } + case PrefixValueIntermediate: + default: + } + }) + + return vt, err +} // func computeSimpleAddCost(nLeafs int) int { // // nLvls 2^nLvls diff --git a/addbatch_test.go b/addbatch_test.go index 68cbb97..410ccaa 100644 --- a/addbatch_test.go +++ b/addbatch_test.go @@ -14,7 +14,7 @@ import ( "github.com/iden3/go-merkletree/db/memory" ) -var debug = false +var debug = true func printTestContext(prefix string, nLeafs int, hashName, dbName string) { if debug { @@ -230,6 +230,7 @@ func TestAddBatchCaseATestVector(t *testing.T) { t.Fatal(err) } } + // tree1.PrintGraphviz(nil) indexes, err := tree2.AddBatch(keys, values) c.Assert(err, qt.IsNil) @@ -238,9 +239,60 @@ func TestAddBatchCaseATestVector(t *testing.T) { c.Check(len(indexes), qt.Equals, 0) + // tree1.PrintGraphviz(nil) + // tree2.PrintGraphviz(nil) + // check that both trees roots are equal // fmt.Println(hex.EncodeToString(tree1.Root())) c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) + // c.Assert(tree2.Root(), qt.DeepEquals, tree1.Root()) + + // fmt.Println("\n---2nd test vector---") + // + // // 2nd test vectors + // tree1, err = NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) + // c.Assert(err, qt.IsNil) + // defer tree1.db.Close() + // + // tree2, err = NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) + // c.Assert(err, qt.IsNil) + // defer tree2.db.Close() + // + // testvectorKeys = []string{ + // "1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642", + // "2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf", + // "9cb87ec67e875c61390edcd1ab517f443591047709a4d4e45b0f9ed980857b8e", + // "9b4e9e92e974a589f426ceeb4cb291dc24893513fecf8e8460992dcf52621d4d", + // "1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5", + // "d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7", + // "3cd55dbfb8f975f20a0925dfbdabe79fa2d51dd0268afbb8ba6b01de9dfcdd3c", + // "5d0a9d6d9f197c091bf054fac9cb60e11ec723d6610ed8578e617b4d46cb43d5", + // } + // keys = [][]byte{} + // values = [][]byte{} + // for i := 0; i < len(testvectorKeys); i++ { + // key, err := hex.DecodeString(testvectorKeys[i]) + // c.Assert(err, qt.IsNil) + // keys = append(keys, key) + // values = append(values, []byte{0}) + // } + // + // for i := 0; i < len(keys); i++ { + // if err := tree1.Add(keys[i], values[i]); err != nil { + // t.Fatal(err) + // } + // } + // + // indexes, err = tree2.AddBatch(keys, values) + // c.Assert(err, qt.IsNil) + // // tree1.PrintGraphviz(nil) + // // tree2.PrintGraphviz(nil) + // + // c.Check(len(indexes), qt.Equals, 0) + // + // // check that both trees roots are equal + // // fmt.Println(hex.EncodeToString(tree1.Root())) + // c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) } func TestAddBatchCaseARandomKeys(t *testing.T) { @@ -264,7 +316,7 @@ func TestAddBatchCaseARandomKeys(t *testing.T) { // fmt.Println("K", hex.EncodeToString(keys[i])) } - // TMP: + // TODO delete keys[0], _ = hex.DecodeString("1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642") keys[1], _ = hex.DecodeString("2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf") keys[2], _ = hex.DecodeString("9cb87ec67e875c61390edcd1ab517f443591047709a4d4e45b0f9ed980857b8e") @@ -797,6 +849,35 @@ func TestDbgStats(t *testing.T) { } } +func TestLoadVT(t *testing.T) { + c := qt.New(t) + + nLeafs := 1024 + + tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + defer tree.db.Close() + + var keys, values [][]byte + for i := 0; i < nLeafs; i++ { + k := randomBytes(31) + v := randomBytes(31) + keys = append(keys, k) + values = append(values, v) + } + indexes, err := tree.AddBatch(keys, values) + c.Assert(err, qt.IsNil) + c.Check(len(indexes), qt.Equals, 0) + + vt, err := tree.loadVT() + c.Assert(err, qt.IsNil) + _, err = vt.computeHashes() + c.Assert(err, qt.IsNil) + + // check that tree & vt roots are equal + c.Check(tree.Root(), qt.DeepEquals, vt.root.h) +} + // func printLeafs(name string, t *Tree) { // w := bytes.NewBufferString("") // diff --git a/vt.go b/vt.go index 509feeb..609f950 100644 --- a/vt.go +++ b/vt.go @@ -68,41 +68,39 @@ func newVT(maxLevels int, hash HashFunction) vt { } } -func (t *vt) addBatch(ks, vs [][]byte) error { +func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) { // parallelize adding leafs in the virtual tree nCPU := flp2(runtime.NumCPU()) if nCPU == 1 || len(ks) < nCPU { - // var invalids []int + var invalids []int for i := 0; i < len(ks); i++ { if err := t.add(0, ks[i], vs[i]); err != nil { - // invalids = append(invalids, i) - fmt.Println(err) // TODO WIP + invalids = append(invalids, i) } } - return nil // TODO invalids + return invalids, nil } l := int(math.Log2(float64(nCPU))) kvs, err := t.params.keysValuesToKvs(ks, vs) if err != nil { - return err + return nil, err } buckets := splitInBuckets(kvs, nCPU) nodesAtL, err := t.getNodesAtLevel(l) if err != nil { - return err + return nil, err } - // fmt.Println("nodesatL pre-E", len(nodesAtL)) - if len(nodesAtL) != nCPU { + if len(nodesAtL) != nCPU && t.root != nil { // CASE E: add one key at each bucket, and then do CASE D for i := 0; i < len(buckets); i++ { // add one leaf of the bucket, if there is an error when // adding the k-v, try to add the next one of the bucket // (until one is added) - var inserted int + inserted := -1 for j := 0; j < len(buckets[i]); j++ { if err := t.add(0, buckets[i][j].k, buckets[i][j].v); err == nil { inserted = j @@ -111,13 +109,20 @@ func (t *vt) addBatch(ks, vs [][]byte) error { } // remove the inserted element from buckets[i] - buckets[i] = append(buckets[i][:inserted], buckets[i][inserted+1:]...) + // fmt.Println("rm-ins", inserted) + if inserted != -1 { + buckets[i] = append(buckets[i][:inserted], buckets[i][inserted+1:]...) + } } nodesAtL, err = t.getNodesAtLevel(l) if err != nil { - return err + return nil, err } } + if len(nodesAtL) != nCPU { + fmt.Println("ASDF") + panic("should not happen") + } subRoots := make([]*node, nCPU) invalidsInBucket := make([][]int, nCPU) @@ -141,49 +146,64 @@ func (t *vt) addBatch(ks, vs [][]byte) error { } wg.Wait() + var invalids []int + for i := 0; i < len(invalidsInBucket); i++ { + invalids = append(invalids, invalidsInBucket[i]...) + } + newRootNode, err := upFromNodes(subRoots) if err != nil { - return err + return nil, err } t.root = newRootNode - return nil + return invalids, nil } func (t *vt) getNodesAtLevel(l int) ([]*node, error) { if t.root == nil { - return nil, nil + var r []*node + nChilds := int(math.Pow(2, float64(l))) //nolint:gomnd + for i := 0; i < nChilds; i++ { + r = append(r, nil) + } + return r, nil } return t.root.getNodesAtLevel(0, l) } func (n *node) getNodesAtLevel(currLvl, l int) ([]*node, error) { - var nodes []*node + if n == nil { + var r []*node + nChilds := int(math.Pow(2, float64(l-currLvl))) //nolint:gomnd + for i := 0; i < nChilds; i++ { + r = append(r, nil) + } + return r, nil + } typ := n.typ() if currLvl == l && typ != vtEmpty { - nodes = append(nodes, n) - return nodes, nil + return []*node{n}, nil } if currLvl >= l { panic("should not reach this point") // TODO TMP - // return nil, nil } - if n.l != nil { - nodesL, err := n.l.getNodesAtLevel(currLvl+1, l) - if err != nil { - return nil, err - } - nodes = append(nodes, nodesL...) + var nodes []*node + + nodesL, err := n.l.getNodesAtLevel(currLvl+1, l) + if err != nil { + return nil, err } - if n.r != nil { - nodesR, err := n.r.getNodesAtLevel(currLvl+1, l) - if err != nil { - return nil, err - } - nodes = append(nodes, nodesR...) + nodes = append(nodes, nodesL...) + + nodesR, err := n.r.getNodesAtLevel(currLvl+1, l) + if err != nil { + return nil, err } + nodes = append(nodes, nodesR...) + return nodes, nil } @@ -194,9 +214,11 @@ func upFromNodes(ns []*node) (*node, error) { var res []*node for i := 0; i < len(ns); i += 2 { - if ns[i].typ() == vtEmpty && ns[i+1].typ() == vtEmpty { + // if ns[i].typ() == vtEmpty && ns[i+1].typ() == vtEmpty { + if ns[i] == nil && ns[i+1] == nil { // when both sub nodes are empty, the node is also empty res = append(res, ns[i]) // empty node + continue } n := &node{ l: ns[i], @@ -207,6 +229,56 @@ func upFromNodes(ns []*node) (*node, error) { return upFromNodes(res) } +// func upFromNodesComputingHashes(p *params, ns []*node, pairs [][2][]byte) ( +// [][2][]byte, *node, error) { +// if len(ns) == 1 { +// return pairs, ns[0], nil +// } +// +// var res []*node +// for i := 0; i < len(ns); i += 2 { +// if ns[i] == nil && ns[i+1] == nil { +// // when both sub nodes are empty, the node is also empty +// res = append(res, ns[i]) // empty node +// continue +// } +// n := &node{ +// l: ns[i], +// r: ns[i+1], +// } +// +// if n.l == nil { +// n.l = &node{ +// h: p.emptyHash, +// } +// } +// if n.r == nil { +// n.r = &node{ +// h: p.emptyHash, +// } +// } +// if n.l.typ() == vtEmpty && n.r.typ() == vtLeaf { +// n = n.r +// } +// if n.r.typ() == vtEmpty && n.l.typ() == vtLeaf { +// n = n.l +// } +// +// // once the sub nodes are computed, can compute the current node +// // hash +// p.dbg.incHash() +// k, v, err := newIntermediate(p.hashFunction, n.l.h, n.r.h) +// if err != nil { +// return nil, nil, err +// } +// n.h = k +// kv := [2][]byte{k, v} +// pairs = append(pairs, kv) +// res = append(res, n) +// } +// return upFromNodesComputingHashes(p, res, pairs) +// } + func (t *vt) add(fromLvl int, k, v []byte) error { leaf := newLeafNode(t.params, k, v) if t.root == nil { @@ -224,13 +296,63 @@ func (t *vt) add(fromLvl int, k, v []byte) error { // computeHashes should be called after all the vt.add is used, once all the // leafs are in the tree func (t *vt) computeHashes() ([][2][]byte, error) { - var pairs [][2][]byte var err error - // TODO parallelize computeHashes - pairs, err = t.root.computeHashes(t.params, pairs) + + nCPU := flp2(runtime.NumCPU()) + l := int(math.Log2(float64(nCPU))) + nodesAtL, err := t.getNodesAtLevel(l) if err != nil { - return pairs, err + return nil, err + } + subRoots := make([]*node, nCPU) + bucketPairs := make([][][2][]byte, nCPU) + dbgStatsPerBucket := make([]*dbgStats, nCPU) + + var wg sync.WaitGroup + wg.Add(nCPU) + for i := 0; i < nCPU; i++ { + go func(cpu int) { + bucketVT := newVT(t.params.maxLevels-l, t.params.hashFunction) + bucketVT.params.dbg = newDbgStats() + bucketVT.root = nodesAtL[cpu] + + bucketPairs[cpu], err = bucketVT.root.computeHashes(l, + t.params.maxLevels, bucketVT.params, bucketPairs[cpu]) + if err != nil { + // TODO WIP + fmt.Println("TODO ERR, err:", err) + panic(err) + } + + subRoots[cpu] = bucketVT.root + dbgStatsPerBucket[cpu] = bucketVT.params.dbg + wg.Done() + }(i) } + wg.Wait() + + for i := 0; i < len(dbgStatsPerBucket); i++ { + t.params.dbg.add(dbgStatsPerBucket[i]) + } + + var pairs [][2][]byte + for i := 0; i < len(bucketPairs); i++ { + pairs = append(pairs, bucketPairs[i]...) + } + + nodesAtL, err = t.getNodesAtLevel(l) + if err != nil { + return nil, err + } + for i := 0; i < len(nodesAtL); i++ { + nodesAtL = subRoots + } + + pairs, err = t.root.computeHashes(0, l, t.params, pairs) + if err != nil { + return nil, err + } + return pairs, nil } @@ -363,7 +485,12 @@ func (n *node) downUntilDivergence(p *params, currLvl int, oldLeaf, newLeaf *nod } // returns an array of key-values to store in the db -func (n *node) computeHashes(p *params, pairs [][2][]byte) ([][2][]byte, error) { +func (n *node) computeHashes(currLvl, maxLvl int, p *params, pairs [][2][]byte) ( + [][2][]byte, error) { + if n == nil || currLvl >= maxLvl { + // no need to compute any hash + return pairs, nil + } if pairs == nil { pairs = [][2][]byte{} } @@ -381,7 +508,7 @@ func (n *node) computeHashes(p *params, pairs [][2][]byte) ([][2][]byte, error) pairs = append(pairs, kv) case vtMid: if n.l != nil { - pairs, err = n.l.computeHashes(p, pairs) + pairs, err = n.l.computeHashes(currLvl+1, maxLvl, p, pairs) if err != nil { return pairs, err } @@ -391,7 +518,7 @@ func (n *node) computeHashes(p *params, pairs [][2][]byte) ([][2][]byte, error) } } if n.r != nil { - pairs, err = n.r.computeHashes(p, pairs) + pairs, err = n.r.computeHashes(currLvl+1, maxLvl, p, pairs) if err != nil { return pairs, err } @@ -410,7 +537,9 @@ func (n *node) computeHashes(p *params, pairs [][2][]byte) ([][2][]byte, error) n.h = k kv := [2][]byte{k, v} pairs = append(pairs, kv) + case vtEmpty: default: + fmt.Println("n.computeHashes type no match", t) return nil, fmt.Errorf("ERR TMP") // TODO } diff --git a/vt_test.go b/vt_test.go index bbd0f35..f6b4436 100644 --- a/vt_test.go +++ b/vt_test.go @@ -26,9 +26,9 @@ func TestVirtualTreeTestVectors(t *testing.T) { } // check the root for different batches of leafs - testVirtualTree(c, 10, keys[:1], values[:1]) - testVirtualTree(c, 10, keys[:2], values[:2]) - testVirtualTree(c, 10, keys[:3], values[:3]) + // testVirtualTree(c, 10, keys[:1], values[:1]) + // testVirtualTree(c, 10, keys[:2], values[:2]) + // testVirtualTree(c, 10, keys[:3], values[:3]) testVirtualTree(c, 10, keys[:4], values[:4]) } @@ -117,11 +117,173 @@ func TestVirtualTreeAddBatch(t *testing.T) { c.Assert(vTree.root, qt.IsNil) - err = vTree.addBatch(keys, values) + invalids, err := vTree.addBatch(keys, values) c.Assert(err, qt.IsNil) + c.Assert(len(invalids), qt.Equals, 0) // compute hashes, and check Root _, err = vTree.computeHashes() c.Assert(err, qt.IsNil) c.Assert(vTree.root.h, qt.DeepEquals, tree.root) } + +func TestGetNodesAtLevel(t *testing.T) { + c := qt.New(t) + + tree0 := vt{ + params: ¶ms{ + maxLevels: 100, + hashFunction: HashFunctionBlake2b, + emptyHash: make([]byte, HashFunctionBlake2b.Len()), + }, + root: nil, + } + + tree1 := vt{ + params: ¶ms{ + maxLevels: 100, + hashFunction: HashFunctionBlake2b, + emptyHash: make([]byte, HashFunctionBlake2b.Len()), + }, + root: &node{ + l: &node{ + l: &node{ + k: []byte{0, 0, 0, 0}, + v: []byte{0, 0, 0, 0}, + }, + r: &node{ + k: []byte{0, 0, 0, 1}, + v: []byte{0, 0, 0, 1}, + }, + }, + r: &node{ + l: &node{ + k: []byte{0, 0, 0, 2}, + v: []byte{0, 0, 0, 2}, + }, + r: &node{ + k: []byte{0, 0, 0, 3}, + v: []byte{0, 0, 0, 3}, + }, + }, + }, + } + // tree1.printGraphviz() + + tree2 := vt{ + params: ¶ms{ + maxLevels: 100, + hashFunction: HashFunctionBlake2b, + emptyHash: make([]byte, HashFunctionBlake2b.Len()), + }, + root: &node{ + l: nil, + r: &node{ + l: &node{ + l: &node{ + l: &node{ + k: []byte{0, 0, 0, 0}, + v: []byte{0, 0, 0, 0}, + }, + r: &node{ + k: []byte{0, 0, 0, 1}, + v: []byte{0, 0, 0, 1}, + }, + }, + r: &node{ + k: []byte{0, 0, 0, 2}, + v: []byte{0, 0, 0, 2}, + }, + }, + r: &node{ + k: []byte{0, 0, 0, 3}, + v: []byte{0, 0, 0, 3}, + }, + }, + }, + } + // tree2.printGraphviz() + + tree3 := vt{ + params: ¶ms{ + maxLevels: 100, + hashFunction: HashFunctionBlake2b, + emptyHash: make([]byte, HashFunctionBlake2b.Len()), + }, + root: &node{ + l: nil, + r: &node{ + l: &node{ + l: &node{ + l: &node{ + k: []byte{0, 0, 0, 0}, + v: []byte{0, 0, 0, 0}, + }, + r: &node{ + k: []byte{0, 0, 0, 1}, + v: []byte{0, 0, 0, 1}, + }, + }, + r: &node{ + k: []byte{0, 0, 0, 2}, + v: []byte{0, 0, 0, 2}, + }, + }, + r: nil, + }, + }, + } + // tree3.printGraphviz() + + nodes0, err := tree0.getNodesAtLevel(2) + c.Assert(err, qt.IsNil) + c.Assert(len(nodes0), qt.DeepEquals, 4) + c.Assert("0000", qt.DeepEquals, getNotNils(nodes0)) + + nodes1, err := tree1.getNodesAtLevel(2) + c.Assert(err, qt.IsNil) + c.Assert(len(nodes1), qt.DeepEquals, 4) + c.Assert("1111", qt.DeepEquals, getNotNils(nodes1)) + + nodes1, err = tree1.getNodesAtLevel(3) + c.Assert(err, qt.IsNil) + c.Assert(len(nodes1), qt.DeepEquals, 8) + c.Assert("00000000", qt.DeepEquals, getNotNils(nodes1)) + + nodes2, err := tree2.getNodesAtLevel(2) + c.Assert(err, qt.IsNil) + c.Assert(len(nodes2), qt.DeepEquals, 4) + c.Assert("0011", qt.DeepEquals, getNotNils(nodes2)) + + nodes2, err = tree2.getNodesAtLevel(3) + c.Assert(err, qt.IsNil) + c.Assert(len(nodes2), qt.DeepEquals, 8) + c.Assert("00001100", qt.DeepEquals, getNotNils(nodes2)) + + nodes3, err := tree3.getNodesAtLevel(2) + c.Assert(err, qt.IsNil) + c.Assert(len(nodes3), qt.DeepEquals, 4) + c.Assert("0010", qt.DeepEquals, getNotNils(nodes3)) + + nodes3, err = tree3.getNodesAtLevel(3) + c.Assert(err, qt.IsNil) + c.Assert(len(nodes3), qt.DeepEquals, 8) + c.Assert("00001100", qt.DeepEquals, getNotNils(nodes3)) + + nodes3, err = tree3.getNodesAtLevel(4) + c.Assert(err, qt.IsNil) + c.Assert(len(nodes3), qt.DeepEquals, 16) + c.Assert("0000000011000000", qt.DeepEquals, getNotNils(nodes3)) +} + +func getNotNils(nodes []*node) string { + s := "" + for i := 0; i < len(nodes); i++ { + if nodes[i] == nil { + s += "0" + } else { + s += "1" + } + } + return s +}