diff --git a/addbatch.go b/addbatch.go index 73c74ba..23aa96f 100644 --- a/addbatch.go +++ b/addbatch.go @@ -169,7 +169,7 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { // if nCPU is not a power of two, cut at the highest power of two under // nCPU - nCPU := highestPowerOfTwo(runtime.NumCPU()) + nCPU := flp2(runtime.NumCPU()) l := int(math.Log2(float64(nCPU))) var invalids []int @@ -189,18 +189,10 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { return nil, err } if nLeafs < minLeafsThreshold { // CASE B - var excedents []kv - invalids, excedents, err = t.caseB(nCPU, 0, kvs) + invalids, err = t.caseB(nCPU, 0, kvs) if err != nil { return nil, err } - // add the excedents - for i := 0; i < len(excedents); i++ { - err = t.add(0, excedents[i].k, excedents[i].v) - if err != nil { - invalids = append(invalids, excedents[i].pos) - } - } return t.finalizeAddBatch(len(keys), invalids) } @@ -283,31 +275,22 @@ func (t *Tree) finalizeAddBatch(nKeys int, invalids []int) ([]int, error) { } func (t *Tree) caseA(nCPU int, kvs []kv) ([]int, error) { - // if len(kvs) is not a power of 2, cut at the bigger power - // of two under len(kvs), build the tree with that, and add - // later the excedents - kvsP2, kvsNonP2 := cutPowerOfTwo(kvs) - invalids, err := t.buildTreeBottomUp(nCPU, kvsP2) + invalids, err := t.buildTreeBottomUp(nCPU, kvs) if err != nil { return nil, err } - for i := 0; i < len(kvsNonP2); i++ { - if err = t.add(0, kvsNonP2[i].k, kvsNonP2[i].v); err != nil { - invalids = append(invalids, kvsNonP2[i].pos) - } - } return invalids, nil } -func (t *Tree) caseB(nCPU, l int, kvs []kv) ([]int, []kv, error) { +func (t *Tree) caseB(nCPU, l int, kvs []kv) ([]int, error) { // get already existing keys aKs, aVs, err := t.getLeafs(t.root) if err != nil { - return nil, nil, err + return nil, err } aKvs, err := t.keysValuesToKvs(aKs, aVs) if err != nil { - return nil, nil, err + return nil, err } // add already existing key-values to the inputted key-values // kvs = append(kvs, aKvs...) @@ -316,23 +299,20 @@ func (t *Tree) caseB(nCPU, l int, kvs []kv) ([]int, []kv, error) { // proceed with CASE A sortKvs(kvs) - // cutPowerOfTwo, the excedent add it as normal Tree.Add - kvsP2, kvsNonP2 := cutPowerOfTwo(kvs) var invalids2 []int if nCPU > 1 { - invalids2, err = t.buildTreeBottomUp(nCPU, kvsP2) + invalids2, err = t.buildTreeBottomUp(nCPU, kvs) if err != nil { - return nil, nil, err + return nil, err } } else { - invalids2, err = t.buildTreeBottomUpSingleThread(l, kvsP2) + invalids2, err = t.buildTreeBottomUpSingleThread(l, kvs) if err != nil { - return nil, nil, err + return nil, err } } invalids = append(invalids, invalids2...) - // return the excedents which will be added at the full tree at the end - return invalids, kvsNonP2, nil + return invalids, nil } func (t *Tree) caseC(nCPU, l int, keysAtL [][]byte, kvs []kv) ([]int, error) { @@ -342,7 +322,6 @@ func (t *Tree) caseC(nCPU, l int, keysAtL [][]byte, kvs []kv) ([]int, error) { buckets := splitInBuckets(kvs, nCPU) // 2. use keys at level L as roots of the subtrees under each one - excedentsInBucket := make([][]kv, nCPU) subRoots := make([][]byte, nCPU) txs := make([]db.Tx, nCPU) var wg sync.WaitGroup @@ -361,12 +340,11 @@ func (t *Tree) caseC(nCPU, l int, keysAtL [][]byte, kvs []kv) ([]int, error) { hashFunction: t.hashFunction, root: keysAtL[cpu]} // 3. do CASE B (with 1 cpu) for each key at level L - _, bucketExcedents, err := bucketTree.caseB(1, l, buckets[cpu]) + _, err = bucketTree.caseB(1, l, buckets[cpu]) // TODO handle invalids if err != nil { panic(err) // TODO WIP // return nil, err } - excedentsInBucket[cpu] = bucketExcedents subRoots[cpu] = bucketTree.root wg.Done() }(i) @@ -379,9 +357,6 @@ func (t *Tree) caseC(nCPU, l int, keysAtL [][]byte, kvs []kv) ([]int, error) { return nil, err } } - for i := 0; i < len(excedentsInBucket); i++ { - excedents = append(excedents, excedentsInBucket[i]...) - } // 4. go upFromKeys from the new roots of the subtrees newRoot, err := t.upFromKeys(subRoots) @@ -544,7 +519,7 @@ func (t *Tree) keysValuesToKvs(ks, vs [][]byte) ([]kv, error) { keyPath := make([]byte, t.hashFunction.Len()) copy(keyPath[:], ks[i]) kvs[i].pos = i - kvs[i].keyPath = ks[i] + kvs[i].keyPath = keyPath kvs[i].k = ks[i] kvs[i].v = vs[i] } @@ -715,18 +690,9 @@ func (t *Tree) getKeysAtLevel(l int) ([][]byte, error) { return keys, err } -// cutPowerOfTwo returns []kv of length that is a power of 2, and a second []kv -// with the extra elements that don't fit in a power of 2 length -func cutPowerOfTwo(kvs []kv) ([]kv, []kv) { - x := len(kvs) - if (x & (x - 1)) != 0 { - p2 := highestPowerOfTwo(x) - return kvs[:p2], kvs[p2:] - } - return kvs, nil -} - -func highestPowerOfTwo(n int) int { +// flp2 computes the floor power of 2, the highest power of 2 under the given +// value. +func flp2(n int) int { res := 0 for i := n; i >= 1; i-- { if (i & (i - 1)) == 0 { diff --git a/addbatch_test.go b/addbatch_test.go index 7e1e3ee..d35aa7d 100644 --- a/addbatch_test.go +++ b/addbatch_test.go @@ -668,13 +668,14 @@ func TestAddBatchCaseE(t *testing.T) { c.Check(tree2.Root(), qt.DeepEquals, tree1.Root()) } -func TestHighestPowerOfTwo(t *testing.T) { +func TestFlp2(t *testing.T) { c := qt.New(t) - c.Assert(highestPowerOfTwo(31), qt.Equals, 16) - c.Assert(highestPowerOfTwo(32), qt.Equals, 32) - c.Assert(highestPowerOfTwo(33), qt.Equals, 32) - c.Assert(highestPowerOfTwo(63), qt.Equals, 32) - c.Assert(highestPowerOfTwo(64), qt.Equals, 64) + c.Assert(flp2(31), qt.Equals, 16) + c.Assert(flp2(32), qt.Equals, 32) + c.Assert(flp2(33), qt.Equals, 32) + c.Assert(flp2(63), qt.Equals, 32) + c.Assert(flp2(64), qt.Equals, 64) + c.Assert(flp2(9000), qt.Equals, 8192) } // func printLeafs(name string, t *Tree) { diff --git a/tree_test.go b/tree_test.go index 9e49da5..9cf88bd 100644 --- a/tree_test.go +++ b/tree_test.go @@ -267,11 +267,11 @@ func TestGenProofAndVerify(t *testing.T) { } k := BigIntToBytes(big.NewInt(int64(7))) - _, siblings, err := tree.GenProof(k) + v := BigIntToBytes(big.NewInt(int64(14))) + proofV, siblings, err := tree.GenProof(k) c.Assert(err, qt.IsNil) + c.Assert(proofV, qt.DeepEquals, v) - k = BigIntToBytes(big.NewInt(int64(7))) - v := BigIntToBytes(big.NewInt(int64(14))) verif, err := CheckProof(tree.hashFunction, k, v, tree.Root(), siblings) c.Assert(err, qt.IsNil) c.Check(verif, qt.IsTrue) diff --git a/vt_test.go b/vt_test.go index 15fb27b..0ac8055 100644 --- a/vt_test.go +++ b/vt_test.go @@ -61,8 +61,6 @@ func TestVirtualTreeRandomKeys(t *testing.T) { values[i] = []byte{0} } - // check the root for different batches of leafs - testVirtualTree(c, 100, keys[:1], values[:1]) testVirtualTree(c, 100, keys, values) }