diff --git a/addbatch.go b/addbatch.go deleted file mode 100644 index fa35e29..0000000 --- a/addbatch.go +++ /dev/null @@ -1,821 +0,0 @@ -package arbo - -import ( - "bytes" - "fmt" - "math" - "runtime" - "sort" - "sync" - - "github.com/iden3/go-merkletree/db" -) - -/* - -AddBatch design -=============== - - -CASE A: Empty Tree --> if tree is empty (root==0) -================================================= -- Build the full tree from bottom to top (from all the leaf to the root) - - -CASE B: ALMOST CASE A, Almost empty Tree --> if Tree has numLeafs < minLeafsThreshold -============================================================================== -- Get the Leafs (key & value) (iterate the tree from the current root getting -the leafs) -- Create a new empty Tree -- Do CASE A for the new Tree, giving the already existing key&values (leafs) -from the original Tree + the new key&values to be added from the AddBatch call - - - R R - / \ / \ - A * / \ - / \ / \ - B C * * - / | / \ - / | / \ - / | / \ - L: A B G D - / \ - / \ - / \ - C * - / \ - / \ - / \ - ... ... (nLeafs < minLeafsThreshold) - - -CASE C: ALMOST CASE B --> if Tree has few Leafs (but numLeafs>=minLeafsThreshold) -============================================================================== -- Use A, B, G, F as Roots of subtrees -- Do CASE B for each subtree -- Then go from L to the Root - - R - / \ - / \ - / \ - * * - / | / \ - / | / \ - / | / \ -L: A B G D - / \ - / \ - / \ - C * - / \ - / \ - / \ - ... ... (nLeafs >= minLeafsThreshold) - - - -CASE D: Already populated Tree -============================== -- Use A, B, C, D as subtree -- Sort the Keys in Buckets that share the initial part of the path -- For each subtree add there the new leafs - - R - / \ - / \ - / \ - * * - / | / \ - / | / \ - / | / \ -L: A B C D - /\ /\ / \ / \ - ... ... ... ... ... ... - - -CASE E: Already populated Tree Unbalanced -========================================= -- Need to fill M1 and M2, and then will be able to use CASE D - - Search for M1 & M2 in the inputed Keys - - Add M1 & M2 to the Tree - - From here can use CASE D - - R - / \ - / \ - / \ - * * - | \ - | \ - | \ -L: M1 * M2 * (where M1 and M2 are empty) - / | / - / | / - / | / - A * * - / \ | \ - / \ | \ - / \ | \ - B * * C - / \ |\ - ... ... | \ - | \ - D E - - - -Algorithm decision -================== -- if nLeafs==0 (root==0): CASE A -- if nLeafs=minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold: CASE C -- else: CASE D & CASE E - - -- Multiple tree.Add calls: O(n log n) - - Used in: cases A, B, C -- Tree from bottom to top: O(log n) - - Used in: cases D, E - -*/ - -const ( - minLeafsThreshold = 100 // nolint:gomnd // TMP WIP this will be autocalculated -) - -// AddBatch adds a batch of key-values to the Tree. Returns an array containing -// the indexes of the keys failed to add. -func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { - t.updateAccessTime() - t.Lock() - defer t.Unlock() - - vt, err := t.loadVT() - if err != nil { - return nil, err - } - - invalids, err := vt.addBatch(keys, values) - if err != nil { - return nil, err - } - - pairs, err := vt.computeHashes() - if err != nil { - return nil, err - } - t.root = vt.root.h - - // store pairs in db - t.tx, err = t.db.NewTx() - if err != nil { - return nil, err - } - for i := 0; i < len(pairs); i++ { - if err := t.dbPut(pairs[i][0], pairs[i][1]); err != nil { - return nil, err - } - } - - return t.finalizeAddBatch(len(keys), invalids) -} - -// AddBatchOLD adds a batch of key-values to the Tree. Returns an array containing -// the indexes of the keys failed to add. -func (t *Tree) AddBatchOLD(keys, values [][]byte) ([]int, error) { - // TODO: support vaules=nil - t.updateAccessTime() - t.Lock() - defer t.Unlock() - - kvs, err := t.keysValuesToKvs(keys, values) - if err != nil { - return nil, err - } - - t.tx, err = t.db.NewTx() - if err != nil { - return nil, err - } - - // if nCPU is not a power of two, cut at the highest power of two under - // nCPU - nCPU := flp2(runtime.NumCPU()) - l := int(math.Log2(float64(nCPU))) - var invalids []int - - // CASE A: if nLeafs==0 (root==0) - if bytes.Equal(t.root, t.emptyHash) { - invalids, err = t.caseA(nCPU, kvs) - if err != nil { - return nil, err - } - - return t.finalizeAddBatch(len(keys), invalids) - } - - // CASE B: if nLeafs=minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold - // available parallelization, will need to be a power of 2 (2**n) - if nLeafs >= minLeafsThreshold && - (nLeafs/nCPU) < minLeafsThreshold && - len(keysAtL) == nCPU { - invalids, err = t.caseC(nCPU, l, keysAtL, kvs) - if err != nil { - return nil, err - } - - return t.finalizeAddBatch(len(keys), invalids) - } - - // CASE E - if len(keysAtL) != nCPU { - // CASE E: add one key at each bucket, and then do CASE D - buckets := splitInBuckets(kvs, nCPU) - kvs = []kv{} - for i := 0; i < len(buckets); i++ { - // add one leaf of the bucket, if there is an error when - // adding the k-v, try to add the next one of the bucket - // (until one is added) - var inserted int - for j := 0; j < len(buckets[i]); j++ { - if err := t.add(0, buckets[i][j].k, buckets[i][j].v); err == nil { - inserted = j - break - } - } - - // put the buckets elements except the inserted one - kvs = append(kvs, buckets[i][:inserted]...) - kvs = append(kvs, buckets[i][inserted+1:]...) - } - keysAtL, err = t.getKeysAtLevel(l + 1) - if err != nil { - return nil, err - } - } - - // CASE D - if len(keysAtL) == nCPU { // enter in CASE D if len(keysAtL)=nCPU, if not, CASE E - invalidsCaseD, err := t.caseD(nCPU, l, keysAtL, kvs) - if err != nil { - return nil, err - } - invalids = append(invalids, invalidsCaseD...) - - return t.finalizeAddBatch(len(keys), invalids) - } - - return nil, fmt.Errorf("UNIMPLEMENTED") -} - -func (t *Tree) finalizeAddBatch(nKeys int, invalids []int) ([]int, error) { - // store root to db - if err := t.dbPut(dbKeyRoot, t.root); err != nil { - return nil, err - } - - // update nLeafs - if err := t.incNLeafs(nKeys - len(invalids)); err != nil { - return nil, err - } - - // commit db tx - if err := t.tx.Commit(); err != nil { - return nil, err - } - return invalids, nil -} - -func (t *Tree) caseA(nCPU int, kvs []kv) ([]int, error) { - invalids, err := t.buildTreeFromLeafs(nCPU, kvs) - if err != nil { - return nil, err - } - return invalids, nil -} - -func (t *Tree) caseB(nCPU, l int, kvs []kv) ([]int, error) { - // get already existing keys - aKs, aVs, err := t.getLeafs(t.root) - if err != nil { - return nil, err - } - aKvs, err := t.keysValuesToKvs(aKs, aVs) - if err != nil { - return nil, err - } - // add already existing key-values to the inputted key-values - // kvs = append(kvs, aKvs...) - kvs, invalids := combineInKVSet(aKvs, kvs) - - // proceed with CASE A - sortKvs(kvs) - - var invalids2 []int - if nCPU > 1 { - invalids2, err = t.buildTreeFromLeafs(nCPU, kvs) - if err != nil { - return nil, err - } - } else { - invalids2, err = t.buildTreeFromLeafsSingleThread(l, kvs) - if err != nil { - return nil, err - } - } - invalids = append(invalids, invalids2...) - - return invalids, nil -} - -func (t *Tree) caseC(nCPU, l int, keysAtL [][]byte, kvs []kv) ([]int, error) { - // 1. go down until level L (L=log2(nBuckets)): keysAtL - - var excedents []kv - buckets := splitInBuckets(kvs, nCPU) - - // 2. use keys at level L as roots of the subtrees under each one - subRoots := make([][]byte, nCPU) - dbgStatsPerBucket := make([]*dbgStats, nCPU) - txs := make([]db.Tx, nCPU) - var wg sync.WaitGroup - wg.Add(nCPU) - for i := 0; i < nCPU; i++ { - go func(cpu int) { - var err error - txs[cpu], err = t.db.NewTx() - if err != nil { - panic(err) // TODO WIP - } - if err := txs[cpu].Add(t.tx); err != nil { - panic(err) // TODO - } - bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels, - hashFunction: t.hashFunction, root: keysAtL[cpu], - emptyHash: t.emptyHash, dbg: newDbgStats()} - - // 3. do CASE B (with 1 cpu) for each key at level L - _, err = bucketTree.caseB(1, l, buckets[cpu]) // TODO handle invalids - if err != nil { - panic(err) // TODO WIP - // return nil, err - } - subRoots[cpu] = bucketTree.root - dbgStatsPerBucket[cpu] = bucketTree.dbg - wg.Done() - }(i) - } - wg.Wait() - - // merge buckets txs into Tree.tx - for i := 0; i < len(txs); i++ { - if err := t.tx.Add(txs[i]); err != nil { - return nil, err - } - } - - // 4. go upFromKeys from the new roots of the subtrees - newRoot, err := t.upFromKeys(subRoots) - if err != nil { - return nil, err - } - t.root = newRoot - - // add the key-values that have not been used yet - var invalids []int - for i := 0; i < len(excedents); i++ { - if err = t.add(0, excedents[i].k, excedents[i].v); err != nil { - invalids = append(invalids, excedents[i].pos) - } - } - - for i := 0; i < len(dbgStatsPerBucket); i++ { - t.dbg.add(dbgStatsPerBucket[i]) - } - - return invalids, nil -} - -func (t *Tree) caseD(nCPU, l int, keysAtL [][]byte, kvs []kv) ([]int, error) { - if nCPU == 1 { // CASE D, but with 1 cpu - var invalids []int - for i := 0; i < len(kvs); i++ { - if err := t.add(0, kvs[i].k, kvs[i].v); err != nil { - invalids = append(invalids, kvs[i].pos) - } - } - return invalids, nil - } - - buckets := splitInBuckets(kvs, nCPU) - - subRoots := make([][]byte, nCPU) - invalidsInBucket := make([][]int, nCPU) - dbgStatsPerBucket := make([]*dbgStats, nCPU) - txs := make([]db.Tx, nCPU) - - var wg sync.WaitGroup - wg.Add(nCPU) - for i := 0; i < nCPU; i++ { - go func(cpu int) { - var err error - txs[cpu], err = t.db.NewTx() - if err != nil { - panic(err) // TODO WIP - } - // put already existing tx into txs[cpu], as txs[cpu] - // needs the pending key-values that are not in tree.db, - // but are in tree.tx - if err := txs[cpu].Add(t.tx); err != nil { - panic(err) // TODO WIP - } - - bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels - l, - hashFunction: t.hashFunction, root: keysAtL[cpu], - emptyHash: t.emptyHash, dbg: newDbgStats()} // TODO bucketTree.dbg should be optional - - for j := 0; j < len(buckets[cpu]); j++ { - if err = bucketTree.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil { - invalidsInBucket[cpu] = append(invalidsInBucket[cpu], buckets[cpu][j].pos) - } - } - subRoots[cpu] = bucketTree.root - dbgStatsPerBucket[cpu] = bucketTree.dbg - wg.Done() - }(i) - } - wg.Wait() - - // merge buckets txs into Tree.tx - for i := 0; i < len(txs); i++ { - if err := t.tx.Add(txs[i]); err != nil { - return nil, err - } - } - - newRoot, err := t.upFromKeys(subRoots) - if err != nil { - return nil, err - } - t.root = newRoot - - var invalids []int - for i := 0; i < len(invalidsInBucket); i++ { - invalids = append(invalids, invalidsInBucket[i]...) - } - - for i := 0; i < len(dbgStatsPerBucket); i++ { - t.dbg.add(dbgStatsPerBucket[i]) - } - - return invalids, nil -} - -func splitInBuckets(kvs []kv, nBuckets int) [][]kv { - buckets := make([][]kv, nBuckets) - // 1. classify the keyvalues into buckets - for i := 0; i < len(kvs); i++ { - pair := kvs[i] - - // bucketnum := keyToBucket(pair.k, nBuckets) - bucketnum := keyToBucket(pair.keyPath, nBuckets) - buckets[bucketnum] = append(buckets[bucketnum], pair) - } - return buckets -} - -// TODO rename in a more 'real' name (calculate bucket from/for key) -func keyToBucket(k []byte, nBuckets int) int { - nLevels := int(math.Log2(float64(nBuckets))) - b := make([]int, nBuckets) - for i := 0; i < nBuckets; i++ { - b[i] = i - } - r := b - mid := len(r) / 2 //nolint:gomnd - for i := 0; i < nLevels; i++ { - if int(k[i/8]&(1<<(i%8))) != 0 { - r = r[mid:] - mid = len(r) / 2 //nolint:gomnd - } else { - r = r[:mid] - mid = len(r) / 2 //nolint:gomnd - } - } - return r[0] -} - -type kv struct { - pos int // original position in the array - keyPath []byte - k []byte - v []byte -} - -// compareBytes compares byte slices where the bytes are compared from left to -// right and each byte is compared by bit from right to left -func compareBytes(a, b []byte) bool { - // WIP - for i := 0; i < len(a); i++ { - for j := 0; j < 8; j++ { - aBit := a[i] & (1 << j) - bBit := b[i] & (1 << j) - if aBit > bBit { - return false - } else if aBit < bBit { - return true - } - } - } - return false -} - -// sortKvs sorts the kv by path -func sortKvs(kvs []kv) { - sort.Slice(kvs, func(i, j int) bool { - return compareBytes(kvs[i].keyPath, kvs[j].keyPath) - }) -} - -func (t *Tree) keysValuesToKvs(ks, vs [][]byte) ([]kv, error) { - if len(ks) != len(vs) { - return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)", - len(ks), len(vs)) - } - kvs := make([]kv, len(ks)) - for i := 0; i < len(ks); i++ { - keyPath := make([]byte, t.hashFunction.Len()) - copy(keyPath[:], ks[i]) - kvs[i].pos = i - kvs[i].keyPath = keyPath - kvs[i].k = ks[i] - kvs[i].v = vs[i] - } - - return kvs, nil -} - -/* -func (t *Tree) kvsToKeysValues(kvs []kv) ([][]byte, [][]byte) { - ks := make([][]byte, len(kvs)) - vs := make([][]byte, len(kvs)) - for i := 0; i < len(kvs); i++ { - ks[i] = kvs[i].k - vs[i] = kvs[i].v - } - return ks, vs -} -*/ - -// buildTreeFromLeafs splits the key-values into n Buckets (where n is the number -// of CPUs), in parallel builds a subtree for each bucket, once all the subtrees -// are built, uses the subtrees roots as keys for a new tree, which as result -// will have the complete Tree build from bottom to up, where until the -// log2(nCPU) level it has been computed in parallel. -func (t *Tree) buildTreeFromLeafs(nCPU int, kvs []kv) ([]int, error) { - l := int(math.Log2(float64(nCPU))) - buckets := splitInBuckets(kvs, nCPU) - - subRoots := make([][]byte, nCPU) - invalidsInBucket := make([][]int, nCPU) - dbgStatsPerBucket := make([]*dbgStats, nCPU) - txs := make([]db.Tx, nCPU) - - var wg sync.WaitGroup - wg.Add(nCPU) - for i := 0; i < nCPU; i++ { - go func(cpu int) { - sortKvs(buckets[cpu]) - - var err error - txs[cpu], err = t.db.NewTx() - if err != nil { - panic(err) // TODO - } - if err := txs[cpu].Add(t.tx); err != nil { - panic(err) // TODO - } - bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels, - hashFunction: t.hashFunction, root: t.emptyHash, - emptyHash: t.emptyHash, dbg: newDbgStats()} - - currInvalids, err := bucketTree.buildTreeFromLeafsSingleThread(l, buckets[cpu]) - if err != nil { - panic(err) // TODO - } - invalidsInBucket[cpu] = currInvalids - subRoots[cpu] = bucketTree.root - dbgStatsPerBucket[cpu] = bucketTree.dbg - wg.Done() - }(i) - } - wg.Wait() - - // merge buckets txs into Tree.tx - for i := 0; i < len(txs); i++ { - if err := t.tx.Add(txs[i]); err != nil { - return nil, err - } - } - - newRoot, err := t.upFromKeys(subRoots) - if err != nil { - return nil, err - } - t.root = newRoot - - var invalids []int - for i := 0; i < len(invalidsInBucket); i++ { - invalids = append(invalids, invalidsInBucket[i]...) - } - - for i := 0; i < len(dbgStatsPerBucket); i++ { - t.dbg.add(dbgStatsPerBucket[i]) - } - - return invalids, err -} - -// buildTreeFromLeafsSingleThread builds the tree with the given []kv from bottom -// to the root -func (t *Tree) buildTreeFromLeafsSingleThread(l int, kvsRaw []kv) ([]int, error) { - // TODO check that log2(len(leafs)) < t.maxLevels, if not, maxLevels - // would be reached and should return error - if len(kvsRaw) == 0 { - return nil, nil - } - - vt := newVT(t.maxLevels, t.hashFunction) - if t.dbg != nil { - vt.params.dbg = newDbgStats() - } - - for i := 0; i < len(kvsRaw); i++ { - if err := vt.add(l, kvsRaw[i].k, kvsRaw[i].v); err != nil { - return nil, err - } - } - pairs, err := vt.computeHashes() - if err != nil { - return nil, err - } - - // store pairs in db - for i := 0; i < len(pairs); i++ { - if err := t.dbPut(pairs[i][0], pairs[i][1]); err != nil { - return nil, err - } - } - t.dbg.add(vt.params.dbg) - - // set tree.root from the virtual tree root - t.root = vt.root.h - - return nil, nil // TODO invalids -} - -// keys & values must be sorted by path, and the array ks must be length -// multiple of 2 -func (t *Tree) upFromKeys(ks [][]byte) ([]byte, error) { - if len(ks) == 1 { - return ks[0], nil - } - - var rKs [][]byte - for i := 0; i < len(ks); i += 2 { - if bytes.Equal(ks[i], t.emptyHash) && bytes.Equal(ks[i+1], t.emptyHash) { - // when both sub keys are empty, the key is also empty - rKs = append(rKs, t.emptyHash) - continue - } - k, v, err := newIntermediate(t.hashFunction, ks[i], ks[i+1]) - if err != nil { - return nil, err - } - // store k-v to db - if err = t.dbPut(k, v); err != nil { - return nil, err - } - rKs = append(rKs, k) - } - return t.upFromKeys(rKs) -} - -func (t *Tree) getLeafs(root []byte) ([][]byte, [][]byte, error) { - var ks, vs [][]byte - err := t.iter(root, func(k, v []byte) { - if v[0] != PrefixValueLeaf { - return - } - leafK, leafV := ReadLeafValue(v) - ks = append(ks, leafK) - vs = append(vs, leafV) - }) - return ks, vs, err -} - -func (t *Tree) getKeysAtLevel(l int) ([][]byte, error) { - var keys [][]byte - err := t.iterWithStop(t.root, 0, func(currLvl int, k, v []byte) bool { - if currLvl == l && !bytes.Equal(k, t.emptyHash) { - keys = append(keys, k) - } - if currLvl >= l { - return true // to stop the iter from going down - } - return false - }) - - return keys, err -} - -// flp2 computes the floor power of 2, the highest power of 2 under the given -// value. -func flp2(n int) int { - res := 0 - for i := n; i >= 1; i-- { - if (i & (i - 1)) == 0 { - res = i - break - } - } - return res -} - -// combineInKVSet combines two kv array in one single array without repeated -// keys. -func combineInKVSet(base, toAdd []kv) ([]kv, []int) { - // TODO this is a naive version, this will be implemented in a more - // efficient way or through maps, or through sorted binary search - r := base - var invalids []int - for i := 0; i < len(toAdd); i++ { - e := false - // check if toAdd[i] exists in the base set - for j := 0; j < len(base); j++ { - if bytes.Equal(toAdd[i].k, base[j].k) { - e = true - } - } - if !e { - r = append(r, toAdd[i]) - } else { - invalids = append(invalids, toAdd[i].pos) - } - } - return r, invalids -} - -// loadVT loads a new virtual tree (vt) from the current Tree, which contains -// the same leafs. -func (t *Tree) loadVT() (vt, error) { - vt := newVT(t.maxLevels, t.hashFunction) - vt.params.dbg = t.dbg - err := t.Iterate(func(k, v []byte) { - switch v[0] { - case PrefixValueEmpty: - case PrefixValueLeaf: - leafK, leafV := ReadLeafValue(v) - if err := vt.add(0, leafK, leafV); err != nil { - panic(err) - } - case PrefixValueIntermediate: - default: - } - }) - - return vt, err -} - -// func computeSimpleAddCost(nLeafs int) int { -// // nLvls 2^nLvls -// nLvls := int(math.Log2(float64(nLeafs))) -// return nLvls * int(math.Pow(2, float64(nLvls))) -// } -// -// func computeFromLeafsAddCost(nLeafs int) int { -// // 2^nLvls * 2 - 1 -// nLvls := int(math.Log2(float64(nLeafs))) -// return (int(math.Pow(2, float64(nLvls))) * 2) - 1 -// } diff --git a/addbatch_test.go b/addbatch_test.go index 410ccaa..529278b 100644 --- a/addbatch_test.go +++ b/addbatch_test.go @@ -6,6 +6,7 @@ import ( "fmt" "math/big" "runtime" + "sort" "testing" "time" @@ -31,7 +32,7 @@ func printRes(name string, duration time.Duration) { func debugTime(descr string, time1, time2 time.Duration) { if debug { - fmt.Printf("%s was %f times faster than without AddBatch\n", + fmt.Printf("%s was %.02fx times faster than without AddBatch\n", descr, float64(time1)/float64(time2)) } } @@ -151,7 +152,7 @@ func randomBytes(n int) []byte { return b } -func TestBuildTreeFromLeafsSingleThread(t *testing.T) { +func TestAddBatchCaseATestVector(t *testing.T) { c := qt.New(t) tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) c.Assert(err, qt.IsNil) @@ -161,6 +162,7 @@ func TestBuildTreeFromLeafsSingleThread(t *testing.T) { c.Assert(err, qt.IsNil) defer tree2.db.Close() + // leafs in 2nd level subtrees: [ 6, 0, 1, 1] testvectorKeys := []string{ "1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642", "2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf", @@ -181,43 +183,33 @@ func TestBuildTreeFromLeafsSingleThread(t *testing.T) { } } - kvs, err := tree2.keysValuesToKvs(keys, values) - c.Assert(err, qt.IsNil) - sortKvs(kvs) - - tree2.tx, err = tree2.db.NewTx() - c.Assert(err, qt.IsNil) - // indexes, err := tree2.buildTreeFromLeafsSingleThread(kvs) - indexes, err := tree2.buildTreeFromLeafs(4, kvs) + indexes, err := tree2.AddBatch(keys, values) c.Assert(err, qt.IsNil) - // tree1.PrintGraphviz(nil) - // tree2.PrintGraphviz(nil) - c.Check(len(indexes), qt.Equals, 0) - // check that both trees roots are equal c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) - // 15b6a23945ae6c81342b7eb14e70fff50812dc8791cb15ec791eb08f91784139 -} -func TestAddBatchCaseATestVector(t *testing.T) { - c := qt.New(t) - tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) + // 2nd test vectors + tree1, err = NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree1.db.Close() - tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) + tree2, err = NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree2.db.Close() - // leafs in 2nd level subtrees: [ 6, 0, 1, 1] - testvectorKeys := []string{ + testvectorKeys = []string{ "1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642", "2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf", + "9cb87ec67e875c61390edcd1ab517f443591047709a4d4e45b0f9ed980857b8e", + "9b4e9e92e974a589f426ceeb4cb291dc24893513fecf8e8460992dcf52621d4d", "1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5", "d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7", + "3cd55dbfb8f975f20a0925dfbdabe79fa2d51dd0268afbb8ba6b01de9dfcdd3c", + "5d0a9d6d9f197c091bf054fac9cb60e11ec723d6610ed8578e617b4d46cb43d5", } - var keys, values [][]byte + keys = [][]byte{} + values = [][]byte{} for i := 0; i < len(testvectorKeys); i++ { key, err := hex.DecodeString(testvectorKeys[i]) c.Assert(err, qt.IsNil) @@ -230,69 +222,12 @@ func TestAddBatchCaseATestVector(t *testing.T) { t.Fatal(err) } } - // tree1.PrintGraphviz(nil) - indexes, err := tree2.AddBatch(keys, values) + indexes, err = tree2.AddBatch(keys, values) c.Assert(err, qt.IsNil) - // tree1.PrintGraphviz(nil) - // tree2.PrintGraphviz(nil) - c.Check(len(indexes), qt.Equals, 0) - - // tree1.PrintGraphviz(nil) - // tree2.PrintGraphviz(nil) - // check that both trees roots are equal - // fmt.Println(hex.EncodeToString(tree1.Root())) c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) - // c.Assert(tree2.Root(), qt.DeepEquals, tree1.Root()) - - // fmt.Println("\n---2nd test vector---") - // - // // 2nd test vectors - // tree1, err = NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) - // c.Assert(err, qt.IsNil) - // defer tree1.db.Close() - // - // tree2, err = NewTree(memory.NewMemoryStorage(), 100, HashFunctionBlake2b) - // c.Assert(err, qt.IsNil) - // defer tree2.db.Close() - // - // testvectorKeys = []string{ - // "1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642", - // "2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf", - // "9cb87ec67e875c61390edcd1ab517f443591047709a4d4e45b0f9ed980857b8e", - // "9b4e9e92e974a589f426ceeb4cb291dc24893513fecf8e8460992dcf52621d4d", - // "1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5", - // "d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7", - // "3cd55dbfb8f975f20a0925dfbdabe79fa2d51dd0268afbb8ba6b01de9dfcdd3c", - // "5d0a9d6d9f197c091bf054fac9cb60e11ec723d6610ed8578e617b4d46cb43d5", - // } - // keys = [][]byte{} - // values = [][]byte{} - // for i := 0; i < len(testvectorKeys); i++ { - // key, err := hex.DecodeString(testvectorKeys[i]) - // c.Assert(err, qt.IsNil) - // keys = append(keys, key) - // values = append(values, []byte{0}) - // } - // - // for i := 0; i < len(keys); i++ { - // if err := tree1.Add(keys[i], values[i]); err != nil { - // t.Fatal(err) - // } - // } - // - // indexes, err = tree2.AddBatch(keys, values) - // c.Assert(err, qt.IsNil) - // // tree1.PrintGraphviz(nil) - // // tree2.PrintGraphviz(nil) - // - // c.Check(len(indexes), qt.Equals, 0) - // - // // check that both trees roots are equal - // // fmt.Println(hex.EncodeToString(tree1.Root())) - // c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) } func TestAddBatchCaseARandomKeys(t *testing.T) { @@ -417,83 +352,6 @@ func TestAddBatchCaseBRepeatedLeafs(t *testing.T) { c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) } -func TestCombineInKVSet(t *testing.T) { - c := qt.New(t) - - var a, b, expected []kv - for i := 0; i < 10; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - kv := kv{k: k} - if i < 7 { - a = append(a, kv) - } - if i >= 4 { - b = append(b, kv) - } - expected = append(expected, kv) - } - - r, invalids := combineInKVSet(a, b) - for i := 0; i < len(r); i++ { - c.Assert(r[i].k, qt.DeepEquals, expected[i].k) - } - c.Assert(len(invalids), qt.Equals, 7-4) -} - -func TestGetKeysAtLevel(t *testing.T) { - c := qt.New(t) - - tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) - c.Assert(err, qt.IsNil) - defer tree.db.Close() - - for i := 0; i < 32; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) - if err := tree.Add(k, v); err != nil { - t.Fatal(err) - } - } - - keys, err := tree.getKeysAtLevel(2) - c.Assert(err, qt.IsNil) - expected := []string{ - "a5d5f14fce7348e40751496cf25d107d91b0bd043435b9577d778a01f8aa6111", - "e9e8dd9b28a7f81d1ff34cb5cefc0146dd848b31031a427b79bdadb62e7f6910", - } - for i := 0; i < len(keys); i++ { - c.Assert(hex.EncodeToString(keys[i]), qt.Equals, expected[i]) - } - - keys, err = tree.getKeysAtLevel(3) - c.Assert(err, qt.IsNil) - expected = []string{ - "9f12c13e52bca96ad4882a26558e48ab67ddd63e062b839207e893d961390f01", - "16d246dd6826ec7346c7328f11c4261facf82d4689f33263ff6e207956a77f21", - "4a22cc901c6337daa17a431fa20170684b710e5f551509511492ec24e81a8f2f", - "470d61abcbd154977bffc9a9ec5a8daff0caabcf2a25e8441f604c79daa0f82d", - } - for i := 0; i < len(keys); i++ { - c.Assert(hex.EncodeToString(keys[i]), qt.Equals, expected[i]) - } - - keys, err = tree.getKeysAtLevel(4) - c.Assert(err, qt.IsNil) - expected = []string{ - "7a5d1c81f7b96318012de3417e53d4f13df5b1337718651cd29d0cb0a66edd20", - "3408213e4e844bdf3355eb8781c74e31626812898c2dbe141ed6d2c92256fc1c", - "dfd8a4d0b6954a3e9f3892e655b58d456eeedf9367f27dfdd9bc2dd6a5577312", - "9e99fbec06fb2a6725997c12c4995f62725eb4cce4808523a5a5e80cca64b007", - "0befa1e070231dbf4e8ff841c05878cdec823e0c09594c24910a248b3ff5a628", - "b7131b0a15c772a57005a4dc5d0d6dd4b3414f5d9ee7408ce5e86c5ab3520e04", - "6d1abe0364077846a56bab1deb1a04883eb796b74fe531a7676a9a370f83ab21", - "4270116394bede69cf9cd72069eca018238080380bef5de75be8dcbbe968e105", - } - for i := 0; i < len(keys); i++ { - c.Assert(hex.EncodeToString(keys[i]), qt.Equals, expected[i]) - } -} - func TestSplitInBuckets(t *testing.T) { c := qt.New(t) @@ -563,11 +421,37 @@ func TestSplitInBuckets(t *testing.T) { sortKvs(buckets[i]) c.Assert(len(buckets[i]), qt.Equals, len(expected[i])) for j := 0; j < len(buckets[i]); j++ { - c.Check(hex.EncodeToString(buckets[i][j].k[:4]), qt.Equals, expected[i][j]) + c.Check(hex.EncodeToString(buckets[i][j].k[:4]), + qt.Equals, expected[i][j]) } } } +// compareBytes compares byte slices where the bytes are compared from left to +// right and each byte is compared by bit from right to left +func compareBytes(a, b []byte) bool { + // WIP + for i := 0; i < len(a); i++ { + for j := 0; j < 8; j++ { + aBit := a[i] & (1 << j) + bBit := b[i] & (1 << j) + if aBit > bBit { + return false + } else if aBit < bBit { + return true + } + } + } + return false +} + +// sortKvs sorts the kv by path +func sortKvs(kvs []kv) { + sort.Slice(kvs, func(i, j int) bool { + return compareBytes(kvs[i].keyPath, kvs[j].keyPath) + }) +} + func TestAddBatchCaseC(t *testing.T) { c := qt.New(t) @@ -878,37 +762,5 @@ func TestLoadVT(t *testing.T) { c.Check(tree.Root(), qt.DeepEquals, vt.root.h) } -// func printLeafs(name string, t *Tree) { -// w := bytes.NewBufferString("") -// -// err := t.Iterate(func(k, v []byte) { -// if v[0] != PrefixValueLeaf { -// return -// } -// leafK, _ := readLeafValue(v) -// fmt.Fprintf(w, hex.EncodeToString(leafK[:4])+"\n") -// }) -// if err != nil { -// panic(err) -// } -// err = ioutil.WriteFile(name, w.Bytes(), 0644) -// if err != nil { -// panic(err) -// } -// -// } - -// func TestComputeCosts(t *testing.T) { -// fmt.Println(computeSimpleAddCost(10)) -// fmt.Println(computeFromLeafsAddCost(10)) -// -// fmt.Println(computeSimpleAddCost(1024)) -// fmt.Println(computeFromLeafsAddCost(1024)) -// } - -// TODO test tree with nLeafs > minLeafsThreshold, but that at level L, there is -// less keys than nBuckets (so CASE C could be applied if first few leafs are -// added to balance the tree) - // TODO test adding batch with repeated keys in the batch // TODO test adding batch with multiple invalid keys diff --git a/tree.go b/tree.go index 24106bd..f5dd91e 100644 --- a/tree.go +++ b/tree.go @@ -116,6 +116,77 @@ func (t *Tree) HashFunction() HashFunction { return t.hashFunction } +// AddBatch adds a batch of key-values to the Tree. Returns an array containing +// the indexes of the keys failed to add. +func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { + t.updateAccessTime() + t.Lock() + defer t.Unlock() + + vt, err := t.loadVT() + if err != nil { + return nil, err + } + + // TODO check that keys & values is valid for Tree.hashFunction + invalids, err := vt.addBatch(keys, values) + if err != nil { + return nil, err + } + + // once the VirtualTree is build, compute the hashes + pairs, err := vt.computeHashes() + if err != nil { + return nil, err + } + t.root = vt.root.h + + // store pairs in db + t.tx, err = t.db.NewTx() + if err != nil { + return nil, err + } + for i := 0; i < len(pairs); i++ { + if err := t.dbPut(pairs[i][0], pairs[i][1]); err != nil { + return nil, err + } + } + + // store root to db + if err := t.dbPut(dbKeyRoot, t.root); err != nil { + return nil, err + } + + // update nLeafs + if err := t.incNLeafs(len(keys) - len(invalids)); err != nil { + return nil, err + } + + // commit db tx + if err := t.tx.Commit(); err != nil { + return nil, err + } + return invalids, nil +} + +// loadVT loads a new virtual tree (vt) from the current Tree, which contains +// the same leafs. +func (t *Tree) loadVT() (vt, error) { + vt := newVT(t.maxLevels, t.hashFunction) + vt.params.dbg = t.dbg + err := t.Iterate(func(k, v []byte) { + if v[0] != PrefixValueLeaf { + return + } + leafK, leafV := ReadLeafValue(v) + if err := vt.add(0, leafK, leafV); err != nil { + panic(err) + } + }) + + return vt, err +} + // Add inserts the key-value into the Tree. If the inputs come from a *big.Int, // is expected that are represented by a Little-Endian byte array (for circom // compatibility). diff --git a/vt.go b/vt.go index 609f950..422468c 100644 --- a/vt.go +++ b/vt.go @@ -30,6 +30,13 @@ type params struct { dbg *dbgStats } +type kv struct { + pos int // original position in the inputted array + keyPath []byte + k []byte + v []byte +} + func (p *params) keysValuesToKvs(ks, vs [][]byte) ([]kv, error) { if len(ks) != len(vs) { return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)", @@ -68,8 +75,12 @@ func newVT(maxLevels int, hash HashFunction) vt { } } +// addBatch adds a batch of key-values to the VirtualTree. Returns an array +// containing the indexes of the keys failed to add. Does not include the +// computation of hashes of the nodes neither the storage of the key-values of +// the tree into the db. After addBatch, vt.computeHashes should be called to +// compute the hashes of all the nodes of the tree. func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) { - // parallelize adding leafs in the virtual tree nCPU := flp2(runtime.NumCPU()) if nCPU == 1 || len(ks) < nCPU { var invalids []int @@ -95,7 +106,37 @@ func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) { return nil, err } if len(nodesAtL) != nCPU && t.root != nil { - // CASE E: add one key at each bucket, and then do CASE D + /* + Already populated Tree but Unbalanced + - Need to fill M1 and M2, and then will be able to continue with the flow + - Search for M1 & M2 in the inputed Keys + - Add M1 & M2 to the Tree + - From here can continue with the flow + + R + / \ + / \ + / \ + * * + | \ + | \ + | \ + L: M1 * M2 * (where M1 and M2 are empty) + / | / + / | / + / | / + A * * + / \ | \ + / \ | \ + / \ | \ + B * * C + / \ |\ + ... ... | \ + | \ + D E + */ + + // add one key at each bucket, and then continue with the flow for i := 0; i < len(buckets); i++ { // add one leaf of the bucket, if there is an error when // adding the k-v, try to add the next one of the bucket @@ -120,8 +161,7 @@ func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) { } } if len(nodesAtL) != nCPU { - fmt.Println("ASDF") - panic("should not happen") + panic("should not happen") // TODO TMP } subRoots := make([]*node, nCPU) @@ -131,8 +171,6 @@ func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) { wg.Add(nCPU) for i := 0; i < nCPU; i++ { go func(cpu int) { - sortKvs(buckets[cpu]) - bucketVT := newVT(t.params.maxLevels-l, t.params.hashFunction) bucketVT.root = nodesAtL[cpu] for j := 0; j < len(buckets[cpu]); j++ { @@ -214,8 +252,8 @@ func upFromNodes(ns []*node) (*node, error) { var res []*node for i := 0; i < len(ns); i += 2 { - // if ns[i].typ() == vtEmpty && ns[i+1].typ() == vtEmpty { - if ns[i] == nil && ns[i+1] == nil { + if ns[i].typ() == vtEmpty && ns[i+1].typ() == vtEmpty { + // if ns[i] == nil && ns[i+1] == nil { // when both sub nodes are empty, the node is also empty res = append(res, ns[i]) // empty node continue @@ -229,56 +267,6 @@ func upFromNodes(ns []*node) (*node, error) { return upFromNodes(res) } -// func upFromNodesComputingHashes(p *params, ns []*node, pairs [][2][]byte) ( -// [][2][]byte, *node, error) { -// if len(ns) == 1 { -// return pairs, ns[0], nil -// } -// -// var res []*node -// for i := 0; i < len(ns); i += 2 { -// if ns[i] == nil && ns[i+1] == nil { -// // when both sub nodes are empty, the node is also empty -// res = append(res, ns[i]) // empty node -// continue -// } -// n := &node{ -// l: ns[i], -// r: ns[i+1], -// } -// -// if n.l == nil { -// n.l = &node{ -// h: p.emptyHash, -// } -// } -// if n.r == nil { -// n.r = &node{ -// h: p.emptyHash, -// } -// } -// if n.l.typ() == vtEmpty && n.r.typ() == vtLeaf { -// n = n.r -// } -// if n.r.typ() == vtEmpty && n.l.typ() == vtLeaf { -// n = n.l -// } -// -// // once the sub nodes are computed, can compute the current node -// // hash -// p.dbg.incHash() -// k, v, err := newIntermediate(p.hashFunction, n.l.h, n.r.h) -// if err != nil { -// return nil, nil, err -// } -// n.h = k -// kv := [2][]byte{k, v} -// pairs = append(pairs, kv) -// res = append(res, n) -// } -// return upFromNodesComputingHashes(p, res, pairs) -// } - func (t *vt) add(fromLvl int, k, v []byte) error { leaf := newLeafNode(t.params, k, v) if t.root == nil { @@ -320,8 +308,7 @@ func (t *vt) computeHashes() ([][2][]byte, error) { t.params.maxLevels, bucketVT.params, bucketPairs[cpu]) if err != nil { // TODO WIP - fmt.Println("TODO ERR, err:", err) - panic(err) + panic("TODO" + err.Error()) } subRoots[cpu] = bucketVT.root @@ -444,7 +431,7 @@ func (n *node) add(p *params, currLvl int, leaf *node) error { case vtEmpty: panic(fmt.Errorf("EMPTY %v", n)) // TODO TMP default: - return fmt.Errorf("ERR") + return fmt.Errorf("ERR") // TODO TMP } return nil @@ -484,6 +471,53 @@ func (n *node) downUntilDivergence(p *params, currLvl int, oldLeaf, newLeaf *nod return nil } +func splitInBuckets(kvs []kv, nBuckets int) [][]kv { + buckets := make([][]kv, nBuckets) + // 1. classify the keyvalues into buckets + for i := 0; i < len(kvs); i++ { + pair := kvs[i] + + // bucketnum := keyToBucket(pair.k, nBuckets) + bucketnum := keyToBucket(pair.keyPath, nBuckets) + buckets[bucketnum] = append(buckets[bucketnum], pair) + } + return buckets +} + +// TODO rename in a more 'real' name (calculate bucket from/for key) +func keyToBucket(k []byte, nBuckets int) int { + nLevels := int(math.Log2(float64(nBuckets))) + b := make([]int, nBuckets) + for i := 0; i < nBuckets; i++ { + b[i] = i + } + r := b + mid := len(r) / 2 //nolint:gomnd + for i := 0; i < nLevels; i++ { + if int(k[i/8]&(1<<(i%8))) != 0 { + r = r[mid:] + mid = len(r) / 2 //nolint:gomnd + } else { + r = r[:mid] + mid = len(r) / 2 //nolint:gomnd + } + } + return r[0] +} + +// flp2 computes the floor power of 2, the highest power of 2 under the given +// value. +func flp2(n int) int { + res := 0 + for i := n; i >= 1; i-- { + if (i & (i - 1)) == 0 { + res = i + break + } + } + return res +} + // returns an array of key-values to store in the db func (n *node) computeHashes(currLvl, maxLvl int, p *params, pairs [][2][]byte) ( [][2][]byte, error) { @@ -539,8 +573,7 @@ func (n *node) computeHashes(currLvl, maxLvl int, p *params, pairs [][2][]byte) pairs = append(pairs, kv) case vtEmpty: default: - fmt.Println("n.computeHashes type no match", t) - return nil, fmt.Errorf("ERR TMP") // TODO + return nil, fmt.Errorf("ERR:n.computeHashes type (%d) no match", t) // TODO TMP } return pairs, nil