You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

714 lines
16 KiB

  1. // Package arbo > vt.go implements the Virtual Tree, which computes a tree
  2. // without computing any hash. With the idea of once all the leafs are placed in
  3. // their positions, the hashes can be computed, avoiding computing a node hash
  4. // more than one time.
  5. package arbo
  6. import (
  7. "bytes"
  8. "encoding/hex"
  9. "fmt"
  10. "io"
  11. "math"
  12. "runtime"
  13. "sync"
  14. )
  15. type node struct {
  16. l *node
  17. r *node
  18. k []byte
  19. v []byte
  20. path []bool
  21. h []byte
  22. }
  23. type params struct {
  24. maxLevels int
  25. hashFunction HashFunction
  26. emptyHash []byte
  27. dbg *dbgStats
  28. }
  29. type kv struct {
  30. pos int // original position in the inputted array
  31. keyPath []byte
  32. k []byte
  33. v []byte
  34. }
  35. func (p *params) keysValuesToKvs(ks, vs [][]byte) ([]kv, []int, error) {
  36. if len(ks) != len(vs) {
  37. return nil, nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
  38. len(ks), len(vs))
  39. }
  40. var invalids []int
  41. var kvs []kv
  42. for i := 0; i < len(ks); i++ {
  43. keyPath, err := keyPathFromKey(p.maxLevels, ks[i])
  44. if err != nil {
  45. // TODO in a future iteration, invalids will contain
  46. // the reason of the error of why each index is
  47. // invalid.
  48. invalids = append(invalids, i)
  49. continue
  50. }
  51. var kvsI kv
  52. kvsI.pos = i
  53. kvsI.keyPath = keyPath
  54. kvsI.k = ks[i]
  55. kvsI.v = vs[i]
  56. kvs = append(kvs, kvsI)
  57. }
  58. return kvs, invalids, nil
  59. }
  60. // vt stands for virtual tree. It's a tree that does not have any computed hash
  61. // while placing the leafs. Once all the leafs are placed, it computes all the
  62. // hashes. In this way, each node hash is only computed one time (at the end)
  63. // and the tree is computed in memory.
  64. type vt struct {
  65. root *node
  66. params *params
  67. }
  68. func newVT(maxLevels int, hash HashFunction) vt {
  69. return vt{
  70. root: nil,
  71. params: &params{
  72. maxLevels: maxLevels,
  73. hashFunction: hash,
  74. emptyHash: make([]byte, hash.Len()), // empty
  75. },
  76. }
  77. }
  78. // addBatch adds a batch of key-values to the VirtualTree. Returns an array
  79. // containing the indexes of the keys failed to add. Does not include the
  80. // computation of hashes of the nodes neither the storage of the key-values of
  81. // the tree into the db. After addBatch, vt.computeHashes should be called to
  82. // compute the hashes of all the nodes of the tree.
  83. func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) {
  84. nCPU := flp2(runtime.NumCPU())
  85. if nCPU == 1 || len(ks) < nCPU {
  86. var invalids []int
  87. for i := 0; i < len(ks); i++ {
  88. if err := t.add(0, ks[i], vs[i]); err != nil {
  89. invalids = append(invalids, i)
  90. }
  91. }
  92. return invalids, nil
  93. }
  94. l := int(math.Log2(float64(nCPU)))
  95. kvs, invalids, err := t.params.keysValuesToKvs(ks, vs)
  96. if err != nil {
  97. return invalids, err
  98. }
  99. buckets := splitInBuckets(kvs, nCPU)
  100. nodesAtL, err := t.getNodesAtLevel(l)
  101. if err != nil {
  102. return nil, err
  103. }
  104. if len(nodesAtL) != nCPU && t.root != nil {
  105. /*
  106. Already populated Tree but Unbalanced
  107. - Need to fill M1 and M2, and then will be able to continue with the flow
  108. - Search for M1 & M2 in the inputed Keys
  109. - Add M1 & M2 to the Tree
  110. - From here can continue with the flow
  111. R
  112. / \
  113. / \
  114. / \
  115. * *
  116. | \
  117. | \
  118. | \
  119. L: M1 * M2 * (where M1 and M2 are empty)
  120. / | /
  121. / | /
  122. / | /
  123. A * *
  124. / \ | \
  125. / \ | \
  126. / \ | \
  127. B * * C
  128. / \ |\
  129. ... ... | \
  130. | \
  131. D E
  132. */
  133. // add one key at each bucket, and then continue with the flow
  134. for i := 0; i < len(buckets); i++ {
  135. // add one leaf of the bucket, if there is an error when
  136. // adding the k-v, try to add the next one of the bucket
  137. // (until one is added)
  138. inserted := -1
  139. for j := 0; j < len(buckets[i]); j++ {
  140. if err := t.add(0, buckets[i][j].k, buckets[i][j].v); err == nil {
  141. inserted = j
  142. break
  143. }
  144. }
  145. // remove the inserted element from buckets[i]
  146. if inserted != -1 {
  147. buckets[i] = append(buckets[i][:inserted], buckets[i][inserted+1:]...)
  148. }
  149. }
  150. nodesAtL, err = t.getNodesAtLevel(l)
  151. if err != nil {
  152. return nil, err
  153. }
  154. }
  155. if len(nodesAtL) != nCPU {
  156. return nil, fmt.Errorf("This error should not be reached."+
  157. " len(nodesAtL) != nCPU, len(nodesAtL)=%d, nCPU=%d."+
  158. " Please report it in a new issue:"+
  159. " https://github.com/vocdoni/arbo/issues/new", len(nodesAtL), nCPU)
  160. }
  161. subRoots := make([]*node, nCPU)
  162. invalidsInBucket := make([][]int, nCPU)
  163. var wg sync.WaitGroup
  164. wg.Add(nCPU)
  165. for i := 0; i < nCPU; i++ {
  166. go func(cpu int) {
  167. bucketVT := newVT(t.params.maxLevels, t.params.hashFunction)
  168. bucketVT.root = nodesAtL[cpu]
  169. for j := 0; j < len(buckets[cpu]); j++ {
  170. if err := bucketVT.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil {
  171. invalidsInBucket[cpu] = append(invalidsInBucket[cpu], buckets[cpu][j].pos)
  172. }
  173. }
  174. subRoots[cpu] = bucketVT.root
  175. wg.Done()
  176. }(i)
  177. }
  178. wg.Wait()
  179. for i := 0; i < len(invalidsInBucket); i++ {
  180. invalids = append(invalids, invalidsInBucket[i]...)
  181. }
  182. newRootNode, err := upFromNodes(subRoots)
  183. if err != nil {
  184. return nil, err
  185. }
  186. t.root = newRootNode
  187. return invalids, nil
  188. }
  189. func (t *vt) getNodesAtLevel(l int) ([]*node, error) {
  190. if t.root == nil {
  191. var r []*node
  192. nChilds := int(math.Pow(2, float64(l))) //nolint:gomnd
  193. for i := 0; i < nChilds; i++ {
  194. r = append(r, nil)
  195. }
  196. return r, nil
  197. }
  198. return t.root.getNodesAtLevel(0, l)
  199. }
  200. func (n *node) getNodesAtLevel(currLvl, l int) ([]*node, error) {
  201. if n == nil {
  202. var r []*node
  203. nChilds := int(math.Pow(2, float64(l-currLvl))) //nolint:gomnd
  204. for i := 0; i < nChilds; i++ {
  205. r = append(r, nil)
  206. }
  207. return r, nil
  208. }
  209. typ := n.typ()
  210. if currLvl == l && typ != vtEmpty {
  211. return []*node{n}, nil
  212. }
  213. if currLvl >= l {
  214. return nil, fmt.Errorf("This error should not be reached."+
  215. " currLvl >= l, currLvl=%d, l=%d."+
  216. " Please report it in a new issue:"+
  217. " https://github.com/vocdoni/arbo/issues/new", currLvl, l)
  218. }
  219. var nodes []*node
  220. nodesL, err := n.l.getNodesAtLevel(currLvl+1, l)
  221. if err != nil {
  222. return nil, err
  223. }
  224. nodes = append(nodes, nodesL...)
  225. nodesR, err := n.r.getNodesAtLevel(currLvl+1, l)
  226. if err != nil {
  227. return nil, err
  228. }
  229. nodes = append(nodes, nodesR...)
  230. return nodes, nil
  231. }
  232. // upFromNodes builds the tree from the bottom to up
  233. func upFromNodes(ns []*node) (*node, error) {
  234. if len(ns) == 1 {
  235. return ns[0], nil
  236. }
  237. var res []*node
  238. for i := 0; i < len(ns); i += 2 {
  239. if (ns[i].typ() == vtEmpty && ns[i+1].typ() == vtEmpty) ||
  240. (ns[i].typ() == vtLeaf && ns[i+1].typ() == vtEmpty) {
  241. // when both sub nodes are empty, the parent is also empty
  242. // or
  243. // when 1st sub node is a leaf but the 2nd is empty, the
  244. // leaf is used as parent
  245. res = append(res, ns[i])
  246. continue
  247. }
  248. if ns[i].typ() == vtEmpty && ns[i+1].typ() == vtLeaf {
  249. // when 2nd sub node is a leaf but the 1st is empty, the
  250. // leaf is used as 'parent'
  251. res = append(res, ns[i+1])
  252. continue
  253. }
  254. n := &node{
  255. l: ns[i],
  256. r: ns[i+1],
  257. }
  258. res = append(res, n)
  259. }
  260. return upFromNodes(res)
  261. }
  262. // add adds a key&value as a leaf in the VirtualTree
  263. func (t *vt) add(fromLvl int, k, v []byte) error {
  264. leaf, err := newLeafNode(t.params, k, v)
  265. if err != nil {
  266. return err
  267. }
  268. if t.root == nil {
  269. t.root = leaf
  270. return nil
  271. }
  272. if err := t.root.add(t.params, fromLvl, leaf); err != nil {
  273. return err
  274. }
  275. return nil
  276. }
  277. // computeHashes should be called after all the vt.add is used, once all the
  278. // leafs are in the tree. Computes the hashes of the tree, parallelizing in the
  279. // available CPUs.
  280. func (t *vt) computeHashes() ([][2][]byte, error) {
  281. var err error
  282. nCPU := flp2(runtime.NumCPU())
  283. l := int(math.Log2(float64(nCPU)))
  284. nodesAtL, err := t.getNodesAtLevel(l)
  285. if err != nil {
  286. return nil, err
  287. }
  288. subRoots := make([]*node, nCPU)
  289. bucketPairs := make([][][2][]byte, nCPU)
  290. dbgStatsPerBucket := make([]*dbgStats, nCPU)
  291. errs := make([]error, nCPU)
  292. var wg sync.WaitGroup
  293. wg.Add(nCPU)
  294. for i := 0; i < nCPU; i++ {
  295. go func(cpu int) {
  296. bucketVT := newVT(t.params.maxLevels, t.params.hashFunction)
  297. bucketVT.params.dbg = newDbgStats()
  298. bucketVT.root = nodesAtL[cpu]
  299. var err error
  300. bucketPairs[cpu], err = bucketVT.root.computeHashes(l-1,
  301. t.params.maxLevels, bucketVT.params, bucketPairs[cpu])
  302. if err != nil {
  303. errs[cpu] = err
  304. }
  305. subRoots[cpu] = bucketVT.root
  306. dbgStatsPerBucket[cpu] = bucketVT.params.dbg
  307. wg.Done()
  308. }(i)
  309. }
  310. wg.Wait()
  311. for i := 0; i < len(errs); i++ {
  312. if errs[i] != nil {
  313. return nil, errs[i]
  314. }
  315. }
  316. for i := 0; i < len(dbgStatsPerBucket); i++ {
  317. t.params.dbg.add(dbgStatsPerBucket[i])
  318. }
  319. var pairs [][2][]byte
  320. for i := 0; i < len(bucketPairs); i++ {
  321. pairs = append(pairs, bucketPairs[i]...)
  322. }
  323. nodesAtL, err = t.getNodesAtLevel(l)
  324. if err != nil {
  325. return nil, err
  326. }
  327. for i := 0; i < len(nodesAtL); i++ {
  328. nodesAtL = subRoots
  329. }
  330. pairs, err = t.root.computeHashes(0, l, t.params, pairs)
  331. if err != nil {
  332. return nil, err
  333. }
  334. return pairs, nil
  335. }
  336. func newLeafNode(p *params, k, v []byte) (*node, error) {
  337. keyPath, err := keyPathFromKey(p.maxLevels, k)
  338. if err != nil {
  339. return nil, err
  340. }
  341. path := getPath(p.maxLevels, keyPath)
  342. n := &node{
  343. k: k,
  344. v: v,
  345. path: path,
  346. }
  347. return n, nil
  348. }
  349. type virtualNodeType int
  350. const (
  351. vtEmpty = 0 // for convenience uses same value that PrefixValueEmpty
  352. vtLeaf = 1 // for convenience uses same value that PrefixValueLeaf
  353. vtMid = 2 // for convenience uses same value that PrefixValueIntermediate
  354. )
  355. func (n *node) typ() virtualNodeType {
  356. if n == nil {
  357. return vtEmpty
  358. }
  359. if n.l == nil && n.r == nil && n.k != nil {
  360. return vtLeaf
  361. }
  362. if n.l != nil || n.r != nil {
  363. return vtMid
  364. }
  365. return vtEmpty
  366. }
  367. func (n *node) add(p *params, currLvl int, leaf *node) error {
  368. if currLvl > p.maxLevels-1 {
  369. return ErrMaxVirtualLevel
  370. }
  371. if n == nil {
  372. // n = leaf // TMP!
  373. return nil
  374. }
  375. t := n.typ()
  376. switch t {
  377. case vtMid:
  378. if leaf.path[currLvl] {
  379. //right
  380. if n.r == nil {
  381. // empty sub-node, add the leaf here
  382. n.r = leaf
  383. return nil
  384. }
  385. if err := n.r.add(p, currLvl+1, leaf); err != nil {
  386. return err
  387. }
  388. } else {
  389. if n.l == nil {
  390. // empty sub-node, add the leaf here
  391. n.l = leaf
  392. return nil
  393. }
  394. if err := n.l.add(p, currLvl+1, leaf); err != nil {
  395. return err
  396. }
  397. }
  398. case vtLeaf:
  399. if bytes.Equal(n.k, leaf.k) {
  400. return fmt.Errorf("%s. Existing node: %s, trying to add node: %s",
  401. ErrKeyAlreadyExists, hex.EncodeToString(n.k),
  402. hex.EncodeToString(leaf.k))
  403. }
  404. oldLeaf := &node{
  405. k: n.k,
  406. v: n.v,
  407. path: n.path,
  408. }
  409. // remove values from current node (converting it to mid node)
  410. n.k = nil
  411. n.v = nil
  412. n.h = nil
  413. n.path = nil
  414. if err := n.downUntilDivergence(p, currLvl, oldLeaf, leaf); err != nil {
  415. return err
  416. }
  417. case vtEmpty:
  418. return fmt.Errorf("virtual tree node.add() with empty node %v", n)
  419. default:
  420. return fmt.Errorf("virtual tree node.add() with unknown node type %v", n)
  421. }
  422. return nil
  423. }
  424. func (n *node) downUntilDivergence(p *params, currLvl int, oldLeaf, newLeaf *node) error {
  425. if currLvl > p.maxLevels-1 {
  426. return ErrMaxVirtualLevel
  427. }
  428. if oldLeaf.path[currLvl] != newLeaf.path[currLvl] {
  429. // reached divergence in next level
  430. if newLeaf.path[currLvl] {
  431. n.l = oldLeaf
  432. n.r = newLeaf
  433. } else {
  434. n.l = newLeaf
  435. n.r = oldLeaf
  436. }
  437. return nil
  438. }
  439. // no divergence yet, continue going down
  440. if newLeaf.path[currLvl] {
  441. // right
  442. n.r = &node{}
  443. if err := n.r.downUntilDivergence(p, currLvl+1, oldLeaf, newLeaf); err != nil {
  444. return err
  445. }
  446. } else {
  447. // left
  448. n.l = &node{}
  449. if err := n.l.downUntilDivergence(p, currLvl+1, oldLeaf, newLeaf); err != nil {
  450. return err
  451. }
  452. }
  453. return nil
  454. }
  455. func splitInBuckets(kvs []kv, nBuckets int) [][]kv {
  456. buckets := make([][]kv, nBuckets)
  457. // 1. classify the keyvalues into buckets
  458. for i := 0; i < len(kvs); i++ {
  459. pair := kvs[i]
  460. // bucketnum := keyToBucket(pair.k, nBuckets)
  461. bucketnum := keyToBucket(pair.keyPath, nBuckets)
  462. buckets[bucketnum] = append(buckets[bucketnum], pair)
  463. }
  464. return buckets
  465. }
  466. // TODO rename in a more 'real' name (calculate bucket from/for key)
  467. func keyToBucket(k []byte, nBuckets int) int {
  468. nLevels := int(math.Log2(float64(nBuckets)))
  469. b := make([]int, nBuckets)
  470. for i := 0; i < nBuckets; i++ {
  471. b[i] = i
  472. }
  473. r := b
  474. mid := len(r) / 2 //nolint:gomnd
  475. for i := 0; i < nLevels; i++ {
  476. if int(k[i/8]&(1<<(i%8))) != 0 {
  477. r = r[mid:]
  478. mid = len(r) / 2 //nolint:gomnd
  479. } else {
  480. r = r[:mid]
  481. mid = len(r) / 2 //nolint:gomnd
  482. }
  483. }
  484. return r[0]
  485. }
  486. // flp2 computes the floor power of 2, the highest power of 2 under the given
  487. // value.
  488. func flp2(n int) int {
  489. res := 0
  490. for i := n; i >= 1; i-- {
  491. if (i & (i - 1)) == 0 {
  492. res = i
  493. break
  494. }
  495. }
  496. return res
  497. }
  498. // computeHashes computes the hashes under the node from which is called the
  499. // method. Returns an array of key-values to store in the db
  500. func (n *node) computeHashes(currLvl, maxLvl int, p *params, pairs [][2][]byte) (
  501. [][2][]byte, error) {
  502. if n == nil || currLvl >= maxLvl {
  503. // no need to compute any hash
  504. return pairs, nil
  505. }
  506. if pairs == nil {
  507. pairs = [][2][]byte{}
  508. }
  509. var err error
  510. t := n.typ()
  511. switch t {
  512. case vtLeaf:
  513. p.dbg.incHash()
  514. leafKey, leafValue, err := newLeafValue(p.hashFunction, n.k, n.v)
  515. if err != nil {
  516. return pairs, err
  517. }
  518. n.h = leafKey
  519. kv := [2][]byte{leafKey, leafValue}
  520. pairs = append(pairs, kv)
  521. case vtMid:
  522. if n.l != nil {
  523. pairs, err = n.l.computeHashes(currLvl+1, maxLvl, p, pairs)
  524. if err != nil {
  525. return pairs, err
  526. }
  527. } else {
  528. n.l = &node{
  529. h: p.emptyHash,
  530. }
  531. }
  532. if n.r != nil {
  533. pairs, err = n.r.computeHashes(currLvl+1, maxLvl, p, pairs)
  534. if err != nil {
  535. return pairs, err
  536. }
  537. } else {
  538. n.r = &node{
  539. h: p.emptyHash,
  540. }
  541. }
  542. // once the sub nodes are computed, can compute the current node
  543. // hash
  544. p.dbg.incHash()
  545. k, v, err := newIntermediate(p.hashFunction, n.l.h, n.r.h)
  546. if err != nil {
  547. return nil, err
  548. }
  549. n.h = k
  550. kv := [2][]byte{k, v}
  551. pairs = append(pairs, kv)
  552. case vtEmpty:
  553. default:
  554. return nil, fmt.Errorf("error: n.computeHashes type (%d) no match", t)
  555. }
  556. return pairs, nil
  557. }
  558. //nolint:unused
  559. func (t *vt) graphviz(w io.Writer) error {
  560. fmt.Fprintf(w, `digraph hierarchy {
  561. node [fontname=Monospace,fontsize=10,shape=box]
  562. `)
  563. if _, err := t.root.graphviz(w, t.params, 0); err != nil {
  564. return err
  565. }
  566. fmt.Fprintf(w, "}\n")
  567. return nil
  568. }
  569. //nolint:unused
  570. func (n *node) graphviz(w io.Writer, p *params, nEmpties int) (int, error) {
  571. if n == nil {
  572. return nEmpties, nil
  573. }
  574. t := n.typ()
  575. switch t {
  576. case vtLeaf:
  577. leafKey, _, err := newLeafValue(p.hashFunction, n.k, n.v)
  578. if err != nil {
  579. return nEmpties, err
  580. }
  581. fmt.Fprintf(w, "\"%p\" [style=filled,label=\"%v\"];\n", n, hex.EncodeToString(leafKey[:nChars]))
  582. k := n.k
  583. v := n.v
  584. if len(n.k) >= nChars {
  585. k = n.k[:nChars]
  586. }
  587. if len(n.v) >= nChars {
  588. v = n.v[:nChars]
  589. }
  590. fmt.Fprintf(w, "\"%p\" -> {\"k:%v\\nv:%v\"}\n", n,
  591. hex.EncodeToString(k),
  592. hex.EncodeToString(v))
  593. fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
  594. hex.EncodeToString(k),
  595. hex.EncodeToString(v))
  596. case vtMid:
  597. fmt.Fprintf(w, "\"%p\" [label=\"\"];\n", n)
  598. lStr := fmt.Sprintf("%p", n.l)
  599. rStr := fmt.Sprintf("%p", n.r)
  600. eStr := ""
  601. if n.l == nil {
  602. lStr = fmt.Sprintf("empty%v", nEmpties)
  603. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  604. lStr)
  605. nEmpties++
  606. }
  607. if n.r == nil {
  608. rStr = fmt.Sprintf("empty%v", nEmpties)
  609. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  610. rStr)
  611. nEmpties++
  612. }
  613. fmt.Fprintf(w, "\"%p\" -> {\"%v\" \"%v\"}\n", n, lStr, rStr)
  614. fmt.Fprint(w, eStr)
  615. nEmpties, err := n.l.graphviz(w, p, nEmpties)
  616. if err != nil {
  617. return nEmpties, err
  618. }
  619. nEmpties, err = n.r.graphviz(w, p, nEmpties)
  620. if err != nil {
  621. return nEmpties, err
  622. }
  623. case vtEmpty:
  624. default:
  625. return nEmpties, fmt.Errorf("ERR")
  626. }
  627. return nEmpties, nil
  628. }
  629. //nolint:unused
  630. func (t *vt) printGraphviz() error {
  631. w := bytes.NewBufferString("")
  632. fmt.Fprintf(w,
  633. "--------\nGraphviz:\n")
  634. err := t.graphviz(w)
  635. if err != nil {
  636. fmt.Println(w)
  637. return err
  638. }
  639. fmt.Fprintf(w,
  640. "End of Graphviz --------\n")
  641. fmt.Println(w)
  642. return nil
  643. }