You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

657 lines
16 KiB

  1. package arbo
  2. import (
  3. "bytes"
  4. "fmt"
  5. "math"
  6. "runtime"
  7. "sort"
  8. "sync"
  9. "github.com/iden3/go-merkletree/db"
  10. )
  11. /*
  12. AddBatch design
  13. ===============
  14. CASE A: Empty Tree --> if tree is empty (root==0)
  15. =================================================
  16. - Build the full tree from bottom to top (from all the leaf to the root)
  17. CASE B: ALMOST CASE A, Almost empty Tree --> if Tree has numLeafs < minLeafsThreshold
  18. ==============================================================================
  19. - Get the Leafs (key & value) (iterate the tree from the current root getting
  20. the leafs)
  21. - Create a new empty Tree
  22. - Do CASE A for the new Tree, giving the already existing key&values (leafs)
  23. from the original Tree + the new key&values to be added from the AddBatch call
  24. R R
  25. / \ / \
  26. A * / \
  27. / \ / \
  28. B C * *
  29. / | / \
  30. / | / \
  31. / | / \
  32. L: A B G D
  33. / \
  34. / \
  35. / \
  36. C *
  37. / \
  38. / \
  39. / \
  40. ... ... (nLeafs < minLeafsThreshold)
  41. CASE C: ALMOST CASE B --> if Tree has few Leafs (but numLeafs>=minLeafsThreshold)
  42. ==============================================================================
  43. - Use A, B, G, F as Roots of subtrees
  44. - Do CASE B for each subtree
  45. - Then go from L to the Root
  46. R
  47. / \
  48. / \
  49. / \
  50. * *
  51. / | / \
  52. / | / \
  53. / | / \
  54. L: A B G D
  55. / \
  56. / \
  57. / \
  58. C *
  59. / \
  60. / \
  61. / \
  62. ... ... (nLeafs >= minLeafsThreshold)
  63. CASE D: Already populated Tree
  64. ==============================
  65. - Use A, B, C, D as subtree
  66. - Sort the Keys in Buckets that share the initial part of the path
  67. - For each subtree add there the new leafs
  68. R
  69. / \
  70. / \
  71. / \
  72. * *
  73. / | / \
  74. / | / \
  75. / | / \
  76. L: A B C D
  77. /\ /\ / \ / \
  78. ... ... ... ... ... ...
  79. CASE E: Already populated Tree Unbalanced
  80. =========================================
  81. - Need to fill M1 and M2, and then will be able to use CASE D
  82. - Search for M1 & M2 in the inputed Keys
  83. - Add M1 & M2 to the Tree
  84. - From here can use CASE D
  85. R
  86. / \
  87. / \
  88. / \
  89. * *
  90. | \
  91. | \
  92. | \
  93. L: M1 * M2 * (where M1 and M2 are empty)
  94. / | /
  95. / | /
  96. / | /
  97. A * *
  98. / \ | \
  99. / \ | \
  100. / \ | \
  101. B * * C
  102. / \ |\
  103. ... ... | \
  104. | \
  105. D E
  106. Algorithm decision
  107. ==================
  108. - if nLeafs==0 (root==0): CASE A
  109. - if nLeafs<minLeafsThreshold: CASE B
  110. - if nLeafs>=minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold: CASE C
  111. - else: CASE D & CASE E
  112. - Multiple tree.Add calls: O(n log n)
  113. - Used in: cases A, B, C
  114. - Tree from bottom to top: O(log n)
  115. - Used in: cases D, E
  116. */
  117. const (
  118. minLeafsThreshold = 100 // nolint:gomnd // TMP WIP this will be autocalculated
  119. )
  120. // AddBatchOpt is the WIP implementation of the AddBatch method in a more
  121. // optimized approach.
  122. func (t *Tree) AddBatchOpt(keys, values [][]byte) ([]int, error) {
  123. t.updateAccessTime()
  124. t.Lock()
  125. defer t.Unlock()
  126. // TODO if len(keys) is not a power of 2, add padding of empty
  127. // keys&values. Maybe when len(keyvalues) is not a power of 2, cut at
  128. // the biggest power of 2 under the len(keys), add those 2**n key-values
  129. // using the AddBatch approach, and then add the remaining key-values
  130. // using tree.Add.
  131. kvs, err := t.keysValuesToKvs(keys, values)
  132. if err != nil {
  133. return nil, err
  134. }
  135. t.tx, err = t.db.NewTx() // TODO add t.tx.Commit()
  136. if err != nil {
  137. return nil, err
  138. }
  139. // TODO if nCPU is not a power of two, cut at the highest power of two
  140. // under nCPU
  141. nCPU := runtime.NumCPU()
  142. l := int(math.Log2(float64(nCPU)))
  143. // CASE A: if nLeafs==0 (root==0)
  144. if bytes.Equal(t.root, t.emptyHash) {
  145. // if len(kvs) is not a power of 2, cut at the bigger power
  146. // of two under len(kvs), build the tree with that, and add
  147. // later the excedents
  148. kvsP2, kvsNonP2 := cutPowerOfTwo(kvs)
  149. invalids, err := t.buildTreeBottomUp(nCPU, kvsP2)
  150. if err != nil {
  151. return nil, err
  152. }
  153. for i := 0; i < len(kvsNonP2); i++ {
  154. err = t.add(0, kvsNonP2[i].k, kvsNonP2[i].v)
  155. if err != nil {
  156. invalids = append(invalids, kvsNonP2[i].pos)
  157. }
  158. }
  159. // store root to db
  160. if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
  161. return nil, err
  162. }
  163. if err = t.tx.Commit(); err != nil {
  164. return nil, err
  165. }
  166. return invalids, nil
  167. }
  168. // CASE B: if nLeafs<nBuckets
  169. nLeafs, err := t.GetNLeafs()
  170. if err != nil {
  171. return nil, err
  172. }
  173. if nLeafs < minLeafsThreshold { // CASE B
  174. invalids, excedents, err := t.caseB(0, kvs)
  175. if err != nil {
  176. return nil, err
  177. }
  178. // add the excedents
  179. for i := 0; i < len(excedents); i++ {
  180. err = t.add(0, excedents[i].k, excedents[i].v)
  181. if err != nil {
  182. invalids = append(invalids, excedents[i].pos)
  183. }
  184. }
  185. // store root to db
  186. if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
  187. return nil, err
  188. }
  189. if err = t.tx.Commit(); err != nil {
  190. return nil, err
  191. }
  192. return invalids, nil
  193. }
  194. // CASE C: if nLeafs>=minLeafsThreshold && (nLeafs/nBuckets) < minLeafsThreshold
  195. // available parallelization, will need to be a power of 2 (2**n)
  196. var excedents []kv
  197. if nLeafs >= minLeafsThreshold && (nLeafs/nCPU) < minLeafsThreshold {
  198. // TODO move to own function
  199. // 1. go down until level L (L=log2(nBuckets))
  200. keysAtL, err := t.getKeysAtLevel(l + 1)
  201. if err != nil {
  202. return nil, err
  203. }
  204. buckets := splitInBuckets(kvs, nCPU)
  205. // 2. use keys at level L as roots of the subtrees under each one
  206. var subRoots [][]byte
  207. // TODO parallelize
  208. for i := 0; i < len(keysAtL); i++ {
  209. bucketTree := Tree{tx: t.tx, db: t.db, maxLevels: t.maxLevels,
  210. hashFunction: t.hashFunction, root: keysAtL[i]}
  211. // 3. and do CASE B for each
  212. _, bucketExcedents, err := bucketTree.caseB(l, buckets[i])
  213. if err != nil {
  214. return nil, err
  215. }
  216. excedents = append(excedents, bucketExcedents...)
  217. subRoots = append(subRoots, bucketTree.root)
  218. }
  219. // 4. go upFromKeys from the new roots of the subtrees
  220. newRoot, err := t.upFromKeys(subRoots)
  221. if err != nil {
  222. return nil, err
  223. }
  224. t.root = newRoot
  225. // add the key-values that have not been used yet
  226. var invalids []int
  227. for i := 0; i < len(excedents); i++ {
  228. // Add until the level L
  229. err = t.add(0, excedents[i].k, excedents[i].v)
  230. if err != nil {
  231. invalids = append(invalids, excedents[i].pos) // TODO WIP
  232. }
  233. }
  234. // store root to db
  235. if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
  236. return nil, err
  237. }
  238. if err = t.tx.Commit(); err != nil {
  239. return nil, err
  240. }
  241. return invalids, nil
  242. }
  243. keysAtL, err := t.getKeysAtLevel(l + 1)
  244. if err != nil {
  245. return nil, err
  246. }
  247. // CASE D
  248. if len(keysAtL) == nCPU { // enter in CASE D if len(keysAtL)=nCPU, if not, CASE E
  249. invalids, err := t.caseD(nCPU, l, keysAtL, kvs)
  250. if err != nil {
  251. return nil, err
  252. }
  253. // store root to db
  254. if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
  255. return nil, err
  256. }
  257. if err = t.tx.Commit(); err != nil {
  258. return nil, err
  259. }
  260. return invalids, nil
  261. }
  262. // CASE E: add one key of each bucket, and then do CASE D
  263. // TODO store t.root into DB
  264. // TODO update NLeafs from DB
  265. return nil, fmt.Errorf("UNIMPLEMENTED")
  266. }
  267. func (t *Tree) caseB(l int, kvs []kv) ([]int, []kv, error) {
  268. // get already existing keys
  269. aKs, aVs, err := t.getLeafs(t.root)
  270. if err != nil {
  271. return nil, nil, err
  272. }
  273. aKvs, err := t.keysValuesToKvs(aKs, aVs)
  274. if err != nil {
  275. return nil, nil, err
  276. }
  277. // add already existing key-values to the inputted key-values
  278. kvs = append(kvs, aKvs...)
  279. // proceed with CASE A
  280. sortKvs(kvs)
  281. // cutPowerOfTwo, the excedent add it as normal Tree.Add
  282. kvsP2, kvsNonP2 := cutPowerOfTwo(kvs)
  283. invalids, err := t.buildTreeBottomUpSingleThread(kvsP2)
  284. if err != nil {
  285. return nil, nil, err
  286. }
  287. // return the excedents which will be added at the full tree at the end
  288. return invalids, kvsNonP2, nil
  289. }
  290. func (t *Tree) caseD(nCPU, l int, keysAtL [][]byte, kvs []kv) ([]int, error) {
  291. buckets := splitInBuckets(kvs, nCPU)
  292. subRoots := make([][]byte, nCPU)
  293. invalidsInBucket := make([][]int, nCPU)
  294. txs := make([]db.Tx, nCPU)
  295. var wg sync.WaitGroup
  296. wg.Add(nCPU)
  297. for i := 0; i < nCPU; i++ {
  298. go func(cpu int) {
  299. var err error
  300. txs[cpu], err = t.db.NewTx()
  301. if err != nil {
  302. panic(err) // TODO WIP
  303. }
  304. bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels - l, // maxLevels-l
  305. hashFunction: t.hashFunction, root: keysAtL[cpu]}
  306. for j := 0; j < len(buckets[cpu]); j++ {
  307. if err = bucketTree.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil {
  308. invalidsInBucket[cpu] = append(invalidsInBucket[cpu], buckets[cpu][j].pos)
  309. }
  310. }
  311. subRoots[cpu] = bucketTree.root
  312. wg.Done()
  313. }(i)
  314. }
  315. wg.Wait()
  316. // merge buckets txs into Tree.tx
  317. for i := 0; i < len(txs); i++ {
  318. if err := t.tx.Add(txs[i]); err != nil {
  319. return nil, err
  320. }
  321. }
  322. newRoot, err := t.upFromKeys(subRoots)
  323. if err != nil {
  324. return nil, err
  325. }
  326. t.root = newRoot
  327. var invalids []int
  328. for i := 0; i < len(invalidsInBucket); i++ {
  329. invalids = append(invalids, invalidsInBucket[i]...)
  330. }
  331. return invalids, nil
  332. }
  333. func splitInBuckets(kvs []kv, nBuckets int) [][]kv {
  334. buckets := make([][]kv, nBuckets)
  335. // 1. classify the keyvalues into buckets
  336. for i := 0; i < len(kvs); i++ {
  337. pair := kvs[i]
  338. // bucketnum := keyToBucket(pair.k, nBuckets)
  339. bucketnum := keyToBucket(pair.keyPath, nBuckets)
  340. buckets[bucketnum] = append(buckets[bucketnum], pair)
  341. }
  342. return buckets
  343. }
  344. // TODO rename in a more 'real' name (calculate bucket from/for key)
  345. func keyToBucket(k []byte, nBuckets int) int {
  346. nLevels := int(math.Log2(float64(nBuckets)))
  347. b := make([]int, nBuckets)
  348. for i := 0; i < nBuckets; i++ {
  349. b[i] = i
  350. }
  351. r := b
  352. mid := len(r) / 2 //nolint:gomnd
  353. for i := 0; i < nLevels; i++ {
  354. if int(k[i/8]&(1<<(i%8))) != 0 {
  355. r = r[mid:]
  356. mid = len(r) / 2 //nolint:gomnd
  357. } else {
  358. r = r[:mid]
  359. mid = len(r) / 2 //nolint:gomnd
  360. }
  361. }
  362. return r[0]
  363. }
  364. type kv struct {
  365. pos int // original position in the array
  366. keyPath []byte
  367. k []byte
  368. v []byte
  369. }
  370. // compareBytes compares byte slices where the bytes are compared from left to
  371. // right and each byte is compared by bit from right to left
  372. func compareBytes(a, b []byte) bool {
  373. // WIP
  374. for i := 0; i < len(a); i++ {
  375. for j := 0; j < 8; j++ {
  376. aBit := a[i] & (1 << j)
  377. bBit := b[i] & (1 << j)
  378. if aBit > bBit {
  379. return false
  380. } else if aBit < bBit {
  381. return true
  382. }
  383. }
  384. }
  385. return false
  386. }
  387. // sortKvs sorts the kv by path
  388. func sortKvs(kvs []kv) {
  389. sort.Slice(kvs, func(i, j int) bool {
  390. return compareBytes(kvs[i].keyPath, kvs[j].keyPath)
  391. })
  392. }
  393. func (t *Tree) keysValuesToKvs(ks, vs [][]byte) ([]kv, error) {
  394. if len(ks) != len(vs) {
  395. return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
  396. len(ks), len(vs))
  397. }
  398. kvs := make([]kv, len(ks))
  399. for i := 0; i < len(ks); i++ {
  400. keyPath := make([]byte, t.hashFunction.Len())
  401. copy(keyPath[:], ks[i])
  402. kvs[i].pos = i
  403. kvs[i].keyPath = ks[i]
  404. kvs[i].k = ks[i]
  405. kvs[i].v = vs[i]
  406. }
  407. return kvs, nil
  408. }
  409. /*
  410. func (t *Tree) kvsToKeysValues(kvs []kv) ([][]byte, [][]byte) {
  411. ks := make([][]byte, len(kvs))
  412. vs := make([][]byte, len(kvs))
  413. for i := 0; i < len(kvs); i++ {
  414. ks[i] = kvs[i].k
  415. vs[i] = kvs[i].v
  416. }
  417. return ks, vs
  418. }
  419. */
  420. // buildTreeBottomUp splits the key-values into n Buckets (where n is the number
  421. // of CPUs), in parallel builds a subtree for each bucket, once all the subtrees
  422. // are built, uses the subtrees roots as keys for a new tree, which as result
  423. // will have the complete Tree build from bottom to up, where until the
  424. // log2(nCPU) level it has been computed in parallel.
  425. func (t *Tree) buildTreeBottomUp(nCPU int, kvs []kv) ([]int, error) {
  426. buckets := splitInBuckets(kvs, nCPU)
  427. subRoots := make([][]byte, nCPU)
  428. invalidsInBucket := make([][]int, nCPU)
  429. txs := make([]db.Tx, nCPU)
  430. var wg sync.WaitGroup
  431. wg.Add(nCPU)
  432. for i := 0; i < nCPU; i++ {
  433. go func(cpu int) {
  434. sortKvs(buckets[cpu])
  435. var err error
  436. txs[cpu], err = t.db.NewTx()
  437. if err != nil {
  438. panic(err) // TODO
  439. }
  440. bucketTree := Tree{tx: txs[cpu], db: t.db, maxLevels: t.maxLevels,
  441. hashFunction: t.hashFunction, root: t.emptyHash}
  442. currInvalids, err := bucketTree.buildTreeBottomUpSingleThread(buckets[cpu])
  443. if err != nil {
  444. panic(err) // TODO
  445. }
  446. invalidsInBucket[cpu] = currInvalids
  447. subRoots[cpu] = bucketTree.root
  448. wg.Done()
  449. }(i)
  450. }
  451. wg.Wait()
  452. // merge buckets txs into Tree.tx
  453. for i := 0; i < len(txs); i++ {
  454. if err := t.tx.Add(txs[i]); err != nil {
  455. return nil, err
  456. }
  457. }
  458. newRoot, err := t.upFromKeys(subRoots)
  459. if err != nil {
  460. return nil, err
  461. }
  462. t.root = newRoot
  463. var invalids []int
  464. for i := 0; i < len(invalidsInBucket); i++ {
  465. invalids = append(invalids, invalidsInBucket[i]...)
  466. }
  467. return invalids, err
  468. }
  469. // buildTreeBottomUpSingleThread builds the tree with the given []kv from bottom
  470. // to the root. keys & values must be sorted by path, and the array ks must be
  471. // length multiple of 2
  472. func (t *Tree) buildTreeBottomUpSingleThread(kvs []kv) ([]int, error) {
  473. // TODO check that log2(len(leafs)) < t.maxLevels, if not, maxLevels
  474. // would be reached and should return error
  475. var invalids []int
  476. // build the leafs
  477. leafKeys := make([][]byte, len(kvs))
  478. for i := 0; i < len(kvs); i++ {
  479. // TODO handle the case where Key&Value == 0
  480. leafKey, leafValue, err := newLeafValue(t.hashFunction, kvs[i].k, kvs[i].v)
  481. if err != nil {
  482. // return nil, err
  483. invalids = append(invalids, kvs[i].pos)
  484. }
  485. // store leafKey & leafValue to db
  486. if err := t.tx.Put(leafKey, leafValue); err != nil {
  487. // return nil, err
  488. invalids = append(invalids, kvs[i].pos)
  489. }
  490. leafKeys[i] = leafKey
  491. }
  492. r, err := t.upFromKeys(leafKeys)
  493. if err != nil {
  494. return invalids, err
  495. }
  496. t.root = r
  497. return invalids, nil
  498. }
  499. // keys & values must be sorted by path, and the array ks must be length
  500. // multiple of 2
  501. func (t *Tree) upFromKeys(ks [][]byte) ([]byte, error) {
  502. if len(ks) == 1 {
  503. return ks[0], nil
  504. }
  505. var rKs [][]byte
  506. for i := 0; i < len(ks); i += 2 {
  507. // TODO handle the case where Key&Value == 0
  508. k, v, err := newIntermediate(t.hashFunction, ks[i], ks[i+1])
  509. if err != nil {
  510. return nil, err
  511. }
  512. // store k-v to db
  513. if err = t.tx.Put(k, v); err != nil {
  514. return nil, err
  515. }
  516. rKs = append(rKs, k)
  517. }
  518. return t.upFromKeys(rKs)
  519. }
  520. func (t *Tree) getLeafs(root []byte) ([][]byte, [][]byte, error) {
  521. var ks, vs [][]byte
  522. err := t.iter(root, func(k, v []byte) {
  523. if v[0] != PrefixValueLeaf {
  524. return
  525. }
  526. leafK, leafV := readLeafValue(v)
  527. ks = append(ks, leafK)
  528. vs = append(vs, leafV)
  529. })
  530. return ks, vs, err
  531. }
  532. func (t *Tree) getKeysAtLevel(l int) ([][]byte, error) {
  533. var keys [][]byte
  534. err := t.iterWithStop(t.root, 0, func(currLvl int, k, v []byte) bool {
  535. if currLvl == l {
  536. keys = append(keys, k)
  537. }
  538. if currLvl >= l {
  539. return true // to stop the iter from going down
  540. }
  541. return false
  542. })
  543. return keys, err
  544. }
  545. // cutPowerOfTwo returns []kv of length that is a power of 2, and a second []kv
  546. // with the extra elements that don't fit in a power of 2 length
  547. func cutPowerOfTwo(kvs []kv) ([]kv, []kv) {
  548. x := len(kvs)
  549. if (x & (x - 1)) != 0 {
  550. p2 := highestPowerOfTwo(x)
  551. return kvs[:p2], kvs[p2:]
  552. }
  553. return kvs, nil
  554. }
  555. func highestPowerOfTwo(n int) int {
  556. res := 0
  557. for i := n; i >= 1; i-- {
  558. if (i & (i - 1)) == 0 {
  559. res = i
  560. break
  561. }
  562. }
  563. return res
  564. }
  565. // func computeSimpleAddCost(nLeafs int) int {
  566. // // nLvls 2^nLvls
  567. // nLvls := int(math.Log2(float64(nLeafs)))
  568. // return nLvls * int(math.Pow(2, float64(nLvls)))
  569. // }
  570. //
  571. // func computeBottomUpAddCost(nLeafs int) int {
  572. // // 2^nLvls * 2 - 1
  573. // nLvls := int(math.Log2(float64(nLeafs)))
  574. // return (int(math.Pow(2, float64(nLvls))) * 2) - 1
  575. // }