You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

700 lines
16 KiB

  1. // Package arbo > vt.go implements the Virtual Tree, which computes a tree
  2. // without computing any hash. With the idea of once all the leafs are placed in
  3. // their positions, the hashes can be computed, avoiding computing a node hash
  4. // more than one time.
  5. package arbo
  6. import (
  7. "bytes"
  8. "encoding/hex"
  9. "fmt"
  10. "io"
  11. "math"
  12. "runtime"
  13. "sync"
  14. )
  15. type node struct {
  16. l *node
  17. r *node
  18. k []byte
  19. v []byte
  20. path []bool
  21. h []byte
  22. }
  23. type params struct {
  24. maxLevels int
  25. hashFunction HashFunction
  26. emptyHash []byte
  27. dbg *dbgStats
  28. }
  29. type kv struct {
  30. pos int // original position in the inputted array
  31. keyPath []byte
  32. k []byte
  33. v []byte
  34. }
  35. func (p *params) keysValuesToKvs(ks, vs [][]byte) ([]kv, error) {
  36. if len(ks) != len(vs) {
  37. return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
  38. len(ks), len(vs))
  39. }
  40. kvs := make([]kv, len(ks))
  41. for i := 0; i < len(ks); i++ {
  42. keyPath := make([]byte, p.hashFunction.Len())
  43. copy(keyPath[:], ks[i])
  44. kvs[i].pos = i
  45. kvs[i].keyPath = keyPath
  46. kvs[i].k = ks[i]
  47. kvs[i].v = vs[i]
  48. }
  49. return kvs, nil
  50. }
  51. // vt stands for virtual tree. It's a tree that does not have any computed hash
  52. // while placing the leafs. Once all the leafs are placed, it computes all the
  53. // hashes. In this way, each node hash is only computed one time (at the end)
  54. // and the tree is computed in memory.
  55. type vt struct {
  56. root *node
  57. params *params
  58. }
  59. func newVT(maxLevels int, hash HashFunction) vt {
  60. return vt{
  61. root: nil,
  62. params: &params{
  63. maxLevels: maxLevels,
  64. hashFunction: hash,
  65. emptyHash: make([]byte, hash.Len()), // empty
  66. },
  67. }
  68. }
  69. // addBatch adds a batch of key-values to the VirtualTree. Returns an array
  70. // containing the indexes of the keys failed to add. Does not include the
  71. // computation of hashes of the nodes neither the storage of the key-values of
  72. // the tree into the db. After addBatch, vt.computeHashes should be called to
  73. // compute the hashes of all the nodes of the tree.
  74. func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) {
  75. nCPU := flp2(runtime.NumCPU())
  76. if nCPU == 1 || len(ks) < nCPU {
  77. var invalids []int
  78. for i := 0; i < len(ks); i++ {
  79. if err := t.add(0, ks[i], vs[i]); err != nil {
  80. invalids = append(invalids, i)
  81. }
  82. }
  83. return invalids, nil
  84. }
  85. l := int(math.Log2(float64(nCPU)))
  86. kvs, err := t.params.keysValuesToKvs(ks, vs)
  87. if err != nil {
  88. return nil, err
  89. }
  90. buckets := splitInBuckets(kvs, nCPU)
  91. nodesAtL, err := t.getNodesAtLevel(l)
  92. if err != nil {
  93. return nil, err
  94. }
  95. if len(nodesAtL) != nCPU && t.root != nil {
  96. /*
  97. Already populated Tree but Unbalanced
  98. - Need to fill M1 and M2, and then will be able to continue with the flow
  99. - Search for M1 & M2 in the inputed Keys
  100. - Add M1 & M2 to the Tree
  101. - From here can continue with the flow
  102. R
  103. / \
  104. / \
  105. / \
  106. * *
  107. | \
  108. | \
  109. | \
  110. L: M1 * M2 * (where M1 and M2 are empty)
  111. / | /
  112. / | /
  113. / | /
  114. A * *
  115. / \ | \
  116. / \ | \
  117. / \ | \
  118. B * * C
  119. / \ |\
  120. ... ... | \
  121. | \
  122. D E
  123. */
  124. // add one key at each bucket, and then continue with the flow
  125. for i := 0; i < len(buckets); i++ {
  126. // add one leaf of the bucket, if there is an error when
  127. // adding the k-v, try to add the next one of the bucket
  128. // (until one is added)
  129. inserted := -1
  130. for j := 0; j < len(buckets[i]); j++ {
  131. if err := t.add(0, buckets[i][j].k, buckets[i][j].v); err == nil {
  132. inserted = j
  133. break
  134. }
  135. }
  136. // remove the inserted element from buckets[i]
  137. if inserted != -1 {
  138. buckets[i] = append(buckets[i][:inserted], buckets[i][inserted+1:]...)
  139. }
  140. }
  141. nodesAtL, err = t.getNodesAtLevel(l)
  142. if err != nil {
  143. return nil, err
  144. }
  145. }
  146. if len(nodesAtL) != nCPU {
  147. return nil, fmt.Errorf("This error should not be reached."+
  148. " len(nodesAtL) != nCPU, len(nodesAtL)=%d, nCPU=%d."+
  149. " Please report it in a new issue:"+
  150. " https://github.com/vocdoni/arbo/issues/new", len(nodesAtL), nCPU)
  151. }
  152. subRoots := make([]*node, nCPU)
  153. invalidsInBucket := make([][]int, nCPU)
  154. var wg sync.WaitGroup
  155. wg.Add(nCPU)
  156. for i := 0; i < nCPU; i++ {
  157. go func(cpu int) {
  158. bucketVT := newVT(t.params.maxLevels, t.params.hashFunction)
  159. bucketVT.root = nodesAtL[cpu]
  160. for j := 0; j < len(buckets[cpu]); j++ {
  161. if err := bucketVT.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil {
  162. invalidsInBucket[cpu] = append(invalidsInBucket[cpu], buckets[cpu][j].pos)
  163. }
  164. }
  165. subRoots[cpu] = bucketVT.root
  166. wg.Done()
  167. }(i)
  168. }
  169. wg.Wait()
  170. var invalids []int
  171. for i := 0; i < len(invalidsInBucket); i++ {
  172. invalids = append(invalids, invalidsInBucket[i]...)
  173. }
  174. newRootNode, err := upFromNodes(subRoots)
  175. if err != nil {
  176. return nil, err
  177. }
  178. t.root = newRootNode
  179. return invalids, nil
  180. }
  181. func (t *vt) getNodesAtLevel(l int) ([]*node, error) {
  182. if t.root == nil {
  183. var r []*node
  184. nChilds := int(math.Pow(2, float64(l))) //nolint:gomnd
  185. for i := 0; i < nChilds; i++ {
  186. r = append(r, nil)
  187. }
  188. return r, nil
  189. }
  190. return t.root.getNodesAtLevel(0, l)
  191. }
  192. func (n *node) getNodesAtLevel(currLvl, l int) ([]*node, error) {
  193. if n == nil {
  194. var r []*node
  195. nChilds := int(math.Pow(2, float64(l-currLvl))) //nolint:gomnd
  196. for i := 0; i < nChilds; i++ {
  197. r = append(r, nil)
  198. }
  199. return r, nil
  200. }
  201. typ := n.typ()
  202. if currLvl == l && typ != vtEmpty {
  203. return []*node{n}, nil
  204. }
  205. if currLvl >= l {
  206. return nil, fmt.Errorf("This error should not be reached."+
  207. " currLvl >= l, currLvl=%d, l=%d."+
  208. " Please report it in a new issue:"+
  209. " https://github.com/vocdoni/arbo/issues/new", currLvl, l)
  210. }
  211. var nodes []*node
  212. nodesL, err := n.l.getNodesAtLevel(currLvl+1, l)
  213. if err != nil {
  214. return nil, err
  215. }
  216. nodes = append(nodes, nodesL...)
  217. nodesR, err := n.r.getNodesAtLevel(currLvl+1, l)
  218. if err != nil {
  219. return nil, err
  220. }
  221. nodes = append(nodes, nodesR...)
  222. return nodes, nil
  223. }
  224. // upFromNodes builds the tree from the bottom to up
  225. func upFromNodes(ns []*node) (*node, error) {
  226. if len(ns) == 1 {
  227. return ns[0], nil
  228. }
  229. var res []*node
  230. for i := 0; i < len(ns); i += 2 {
  231. if (ns[i].typ() == vtEmpty && ns[i+1].typ() == vtEmpty) ||
  232. (ns[i].typ() == vtLeaf && ns[i+1].typ() == vtEmpty) {
  233. // when both sub nodes are empty, the parent is also empty
  234. // or
  235. // when 1st sub node is a leaf but the 2nd is empty, the
  236. // leaf is used as parent
  237. res = append(res, ns[i])
  238. continue
  239. }
  240. if ns[i].typ() == vtEmpty && ns[i+1].typ() == vtLeaf {
  241. // when 2nd sub node is a leaf but the 1st is empty, the
  242. // leaf is used as 'parent'
  243. res = append(res, ns[i+1])
  244. continue
  245. }
  246. n := &node{
  247. l: ns[i],
  248. r: ns[i+1],
  249. }
  250. res = append(res, n)
  251. }
  252. return upFromNodes(res)
  253. }
  254. // add adds a key&value as a leaf in the VirtualTree
  255. func (t *vt) add(fromLvl int, k, v []byte) error {
  256. leaf := newLeafNode(t.params, k, v)
  257. if t.root == nil {
  258. t.root = leaf
  259. return nil
  260. }
  261. if err := t.root.add(t.params, fromLvl, leaf); err != nil {
  262. return err
  263. }
  264. return nil
  265. }
  266. // computeHashes should be called after all the vt.add is used, once all the
  267. // leafs are in the tree. Computes the hashes of the tree, parallelizing in the
  268. // available CPUs.
  269. func (t *vt) computeHashes() ([][2][]byte, error) {
  270. var err error
  271. nCPU := flp2(runtime.NumCPU())
  272. l := int(math.Log2(float64(nCPU)))
  273. nodesAtL, err := t.getNodesAtLevel(l)
  274. if err != nil {
  275. return nil, err
  276. }
  277. subRoots := make([]*node, nCPU)
  278. bucketPairs := make([][][2][]byte, nCPU)
  279. dbgStatsPerBucket := make([]*dbgStats, nCPU)
  280. errs := make([]error, nCPU)
  281. var wg sync.WaitGroup
  282. wg.Add(nCPU)
  283. for i := 0; i < nCPU; i++ {
  284. go func(cpu int) {
  285. bucketVT := newVT(t.params.maxLevels, t.params.hashFunction)
  286. bucketVT.params.dbg = newDbgStats()
  287. bucketVT.root = nodesAtL[cpu]
  288. var err error
  289. bucketPairs[cpu], err = bucketVT.root.computeHashes(l-1,
  290. t.params.maxLevels, bucketVT.params, bucketPairs[cpu])
  291. if err != nil {
  292. errs[cpu] = err
  293. }
  294. subRoots[cpu] = bucketVT.root
  295. dbgStatsPerBucket[cpu] = bucketVT.params.dbg
  296. wg.Done()
  297. }(i)
  298. }
  299. wg.Wait()
  300. for i := 0; i < len(errs); i++ {
  301. if errs[i] != nil {
  302. return nil, errs[i]
  303. }
  304. }
  305. for i := 0; i < len(dbgStatsPerBucket); i++ {
  306. t.params.dbg.add(dbgStatsPerBucket[i])
  307. }
  308. var pairs [][2][]byte
  309. for i := 0; i < len(bucketPairs); i++ {
  310. pairs = append(pairs, bucketPairs[i]...)
  311. }
  312. nodesAtL, err = t.getNodesAtLevel(l)
  313. if err != nil {
  314. return nil, err
  315. }
  316. for i := 0; i < len(nodesAtL); i++ {
  317. nodesAtL = subRoots
  318. }
  319. pairs, err = t.root.computeHashes(0, l, t.params, pairs)
  320. if err != nil {
  321. return nil, err
  322. }
  323. return pairs, nil
  324. }
  325. func newLeafNode(p *params, k, v []byte) *node {
  326. keyPath := make([]byte, p.hashFunction.Len())
  327. copy(keyPath[:], k)
  328. path := getPath(p.maxLevels, keyPath)
  329. n := &node{
  330. k: k,
  331. v: v,
  332. path: path,
  333. }
  334. return n
  335. }
  336. type virtualNodeType int
  337. const (
  338. vtEmpty = 0 // for convenience uses same value that PrefixValueEmpty
  339. vtLeaf = 1 // for convenience uses same value that PrefixValueLeaf
  340. vtMid = 2 // for convenience uses same value that PrefixValueIntermediate
  341. )
  342. func (n *node) typ() virtualNodeType {
  343. if n == nil {
  344. return vtEmpty
  345. }
  346. if n.l == nil && n.r == nil && n.k != nil {
  347. return vtLeaf
  348. }
  349. if n.l != nil || n.r != nil {
  350. return vtMid
  351. }
  352. return vtEmpty
  353. }
  354. func (n *node) add(p *params, currLvl int, leaf *node) error {
  355. if currLvl > p.maxLevels-1 {
  356. return ErrMaxVirtualLevel
  357. }
  358. if n == nil {
  359. // n = leaf // TMP!
  360. return nil
  361. }
  362. t := n.typ()
  363. switch t {
  364. case vtMid:
  365. if leaf.path[currLvl] {
  366. //right
  367. if n.r == nil {
  368. // empty sub-node, add the leaf here
  369. n.r = leaf
  370. return nil
  371. }
  372. if err := n.r.add(p, currLvl+1, leaf); err != nil {
  373. return err
  374. }
  375. } else {
  376. if n.l == nil {
  377. // empty sub-node, add the leaf here
  378. n.l = leaf
  379. return nil
  380. }
  381. if err := n.l.add(p, currLvl+1, leaf); err != nil {
  382. return err
  383. }
  384. }
  385. case vtLeaf:
  386. if bytes.Equal(n.k, leaf.k) {
  387. return fmt.Errorf("%s. Existing node: %s, trying to add node: %s",
  388. ErrKeyAlreadyExists, hex.EncodeToString(n.k),
  389. hex.EncodeToString(leaf.k))
  390. }
  391. oldLeaf := &node{
  392. k: n.k,
  393. v: n.v,
  394. path: n.path,
  395. }
  396. // remove values from current node (converting it to mid node)
  397. n.k = nil
  398. n.v = nil
  399. n.h = nil
  400. n.path = nil
  401. if err := n.downUntilDivergence(p, currLvl, oldLeaf, leaf); err != nil {
  402. return err
  403. }
  404. case vtEmpty:
  405. return fmt.Errorf("virtual tree node.add() with empty node %v", n)
  406. default:
  407. return fmt.Errorf("virtual tree node.add() with unknown node type %v", n)
  408. }
  409. return nil
  410. }
  411. func (n *node) downUntilDivergence(p *params, currLvl int, oldLeaf, newLeaf *node) error {
  412. if currLvl > p.maxLevels-1 {
  413. return ErrMaxVirtualLevel
  414. }
  415. if oldLeaf.path[currLvl] != newLeaf.path[currLvl] {
  416. // reached divergence in next level
  417. if newLeaf.path[currLvl] {
  418. n.l = oldLeaf
  419. n.r = newLeaf
  420. } else {
  421. n.l = newLeaf
  422. n.r = oldLeaf
  423. }
  424. return nil
  425. }
  426. // no divergence yet, continue going down
  427. if newLeaf.path[currLvl] {
  428. // right
  429. n.r = &node{}
  430. if err := n.r.downUntilDivergence(p, currLvl+1, oldLeaf, newLeaf); err != nil {
  431. return err
  432. }
  433. } else {
  434. // left
  435. n.l = &node{}
  436. if err := n.l.downUntilDivergence(p, currLvl+1, oldLeaf, newLeaf); err != nil {
  437. return err
  438. }
  439. }
  440. return nil
  441. }
  442. func splitInBuckets(kvs []kv, nBuckets int) [][]kv {
  443. buckets := make([][]kv, nBuckets)
  444. // 1. classify the keyvalues into buckets
  445. for i := 0; i < len(kvs); i++ {
  446. pair := kvs[i]
  447. // bucketnum := keyToBucket(pair.k, nBuckets)
  448. bucketnum := keyToBucket(pair.keyPath, nBuckets)
  449. buckets[bucketnum] = append(buckets[bucketnum], pair)
  450. }
  451. return buckets
  452. }
  453. // TODO rename in a more 'real' name (calculate bucket from/for key)
  454. func keyToBucket(k []byte, nBuckets int) int {
  455. nLevels := int(math.Log2(float64(nBuckets)))
  456. b := make([]int, nBuckets)
  457. for i := 0; i < nBuckets; i++ {
  458. b[i] = i
  459. }
  460. r := b
  461. mid := len(r) / 2 //nolint:gomnd
  462. for i := 0; i < nLevels; i++ {
  463. if int(k[i/8]&(1<<(i%8))) != 0 {
  464. r = r[mid:]
  465. mid = len(r) / 2 //nolint:gomnd
  466. } else {
  467. r = r[:mid]
  468. mid = len(r) / 2 //nolint:gomnd
  469. }
  470. }
  471. return r[0]
  472. }
  473. // flp2 computes the floor power of 2, the highest power of 2 under the given
  474. // value.
  475. func flp2(n int) int {
  476. res := 0
  477. for i := n; i >= 1; i-- {
  478. if (i & (i - 1)) == 0 {
  479. res = i
  480. break
  481. }
  482. }
  483. return res
  484. }
  485. // computeHashes computes the hashes under the node from which is called the
  486. // method. Returns an array of key-values to store in the db
  487. func (n *node) computeHashes(currLvl, maxLvl int, p *params, pairs [][2][]byte) (
  488. [][2][]byte, error) {
  489. if n == nil || currLvl >= maxLvl {
  490. // no need to compute any hash
  491. return pairs, nil
  492. }
  493. if pairs == nil {
  494. pairs = [][2][]byte{}
  495. }
  496. var err error
  497. t := n.typ()
  498. switch t {
  499. case vtLeaf:
  500. p.dbg.incHash()
  501. leafKey, leafValue, err := newLeafValue(p.hashFunction, n.k, n.v)
  502. if err != nil {
  503. return pairs, err
  504. }
  505. n.h = leafKey
  506. kv := [2][]byte{leafKey, leafValue}
  507. pairs = append(pairs, kv)
  508. case vtMid:
  509. if n.l != nil {
  510. pairs, err = n.l.computeHashes(currLvl+1, maxLvl, p, pairs)
  511. if err != nil {
  512. return pairs, err
  513. }
  514. } else {
  515. n.l = &node{
  516. h: p.emptyHash,
  517. }
  518. }
  519. if n.r != nil {
  520. pairs, err = n.r.computeHashes(currLvl+1, maxLvl, p, pairs)
  521. if err != nil {
  522. return pairs, err
  523. }
  524. } else {
  525. n.r = &node{
  526. h: p.emptyHash,
  527. }
  528. }
  529. // once the sub nodes are computed, can compute the current node
  530. // hash
  531. p.dbg.incHash()
  532. k, v, err := newIntermediate(p.hashFunction, n.l.h, n.r.h)
  533. if err != nil {
  534. return nil, err
  535. }
  536. n.h = k
  537. kv := [2][]byte{k, v}
  538. pairs = append(pairs, kv)
  539. case vtEmpty:
  540. default:
  541. return nil, fmt.Errorf("error: n.computeHashes type (%d) no match", t)
  542. }
  543. return pairs, nil
  544. }
  545. //nolint:unused
  546. func (t *vt) graphviz(w io.Writer) error {
  547. fmt.Fprintf(w, `digraph hierarchy {
  548. node [fontname=Monospace,fontsize=10,shape=box]
  549. `)
  550. if _, err := t.root.graphviz(w, t.params, 0); err != nil {
  551. return err
  552. }
  553. fmt.Fprintf(w, "}\n")
  554. return nil
  555. }
  556. //nolint:unused
  557. func (n *node) graphviz(w io.Writer, p *params, nEmpties int) (int, error) {
  558. if n == nil {
  559. return nEmpties, nil
  560. }
  561. t := n.typ()
  562. switch t {
  563. case vtLeaf:
  564. leafKey, _, err := newLeafValue(p.hashFunction, n.k, n.v)
  565. if err != nil {
  566. return nEmpties, err
  567. }
  568. fmt.Fprintf(w, "\"%p\" [style=filled,label=\"%v\"];\n", n, hex.EncodeToString(leafKey[:nChars]))
  569. k := n.k
  570. v := n.v
  571. if len(n.k) >= nChars {
  572. k = n.k[:nChars]
  573. }
  574. if len(n.v) >= nChars {
  575. v = n.v[:nChars]
  576. }
  577. fmt.Fprintf(w, "\"%p\" -> {\"k:%v\\nv:%v\"}\n", n,
  578. hex.EncodeToString(k),
  579. hex.EncodeToString(v))
  580. fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
  581. hex.EncodeToString(k),
  582. hex.EncodeToString(v))
  583. case vtMid:
  584. fmt.Fprintf(w, "\"%p\" [label=\"\"];\n", n)
  585. lStr := fmt.Sprintf("%p", n.l)
  586. rStr := fmt.Sprintf("%p", n.r)
  587. eStr := ""
  588. if n.l == nil {
  589. lStr = fmt.Sprintf("empty%v", nEmpties)
  590. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  591. lStr)
  592. nEmpties++
  593. }
  594. if n.r == nil {
  595. rStr = fmt.Sprintf("empty%v", nEmpties)
  596. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  597. rStr)
  598. nEmpties++
  599. }
  600. fmt.Fprintf(w, "\"%p\" -> {\"%v\" \"%v\"}\n", n, lStr, rStr)
  601. fmt.Fprint(w, eStr)
  602. nEmpties, err := n.l.graphviz(w, p, nEmpties)
  603. if err != nil {
  604. return nEmpties, err
  605. }
  606. nEmpties, err = n.r.graphviz(w, p, nEmpties)
  607. if err != nil {
  608. return nEmpties, err
  609. }
  610. case vtEmpty:
  611. default:
  612. return nEmpties, fmt.Errorf("ERR")
  613. }
  614. return nEmpties, nil
  615. }
  616. //nolint:unused
  617. func (t *vt) printGraphviz() error {
  618. w := bytes.NewBufferString("")
  619. fmt.Fprintf(w,
  620. "--------\nGraphviz:\n")
  621. err := t.graphviz(w)
  622. if err != nil {
  623. fmt.Println(w)
  624. return err
  625. }
  626. fmt.Fprintf(w,
  627. "End of Graphviz --------\n")
  628. fmt.Println(w)
  629. return nil
  630. }