diff --git a/addbatch.go b/addbatch.go index 4f0987b..345d080 100644 --- a/addbatch.go +++ b/addbatch.go @@ -152,25 +152,23 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { t.Lock() defer t.Unlock() - // TODO if len(keys) is not a power of 2, add padding of empty - // keys&values. Maybe when len(keyvalues) is not a power of 2, cut at - // the biggest power of 2 under the len(keys), add those 2**n key-values - // using the AddBatch approach, and then add the remaining key-values - // using tree.Add. + // when len(keyvalues) is not a power of 2, cut at the biggest power of + // 2 under the len(keys), add those 2**n key-values using the AddBatch + // approach, and then add the remaining key-values using tree.Add. kvs, err := t.keysValuesToKvs(keys, values) if err != nil { return nil, err } - t.tx, err = t.db.NewTx() // TODO add t.tx.Commit() + t.tx, err = t.db.NewTx() if err != nil { return nil, err } - // TODO if nCPU is not a power of two, cut at the highest power of two - // under nCPU - nCPU := runtime.NumCPU() + // if nCPU is not a power of two, cut at the highest power of two under + // nCPU + nCPU := highestPowerOfTwo(runtime.NumCPU()) l := int(math.Log2(float64(nCPU))) // CASE A: if nLeafs==0 (root==0) @@ -205,7 +203,7 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { return nil, err } if nLeafs < minLeafsThreshold { // CASE B - invalids, excedents, err := t.caseB(0, kvs) + invalids, excedents, err := t.caseB(nCPU, 0, kvs) if err != nil { return nil, err } @@ -226,34 +224,61 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { return invalids, nil } + keysAtL, err := t.getKeysAtLevel(l + 1) + if err != nil { + return nil, err + } + // CASE C: if nLeafs>=minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold // available parallelization, will need to be a power of 2 (2**n) var excedents []kv - if nLeafs >= minLeafsThreshold && (nLeafs/nCPU) < minLeafsThreshold { + if nLeafs >= minLeafsThreshold && + (nLeafs/nCPU) < minLeafsThreshold && + len(keysAtL) == nCPU { // TODO move to own function // 1. go down until level L (L=log2(nBuckets)) - keysAtL, err := t.getKeysAtLevel(l + 1) - if err != nil { - return nil, err - } buckets := splitInBuckets(kvs, nCPU) // 2. use keys at level L as roots of the subtrees under each one - var subRoots [][]byte - // TODO parallelize - for i := 0; i < len(keysAtL); i++ { - bucketTree := Tree{tx: t.tx, db: t.db, maxLevels: t.maxLevels, - hashFunction: t.hashFunction, root: keysAtL[i]} - - // 3. and do CASE B for each - _, bucketExcedents, err := bucketTree.caseB(l, buckets[i]) - if err != nil { + excedentsInBucket := make([][]kv, nCPU) + subRoots := make([][]byte, nCPU) + txs := make([]db.Tx, nCPU) + var wg sync.WaitGroup + wg.Add(nCPU) + for i := 0; i < nCPU; i++ { + go func(cpu int) { + var err error + txs[cpu], err = t.db.NewTx() + if err != nil { + panic(err) // TODO WIP + } + bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels, + hashFunction: t.hashFunction, root: keysAtL[cpu]} + + // 3. and do CASE B (with 1 cpu) for each + _, bucketExcedents, err := bucketTree.caseB(1, l, buckets[cpu]) + if err != nil { + panic(err) + // return nil, err + } + excedentsInBucket[cpu] = bucketExcedents + subRoots[cpu] = bucketTree.root + wg.Done() + }(i) + } + wg.Wait() + + // merge buckets txs into Tree.tx + for i := 0; i < len(txs); i++ { + if err := t.tx.Add(txs[i]); err != nil { return nil, err } - excedents = append(excedents, bucketExcedents...) - subRoots = append(subRoots, bucketTree.root) } + for i := 0; i < len(excedentsInBucket); i++ { + excedents = append(excedents, excedentsInBucket[i]...) + } + // 4. go upFromKeys from the new roots of the subtrees newRoot, err := t.upFromKeys(subRoots) if err != nil { @@ -281,16 +306,52 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { return invalids, nil } - keysAtL, err := t.getKeysAtLevel(l + 1) - if err != nil { - return nil, err + var invalids []int + // CASE E + if len(keysAtL) != nCPU { + // CASE E: add one key at each bucket, and then do CASE D + buckets := splitInBuckets(kvs, nCPU) + kvs = []kv{} + for i := 0; i < len(buckets); i++ { + err = t.add(0, buckets[i][0].k, buckets[i][0].v) + if err != nil { + invalids = append(invalids, buckets[i][0].pos) + // TODO if err, add another key-value from the + // same bucket + } + kvs = append(kvs, buckets[i][1:]...) + } + keysAtL, err = t.getKeysAtLevel(l + 1) + if err != nil { + return nil, err + } } + + if nCPU == 1 { // CASE D, but with 1 cpu + for i := 0; i < len(keys); i++ { + err = t.add(0, keys[i], values[i]) + if err != nil { + invalids = append(invalids, i) + } + } + // store root to db + if err := t.tx.Put(dbKeyRoot, t.root); err != nil { + return nil, err + } + + if err = t.tx.Commit(); err != nil { + return nil, err + } + return invalids, nil + } + // CASE D if len(keysAtL) == nCPU { // enter in CASE D if len(keysAtL)=nCPU, if not, CASE E - invalids, err := t.caseD(nCPU, l, keysAtL, kvs) + invalidsCaseD, err := t.caseD(nCPU, l, keysAtL, kvs) if err != nil { return nil, err } + invalids = append(invalids, invalidsCaseD...) // store root to db if err := t.tx.Put(dbKeyRoot, t.root); err != nil { return nil, err @@ -302,15 +363,12 @@ func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) { return invalids, nil } - // CASE E: add one key of each bucket, and then do CASE D - - // TODO store t.root into DB // TODO update NLeafs from DB return nil, fmt.Errorf("UNIMPLEMENTED") } -func (t *Tree) caseB(l int, kvs []kv) ([]int, []kv, error) { +func (t *Tree) caseB(nCPU, l int, kvs []kv) ([]int, []kv, error) { // get already existing keys aKs, aVs, err := t.getLeafs(t.root) if err != nil { @@ -328,9 +386,17 @@ func (t *Tree) caseB(l int, kvs []kv) ([]int, []kv, error) { // cutPowerOfTwo, the excedent add it as normal Tree.Add kvsP2, kvsNonP2 := cutPowerOfTwo(kvs) - invalids, err := t.buildTreeBottomUpSingleThread(kvsP2) - if err != nil { - return nil, nil, err + var invalids []int + if nCPU > 1 { + invalids, err = t.buildTreeBottomUp(nCPU, kvsP2) + if err != nil { + return nil, nil, err + } + } else { + invalids, err = t.buildTreeBottomUpSingleThread(kvsP2) + if err != nil { + return nil, nil, err + } } // return the excedents which will be added at the full tree at the end return invalids, kvsNonP2, nil @@ -352,7 +418,14 @@ func (t *Tree) caseD(nCPU, l int, keysAtL [][]byte, kvs []kv) ([]int, error) { if err != nil { panic(err) // TODO WIP } - bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels - l, // maxLevels-l + // put already existing tx into txs[cpu], as txs[cpu] + // needs the pending key-values that are not in tree.db, + // but are in tree.tx + if err := txs[cpu].Add(t.tx); err != nil { + panic(err) // TODO WIP + } + + bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels - l, hashFunction: t.hashFunction, root: keysAtL[cpu]} for j := 0; j < len(buckets[cpu]); j++ { @@ -490,6 +563,7 @@ func (t *Tree) kvsToKeysValues(kvs []kv) ([][]byte, [][]byte) { // log2(nCPU) level it has been computed in parallel. func (t *Tree) buildTreeBottomUp(nCPU int, kvs []kv) ([]int, error) { buckets := splitInBuckets(kvs, nCPU) + subRoots := make([][]byte, nCPU) invalidsInBucket := make([][]int, nCPU) txs := make([]db.Tx, nCPU) @@ -610,7 +684,7 @@ func (t *Tree) getLeafs(root []byte) ([][]byte, [][]byte, error) { func (t *Tree) getKeysAtLevel(l int) ([][]byte, error) { var keys [][]byte err := t.iterWithStop(t.root, 0, func(currLvl int, k, v []byte) bool { - if currLvl == l { + if currLvl == l && !bytes.Equal(k, t.emptyHash) { keys = append(keys, k) } if currLvl >= l { diff --git a/addbatch_test.go b/addbatch_test.go index 9e1d488..a90d042 100644 --- a/addbatch_test.go +++ b/addbatch_test.go @@ -415,6 +415,72 @@ func TestAddBatchCaseD(t *testing.T) { c.Check(tree2.Root(), qt.DeepEquals, tree.Root()) } +func TestAddBatchCaseE(t *testing.T) { + c := qt.New(t) + + nLeafs := 4096 + + tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + defer tree.db.Close() + + start := time.Now() + for i := 0; i < nLeafs; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + if err := tree.Add(k, v); err != nil { + t.Fatal(err) + } + } + fmt.Println("time elapsed without CASE E: ", time.Since(start)) + + tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + defer tree2.db.Close() + + var keys, values [][]byte + // add the initial leafs to fill a bit the tree before calling the + // AddBatch method + for i := 0; i < 900; i++ { // TMP TODO use const minLeafsThreshold+1 once ready + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + // use only the keys of one bucket, store the not used ones for + // later + if i%4 != 0 { + keys = append(keys, k) + values = append(values, v) + continue + } + if err := tree2.Add(k, v); err != nil { + t.Fatal(err) + } + } + + for i := 900; i < nLeafs; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + keys = append(keys, k) + values = append(values, v) + } + start = time.Now() + indexes, err := tree2.AddBatchOpt(keys, values) + c.Assert(err, qt.IsNil) + fmt.Println("time elapsed with CASE E: ", time.Since(start)) + c.Check(len(indexes), qt.Equals, 0) + + // check that both trees roots are equal + c.Check(tree2.Root(), qt.DeepEquals, tree.Root()) +} + +func TestHighestPowerOfTwo(t *testing.T) { + c := qt.New(t) + c.Assert(highestPowerOfTwo(31), qt.Equals, 16) + c.Assert(highestPowerOfTwo(32), qt.Equals, 32) + c.Assert(highestPowerOfTwo(33), qt.Equals, 32) + c.Assert(highestPowerOfTwo(63), qt.Equals, 32) + c.Assert(highestPowerOfTwo(64), qt.Equals, 64) +} + // func printLeafs(name string, t *Tree) { // w := bytes.NewBufferString("") // @@ -446,3 +512,8 @@ func TestAddBatchCaseD(t *testing.T) { // TODO test tree with nLeafs > minLeafsThreshold, but that at level L, there is // less keys than nBuckets (so CASE C could be applied if first few leafs are // added to balance the tree) + +// TODO for Cases tests, add initial keys, do snapshot, and then measure time of +// adding the rest of keys with loop over normal Add, and with AddBatch + +// TODO test adding batch with repeated keys in the batch