You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

713 lines
16 KiB

  1. // Package arbo > vt.go implements the Virtual Tree, which computes a tree
  2. // without computing any hash. With the idea of once all the leafs are placed in
  3. // their positions, the hashes can be computed, avoiding computing a node hash
  4. // more than one time.
  5. package arbo
  6. import (
  7. "bytes"
  8. "encoding/hex"
  9. "fmt"
  10. "io"
  11. "math"
  12. "runtime"
  13. "sync"
  14. )
  15. type node struct {
  16. l *node
  17. r *node
  18. k []byte
  19. v []byte
  20. path []bool
  21. h []byte
  22. }
  23. type params struct {
  24. maxLevels int
  25. hashFunction HashFunction
  26. emptyHash []byte
  27. dbg *dbgStats
  28. }
  29. type kv struct {
  30. pos int // original position in the inputted array
  31. keyPath []byte
  32. k []byte
  33. v []byte
  34. }
  35. func keysValuesToKvs(maxLevels int, ks, vs [][]byte) ([]kv, []Invalid, error) {
  36. if len(ks) != len(vs) {
  37. return nil, nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
  38. len(ks), len(vs))
  39. }
  40. var invalids []Invalid
  41. var kvs []kv
  42. for i := 0; i < len(ks); i++ {
  43. keyPath, err := keyPathFromKey(maxLevels, ks[i])
  44. if err != nil {
  45. invalids = append(invalids, Invalid{i, err})
  46. continue
  47. }
  48. var kvsI kv
  49. kvsI.pos = i
  50. kvsI.keyPath = keyPath
  51. kvsI.k = ks[i]
  52. kvsI.v = vs[i]
  53. kvs = append(kvs, kvsI)
  54. }
  55. return kvs, invalids, nil
  56. }
  57. // vt stands for virtual tree. It's a tree that does not have any computed hash
  58. // while placing the leafs. Once all the leafs are placed, it computes all the
  59. // hashes. In this way, each node hash is only computed one time (at the end)
  60. // and the tree is computed in memory.
  61. type vt struct {
  62. root *node
  63. params *params
  64. }
  65. func newVT(maxLevels int, hash HashFunction) vt {
  66. return vt{
  67. root: nil,
  68. params: &params{
  69. maxLevels: maxLevels,
  70. hashFunction: hash,
  71. emptyHash: make([]byte, hash.Len()), // empty
  72. },
  73. }
  74. }
  75. // addBatch adds a batch of key-values to the VirtualTree. Returns an array
  76. // containing the indexes of the keys failed to add. Does not include the
  77. // computation of hashes of the nodes neither the storage of the key-values of
  78. // the tree into the db. After addBatch, vt.computeHashes should be called to
  79. // compute the hashes of all the nodes of the tree.
  80. func (t *vt) addBatch(ks, vs [][]byte) ([]Invalid, error) {
  81. nCPU := flp2(runtime.NumCPU())
  82. if nCPU == 1 || len(ks) < nCPU {
  83. var invalids []Invalid
  84. for i := 0; i < len(ks); i++ {
  85. if err := t.add(0, ks[i], vs[i]); err != nil {
  86. invalids = append(invalids, Invalid{i, err})
  87. }
  88. }
  89. return invalids, nil
  90. }
  91. l := int(math.Log2(float64(nCPU)))
  92. kvs, invalids, err := keysValuesToKvs(t.params.maxLevels, ks, vs)
  93. if err != nil {
  94. return invalids, err
  95. }
  96. buckets := splitInBuckets(kvs, nCPU)
  97. nodesAtL, err := t.getNodesAtLevel(l)
  98. if err != nil {
  99. return nil, err
  100. }
  101. if len(nodesAtL) != nCPU && t.root != nil {
  102. /*
  103. Already populated Tree but Unbalanced
  104. - Need to fill M1 and M2, and then will be able to continue with the flow
  105. - Search for M1 & M2 in the inputed Keys
  106. - Add M1 & M2 to the Tree
  107. - From here can continue with the flow
  108. R
  109. / \
  110. / \
  111. / \
  112. * *
  113. | \
  114. | \
  115. | \
  116. L: M1 * M2 * (where M1 and M2 are empty)
  117. / | /
  118. / | /
  119. / | /
  120. A * *
  121. / \ | \
  122. / \ | \
  123. / \ | \
  124. B * * C
  125. / \ |\
  126. ... ... | \
  127. | \
  128. D E
  129. */
  130. // add one key at each bucket, and then continue with the flow
  131. for i := 0; i < len(buckets); i++ {
  132. // add one leaf of the bucket, if there is an error when
  133. // adding the k-v, try to add the next one of the bucket
  134. // (until one is added)
  135. inserted := -1
  136. for j := 0; j < len(buckets[i]); j++ {
  137. if err := t.add(0, buckets[i][j].k, buckets[i][j].v); err == nil {
  138. inserted = j
  139. break
  140. }
  141. }
  142. // remove the inserted element from buckets[i]
  143. if inserted != -1 {
  144. buckets[i] = append(buckets[i][:inserted], buckets[i][inserted+1:]...)
  145. }
  146. }
  147. nodesAtL, err = t.getNodesAtLevel(l)
  148. if err != nil {
  149. return nil, err
  150. }
  151. }
  152. if len(nodesAtL) != nCPU {
  153. return nil, fmt.Errorf("This error should not be reached."+
  154. " len(nodesAtL) != nCPU, len(nodesAtL)=%d, nCPU=%d."+
  155. " Please report it in a new issue:"+
  156. " https://github.com/vocdoni/arbo/issues/new", len(nodesAtL), nCPU)
  157. }
  158. subRoots := make([]*node, nCPU)
  159. invalidsInBucket := make([][]Invalid, nCPU)
  160. var wg sync.WaitGroup
  161. wg.Add(nCPU)
  162. for i := 0; i < nCPU; i++ {
  163. go func(cpu int) {
  164. bucketVT := newVT(t.params.maxLevels, t.params.hashFunction)
  165. bucketVT.root = nodesAtL[cpu]
  166. for j := 0; j < len(buckets[cpu]); j++ {
  167. if err := bucketVT.add(l, buckets[cpu][j].k,
  168. buckets[cpu][j].v); err != nil {
  169. invalidsInBucket[cpu] = append(invalidsInBucket[cpu],
  170. Invalid{buckets[cpu][j].pos, err})
  171. }
  172. }
  173. subRoots[cpu] = bucketVT.root
  174. wg.Done()
  175. }(i)
  176. }
  177. wg.Wait()
  178. for i := 0; i < len(invalidsInBucket); i++ {
  179. invalids = append(invalids, invalidsInBucket[i]...)
  180. }
  181. newRootNode, err := upFromNodes(subRoots)
  182. if err != nil {
  183. return nil, err
  184. }
  185. t.root = newRootNode
  186. return invalids, nil
  187. }
  188. func (t *vt) getNodesAtLevel(l int) ([]*node, error) {
  189. if t.root == nil {
  190. var r []*node
  191. nChilds := int(math.Pow(2, float64(l))) //nolint:gomnd
  192. for i := 0; i < nChilds; i++ {
  193. r = append(r, nil)
  194. }
  195. return r, nil
  196. }
  197. return t.root.getNodesAtLevel(0, l)
  198. }
  199. func (n *node) getNodesAtLevel(currLvl, l int) ([]*node, error) {
  200. if n == nil {
  201. var r []*node
  202. nChilds := int(math.Pow(2, float64(l-currLvl))) //nolint:gomnd
  203. for i := 0; i < nChilds; i++ {
  204. r = append(r, nil)
  205. }
  206. return r, nil
  207. }
  208. typ := n.typ()
  209. if currLvl == l && typ != vtEmpty {
  210. return []*node{n}, nil
  211. }
  212. if currLvl >= l {
  213. return nil, fmt.Errorf("This error should not be reached."+
  214. " currLvl >= l, currLvl=%d, l=%d."+
  215. " Please report it in a new issue:"+
  216. " https://github.com/vocdoni/arbo/issues/new", currLvl, l)
  217. }
  218. var nodes []*node
  219. nodesL, err := n.l.getNodesAtLevel(currLvl+1, l)
  220. if err != nil {
  221. return nil, err
  222. }
  223. nodes = append(nodes, nodesL...)
  224. nodesR, err := n.r.getNodesAtLevel(currLvl+1, l)
  225. if err != nil {
  226. return nil, err
  227. }
  228. nodes = append(nodes, nodesR...)
  229. return nodes, nil
  230. }
  231. // upFromNodes builds the tree from the bottom to up
  232. func upFromNodes(ns []*node) (*node, error) {
  233. if len(ns) == 1 {
  234. return ns[0], nil
  235. }
  236. var res []*node
  237. for i := 0; i < len(ns); i += 2 {
  238. if (ns[i].typ() == vtEmpty && ns[i+1].typ() == vtEmpty) ||
  239. (ns[i].typ() == vtLeaf && ns[i+1].typ() == vtEmpty) {
  240. // when both sub nodes are empty, the parent is also empty
  241. // or
  242. // when 1st sub node is a leaf but the 2nd is empty, the
  243. // leaf is used as parent
  244. res = append(res, ns[i])
  245. continue
  246. }
  247. if ns[i].typ() == vtEmpty && ns[i+1].typ() == vtLeaf {
  248. // when 2nd sub node is a leaf but the 1st is empty, the
  249. // leaf is used as 'parent'
  250. res = append(res, ns[i+1])
  251. continue
  252. }
  253. n := &node{
  254. l: ns[i],
  255. r: ns[i+1],
  256. }
  257. res = append(res, n)
  258. }
  259. return upFromNodes(res)
  260. }
  261. // add adds a key&value as a leaf in the VirtualTree
  262. func (t *vt) add(fromLvl int, k, v []byte) error {
  263. leaf, err := newLeafNode(t.params, k, v)
  264. if err != nil {
  265. return err
  266. }
  267. if t.root == nil {
  268. t.root = leaf
  269. return nil
  270. }
  271. if err := t.root.add(t.params, fromLvl, leaf); err != nil {
  272. return err
  273. }
  274. return nil
  275. }
  276. // computeHashes should be called after all the vt.add is used, once all the
  277. // leafs are in the tree. Computes the hashes of the tree, parallelizing in the
  278. // available CPUs.
  279. func (t *vt) computeHashes() ([][2][]byte, error) {
  280. var err error
  281. nCPU := flp2(runtime.NumCPU())
  282. l := int(math.Log2(float64(nCPU)))
  283. nodesAtL, err := t.getNodesAtLevel(l)
  284. if err != nil {
  285. return nil, err
  286. }
  287. subRoots := make([]*node, nCPU)
  288. bucketPairs := make([][][2][]byte, nCPU)
  289. dbgStatsPerBucket := make([]*dbgStats, nCPU)
  290. errs := make([]error, nCPU)
  291. var wg sync.WaitGroup
  292. wg.Add(nCPU)
  293. for i := 0; i < nCPU; i++ {
  294. go func(cpu int) {
  295. bucketVT := newVT(t.params.maxLevels, t.params.hashFunction)
  296. bucketVT.params.dbg = newDbgStats()
  297. bucketVT.root = nodesAtL[cpu]
  298. var err error
  299. bucketPairs[cpu], err = bucketVT.root.computeHashes(l-1,
  300. t.params.maxLevels, bucketVT.params, bucketPairs[cpu])
  301. if err != nil {
  302. errs[cpu] = err
  303. }
  304. subRoots[cpu] = bucketVT.root
  305. dbgStatsPerBucket[cpu] = bucketVT.params.dbg
  306. wg.Done()
  307. }(i)
  308. }
  309. wg.Wait()
  310. for i := 0; i < len(errs); i++ {
  311. if errs[i] != nil {
  312. return nil, errs[i]
  313. }
  314. }
  315. for i := 0; i < len(dbgStatsPerBucket); i++ {
  316. t.params.dbg.add(dbgStatsPerBucket[i])
  317. }
  318. var pairs [][2][]byte
  319. for i := 0; i < len(bucketPairs); i++ {
  320. pairs = append(pairs, bucketPairs[i]...)
  321. }
  322. nodesAtL, err = t.getNodesAtLevel(l)
  323. if err != nil {
  324. return nil, err
  325. }
  326. for i := 0; i < len(nodesAtL); i++ {
  327. nodesAtL = subRoots
  328. }
  329. pairs, err = t.root.computeHashes(0, l, t.params, pairs)
  330. if err != nil {
  331. return nil, err
  332. }
  333. return pairs, nil
  334. }
  335. func newLeafNode(p *params, k, v []byte) (*node, error) {
  336. keyPath, err := keyPathFromKey(p.maxLevels, k)
  337. if err != nil {
  338. return nil, err
  339. }
  340. path := getPath(p.maxLevels, keyPath)
  341. n := &node{
  342. k: k,
  343. v: v,
  344. path: path,
  345. }
  346. return n, nil
  347. }
  348. type virtualNodeType int
  349. const (
  350. vtEmpty = 0 // for convenience uses same value that PrefixValueEmpty
  351. vtLeaf = 1 // for convenience uses same value that PrefixValueLeaf
  352. vtMid = 2 // for convenience uses same value that PrefixValueIntermediate
  353. )
  354. func (n *node) typ() virtualNodeType {
  355. if n == nil {
  356. return vtEmpty
  357. }
  358. if n.l == nil && n.r == nil && n.k != nil {
  359. return vtLeaf
  360. }
  361. if n.l != nil || n.r != nil {
  362. return vtMid
  363. }
  364. return vtEmpty
  365. }
  366. func (n *node) add(p *params, currLvl int, leaf *node) error {
  367. if currLvl > p.maxLevels-1 {
  368. return ErrMaxVirtualLevel
  369. }
  370. if n == nil {
  371. // n = leaf // TMP!
  372. return nil
  373. }
  374. t := n.typ()
  375. switch t {
  376. case vtMid:
  377. if leaf.path[currLvl] {
  378. //right
  379. if n.r == nil {
  380. // empty sub-node, add the leaf here
  381. n.r = leaf
  382. return nil
  383. }
  384. if err := n.r.add(p, currLvl+1, leaf); err != nil {
  385. return err
  386. }
  387. } else {
  388. if n.l == nil {
  389. // empty sub-node, add the leaf here
  390. n.l = leaf
  391. return nil
  392. }
  393. if err := n.l.add(p, currLvl+1, leaf); err != nil {
  394. return err
  395. }
  396. }
  397. case vtLeaf:
  398. if bytes.Equal(n.k, leaf.k) {
  399. return fmt.Errorf("%s. Existing node: %s, trying to add node: %s",
  400. ErrKeyAlreadyExists, hex.EncodeToString(n.k),
  401. hex.EncodeToString(leaf.k))
  402. }
  403. oldLeaf := &node{
  404. k: n.k,
  405. v: n.v,
  406. path: n.path,
  407. }
  408. // remove values from current node (converting it to mid node)
  409. n.k = nil
  410. n.v = nil
  411. n.h = nil
  412. n.path = nil
  413. if err := n.downUntilDivergence(p, currLvl, oldLeaf, leaf); err != nil {
  414. return err
  415. }
  416. case vtEmpty:
  417. return fmt.Errorf("virtual tree node.add() with empty node %v", n)
  418. default:
  419. return fmt.Errorf("virtual tree node.add() with unknown node type %v", n)
  420. }
  421. return nil
  422. }
  423. func (n *node) downUntilDivergence(p *params, currLvl int, oldLeaf, newLeaf *node) error {
  424. if currLvl > p.maxLevels-1 {
  425. return ErrMaxVirtualLevel
  426. }
  427. if oldLeaf.path[currLvl] != newLeaf.path[currLvl] {
  428. // reached divergence in next level
  429. if newLeaf.path[currLvl] {
  430. n.l = oldLeaf
  431. n.r = newLeaf
  432. } else {
  433. n.l = newLeaf
  434. n.r = oldLeaf
  435. }
  436. return nil
  437. }
  438. // no divergence yet, continue going down
  439. if newLeaf.path[currLvl] {
  440. // right
  441. n.r = &node{}
  442. if err := n.r.downUntilDivergence(p, currLvl+1, oldLeaf, newLeaf); err != nil {
  443. return err
  444. }
  445. } else {
  446. // left
  447. n.l = &node{}
  448. if err := n.l.downUntilDivergence(p, currLvl+1, oldLeaf, newLeaf); err != nil {
  449. return err
  450. }
  451. }
  452. return nil
  453. }
  454. func splitInBuckets(kvs []kv, nBuckets int) [][]kv {
  455. buckets := make([][]kv, nBuckets)
  456. // 1. classify the keyvalues into buckets
  457. for i := 0; i < len(kvs); i++ {
  458. pair := kvs[i]
  459. // bucketnum := keyToBucket(pair.k, nBuckets)
  460. bucketnum := keyToBucket(pair.keyPath, nBuckets)
  461. buckets[bucketnum] = append(buckets[bucketnum], pair)
  462. }
  463. return buckets
  464. }
  465. // TODO rename in a more 'real' name (calculate bucket from/for key)
  466. func keyToBucket(k []byte, nBuckets int) int {
  467. nLevels := int(math.Log2(float64(nBuckets)))
  468. b := make([]int, nBuckets)
  469. for i := 0; i < nBuckets; i++ {
  470. b[i] = i
  471. }
  472. r := b
  473. mid := len(r) / 2 //nolint:gomnd
  474. for i := 0; i < nLevels; i++ {
  475. if int(k[i/8]&(1<<(i%8))) != 0 {
  476. r = r[mid:]
  477. mid = len(r) / 2 //nolint:gomnd
  478. } else {
  479. r = r[:mid]
  480. mid = len(r) / 2 //nolint:gomnd
  481. }
  482. }
  483. return r[0]
  484. }
  485. // flp2 computes the floor power of 2, the highest power of 2 under the given
  486. // value.
  487. func flp2(n int) int {
  488. res := 0
  489. for i := n; i >= 1; i-- {
  490. if (i & (i - 1)) == 0 {
  491. res = i
  492. break
  493. }
  494. }
  495. return res
  496. }
  497. // computeHashes computes the hashes under the node from which is called the
  498. // method. Returns an array of key-values to store in the db
  499. func (n *node) computeHashes(currLvl, maxLvl int, p *params, pairs [][2][]byte) (
  500. [][2][]byte, error) {
  501. if n == nil || currLvl >= maxLvl {
  502. // no need to compute any hash
  503. return pairs, nil
  504. }
  505. if pairs == nil {
  506. pairs = [][2][]byte{}
  507. }
  508. var err error
  509. t := n.typ()
  510. switch t {
  511. case vtLeaf:
  512. p.dbg.incHash()
  513. leafKey, leafValue, err := newLeafValue(p.hashFunction, n.k, n.v)
  514. if err != nil {
  515. return pairs, err
  516. }
  517. n.h = leafKey
  518. kv := [2][]byte{leafKey, leafValue}
  519. pairs = append(pairs, kv)
  520. case vtMid:
  521. if n.l != nil {
  522. pairs, err = n.l.computeHashes(currLvl+1, maxLvl, p, pairs)
  523. if err != nil {
  524. return pairs, err
  525. }
  526. } else {
  527. n.l = &node{
  528. h: p.emptyHash,
  529. }
  530. }
  531. if n.r != nil {
  532. pairs, err = n.r.computeHashes(currLvl+1, maxLvl, p, pairs)
  533. if err != nil {
  534. return pairs, err
  535. }
  536. } else {
  537. n.r = &node{
  538. h: p.emptyHash,
  539. }
  540. }
  541. // once the sub nodes are computed, can compute the current node
  542. // hash
  543. p.dbg.incHash()
  544. k, v, err := newIntermediate(p.hashFunction, n.l.h, n.r.h)
  545. if err != nil {
  546. return nil, err
  547. }
  548. n.h = k
  549. kv := [2][]byte{k, v}
  550. pairs = append(pairs, kv)
  551. case vtEmpty:
  552. default:
  553. return nil, fmt.Errorf("error: n.computeHashes type (%d) no match", t)
  554. }
  555. return pairs, nil
  556. }
  557. //nolint:unused
  558. func (t *vt) graphviz(w io.Writer) error {
  559. fmt.Fprintf(w, `digraph hierarchy {
  560. node [fontname=Monospace,fontsize=10,shape=box]
  561. `)
  562. if _, err := t.root.graphviz(w, t.params, 0); err != nil {
  563. return err
  564. }
  565. fmt.Fprintf(w, "}\n")
  566. return nil
  567. }
  568. //nolint:unused
  569. func (n *node) graphviz(w io.Writer, p *params, nEmpties int) (int, error) {
  570. if n == nil {
  571. return nEmpties, nil
  572. }
  573. t := n.typ()
  574. switch t {
  575. case vtLeaf:
  576. leafKey, _, err := newLeafValue(p.hashFunction, n.k, n.v)
  577. if err != nil {
  578. return nEmpties, err
  579. }
  580. fmt.Fprintf(w, "\"%p\" [style=filled,label=\"%v\"];\n", n, hex.EncodeToString(leafKey[:nChars]))
  581. k := n.k
  582. v := n.v
  583. if len(n.k) >= nChars {
  584. k = n.k[:nChars]
  585. }
  586. if len(n.v) >= nChars {
  587. v = n.v[:nChars]
  588. }
  589. fmt.Fprintf(w, "\"%p\" -> {\"k:%v\\nv:%v\"}\n", n,
  590. hex.EncodeToString(k),
  591. hex.EncodeToString(v))
  592. fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
  593. hex.EncodeToString(k),
  594. hex.EncodeToString(v))
  595. case vtMid:
  596. fmt.Fprintf(w, "\"%p\" [label=\"\"];\n", n)
  597. lStr := fmt.Sprintf("%p", n.l)
  598. rStr := fmt.Sprintf("%p", n.r)
  599. eStr := ""
  600. if n.l == nil {
  601. lStr = fmt.Sprintf("empty%v", nEmpties)
  602. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  603. lStr)
  604. nEmpties++
  605. }
  606. if n.r == nil {
  607. rStr = fmt.Sprintf("empty%v", nEmpties)
  608. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  609. rStr)
  610. nEmpties++
  611. }
  612. fmt.Fprintf(w, "\"%p\" -> {\"%v\" \"%v\"}\n", n, lStr, rStr)
  613. fmt.Fprint(w, eStr)
  614. nEmpties, err := n.l.graphviz(w, p, nEmpties)
  615. if err != nil {
  616. return nEmpties, err
  617. }
  618. nEmpties, err = n.r.graphviz(w, p, nEmpties)
  619. if err != nil {
  620. return nEmpties, err
  621. }
  622. case vtEmpty:
  623. default:
  624. return nEmpties, fmt.Errorf("ERR")
  625. }
  626. return nEmpties, nil
  627. }
  628. //nolint:unused
  629. func (t *vt) printGraphviz() error {
  630. w := bytes.NewBufferString("")
  631. fmt.Fprintf(w,
  632. "--------\nGraphviz:\n")
  633. err := t.graphviz(w)
  634. if err != nil {
  635. fmt.Println(w)
  636. return err
  637. }
  638. fmt.Fprintf(w,
  639. "End of Graphviz --------\n")
  640. fmt.Println(w)
  641. return nil
  642. }