You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

721 lines
16 KiB

  1. // Package arbo > vt.go implements the Virtual Tree, which computes a tree
  2. // without computing any hash. With the idea of once all the leafs are placed in
  3. // their positions, the hashes can be computed, avoiding computing a node hash
  4. // more than one time.
  5. package arbo
  6. import (
  7. "bytes"
  8. "encoding/hex"
  9. "fmt"
  10. "io"
  11. "math"
  12. "runtime"
  13. "sync"
  14. )
  15. type node struct {
  16. l *node
  17. r *node
  18. k []byte
  19. v []byte
  20. path []bool
  21. h []byte
  22. }
  23. type params struct {
  24. maxLevels int
  25. hashFunction HashFunction
  26. emptyHash []byte
  27. dbg *dbgStats
  28. }
  29. type kv struct {
  30. pos int // original position in the inputted array
  31. keyPath []byte
  32. k []byte
  33. v []byte
  34. }
  35. func keysValuesToKvs(maxLevels int, ks, vs [][]byte) ([]kv, []Invalid, error) {
  36. if len(ks) != len(vs) {
  37. return nil, nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
  38. len(ks), len(vs))
  39. }
  40. var invalids []Invalid
  41. var kvs []kv
  42. for i := 0; i < len(ks); i++ {
  43. keyPath, err := keyPathFromKey(maxLevels, ks[i])
  44. if err != nil {
  45. invalids = append(invalids, Invalid{i, err})
  46. continue
  47. }
  48. if err := checkKeyValueLen(ks[i], vs[i]); err != nil {
  49. invalids = append(invalids, Invalid{i, err})
  50. continue
  51. }
  52. var kvsI kv
  53. kvsI.pos = i
  54. kvsI.keyPath = keyPath
  55. kvsI.k = ks[i]
  56. kvsI.v = vs[i]
  57. kvs = append(kvs, kvsI)
  58. }
  59. return kvs, invalids, nil
  60. }
  61. // vt stands for virtual tree. It's a tree that does not have any computed hash
  62. // while placing the leafs. Once all the leafs are placed, it computes all the
  63. // hashes. In this way, each node hash is only computed one time (at the end)
  64. // and the tree is computed in memory.
  65. type vt struct {
  66. root *node
  67. params *params
  68. }
  69. func newVT(maxLevels int, hash HashFunction) vt {
  70. return vt{
  71. root: nil,
  72. params: &params{
  73. maxLevels: maxLevels,
  74. hashFunction: hash,
  75. emptyHash: make([]byte, hash.Len()), // empty
  76. },
  77. }
  78. }
  79. // addBatch adds a batch of key-values to the VirtualTree. Returns an array
  80. // containing the indexes of the keys failed to add. Does not include the
  81. // computation of hashes of the nodes neither the storage of the key-values of
  82. // the tree into the db. After addBatch, vt.computeHashes should be called to
  83. // compute the hashes of all the nodes of the tree.
  84. func (t *vt) addBatch(ks, vs [][]byte) ([]Invalid, error) {
  85. nCPU := flp2(runtime.NumCPU())
  86. if nCPU == 1 || len(ks) < nCPU {
  87. var invalids []Invalid
  88. for i := 0; i < len(ks); i++ {
  89. if err := t.add(0, ks[i], vs[i]); err != nil {
  90. invalids = append(invalids, Invalid{i, err})
  91. }
  92. }
  93. return invalids, nil
  94. }
  95. l := int(math.Log2(float64(nCPU)))
  96. kvs, invalids, err := keysValuesToKvs(t.params.maxLevels, ks, vs)
  97. if err != nil {
  98. return invalids, err
  99. }
  100. buckets := splitInBuckets(kvs, nCPU)
  101. nodesAtL, err := t.getNodesAtLevel(l)
  102. if err != nil {
  103. return nil, err
  104. }
  105. if len(nodesAtL) != nCPU && t.root != nil {
  106. /*
  107. Already populated Tree but Unbalanced
  108. - Need to fill M1 and M2, and then will be able to continue with the flow
  109. - Search for M1 & M2 in the inputed Keys
  110. - Add M1 & M2 to the Tree
  111. - From here can continue with the flow
  112. R
  113. / \
  114. / \
  115. / \
  116. * *
  117. | \
  118. | \
  119. | \
  120. L: M1 * M2 * (where M1 and M2 are empty)
  121. / | /
  122. / | /
  123. / | /
  124. A * *
  125. / \ | \
  126. / \ | \
  127. / \ | \
  128. B * * C
  129. / \ |\
  130. ... ... | \
  131. | \
  132. D E
  133. */
  134. // add one key at each bucket, and then continue with the flow
  135. for i := 0; i < len(buckets); i++ {
  136. // add one leaf of the bucket, if there is an error when
  137. // adding the k-v, try to add the next one of the bucket
  138. // (until one is added)
  139. inserted := -1
  140. for j := 0; j < len(buckets[i]); j++ {
  141. if err := t.add(0, buckets[i][j].k, buckets[i][j].v); err == nil {
  142. inserted = j
  143. break
  144. }
  145. }
  146. // remove the inserted element from buckets[i]
  147. if inserted != -1 {
  148. buckets[i] = append(buckets[i][:inserted], buckets[i][inserted+1:]...)
  149. }
  150. }
  151. nodesAtL, err = t.getNodesAtLevel(l)
  152. if err != nil {
  153. return nil, err
  154. }
  155. }
  156. if len(nodesAtL) != nCPU {
  157. return nil, fmt.Errorf("This error should not be reached."+
  158. " len(nodesAtL) != nCPU, len(nodesAtL)=%d, nCPU=%d."+
  159. " Please report it in a new issue:"+
  160. " https://github.com/vocdoni/arbo/issues/new", len(nodesAtL), nCPU)
  161. }
  162. subRoots := make([]*node, nCPU)
  163. invalidsInBucket := make([][]Invalid, nCPU)
  164. var wg sync.WaitGroup
  165. wg.Add(nCPU)
  166. for i := 0; i < nCPU; i++ {
  167. go func(cpu int) {
  168. bucketVT := newVT(t.params.maxLevels, t.params.hashFunction)
  169. bucketVT.root = nodesAtL[cpu]
  170. for j := 0; j < len(buckets[cpu]); j++ {
  171. if err := bucketVT.add(l, buckets[cpu][j].k,
  172. buckets[cpu][j].v); err != nil {
  173. invalidsInBucket[cpu] = append(invalidsInBucket[cpu],
  174. Invalid{buckets[cpu][j].pos, err})
  175. }
  176. }
  177. subRoots[cpu] = bucketVT.root
  178. wg.Done()
  179. }(i)
  180. }
  181. wg.Wait()
  182. for i := 0; i < len(invalidsInBucket); i++ {
  183. invalids = append(invalids, invalidsInBucket[i]...)
  184. }
  185. newRootNode, err := upFromNodes(subRoots)
  186. if err != nil {
  187. return nil, err
  188. }
  189. t.root = newRootNode
  190. return invalids, nil
  191. }
  192. func (t *vt) getNodesAtLevel(l int) ([]*node, error) {
  193. if t.root == nil {
  194. var r []*node
  195. nChilds := int(math.Pow(2, float64(l))) //nolint:gomnd
  196. for i := 0; i < nChilds; i++ {
  197. r = append(r, nil)
  198. }
  199. return r, nil
  200. }
  201. return t.root.getNodesAtLevel(0, l)
  202. }
  203. func (n *node) getNodesAtLevel(currLvl, l int) ([]*node, error) {
  204. if n == nil {
  205. var r []*node
  206. nChilds := int(math.Pow(2, float64(l-currLvl))) //nolint:gomnd
  207. for i := 0; i < nChilds; i++ {
  208. r = append(r, nil)
  209. }
  210. return r, nil
  211. }
  212. typ := n.typ()
  213. if currLvl == l && typ != vtEmpty {
  214. return []*node{n}, nil
  215. }
  216. if currLvl >= l {
  217. return nil, fmt.Errorf("This error should not be reached."+
  218. " currLvl >= l, currLvl=%d, l=%d."+
  219. " Please report it in a new issue:"+
  220. " https://github.com/vocdoni/arbo/issues/new", currLvl, l)
  221. }
  222. var nodes []*node
  223. nodesL, err := n.l.getNodesAtLevel(currLvl+1, l)
  224. if err != nil {
  225. return nil, err
  226. }
  227. nodes = append(nodes, nodesL...)
  228. nodesR, err := n.r.getNodesAtLevel(currLvl+1, l)
  229. if err != nil {
  230. return nil, err
  231. }
  232. nodes = append(nodes, nodesR...)
  233. return nodes, nil
  234. }
  235. // upFromNodes builds the tree from the bottom to up
  236. func upFromNodes(ns []*node) (*node, error) {
  237. if len(ns) == 1 {
  238. return ns[0], nil
  239. }
  240. var res []*node
  241. for i := 0; i < len(ns); i += 2 {
  242. if (ns[i].typ() == vtEmpty && ns[i+1].typ() == vtEmpty) ||
  243. (ns[i].typ() == vtLeaf && ns[i+1].typ() == vtEmpty) {
  244. // when both sub nodes are empty, the parent is also empty
  245. // or
  246. // when 1st sub node is a leaf but the 2nd is empty, the
  247. // leaf is used as 'parent'
  248. res = append(res, ns[i])
  249. continue
  250. }
  251. if ns[i].typ() == vtEmpty && ns[i+1].typ() == vtLeaf {
  252. // when 2nd sub node is a leaf but the 1st is empty, the
  253. // leaf is used as 'parent'
  254. res = append(res, ns[i+1])
  255. continue
  256. }
  257. n := &node{
  258. l: ns[i],
  259. r: ns[i+1],
  260. }
  261. res = append(res, n)
  262. }
  263. return upFromNodes(res)
  264. }
  265. // add adds a key&value as a leaf in the VirtualTree
  266. func (t *vt) add(fromLvl int, k, v []byte) error {
  267. leaf, err := newLeafNode(t.params, k, v)
  268. if err != nil {
  269. return err
  270. }
  271. if t.root == nil {
  272. t.root = leaf
  273. return nil
  274. }
  275. if err := t.root.add(t.params, fromLvl, leaf); err != nil {
  276. return err
  277. }
  278. return nil
  279. }
  280. // computeHashes should be called after all the vt.add is used, once all the
  281. // leafs are in the tree. Computes the hashes of the tree, parallelizing in the
  282. // available CPUs.
  283. func (t *vt) computeHashes() ([][2][]byte, error) {
  284. var err error
  285. nCPU := flp2(runtime.NumCPU())
  286. l := int(math.Log2(float64(nCPU)))
  287. nodesAtL, err := t.getNodesAtLevel(l)
  288. if err != nil {
  289. return nil, err
  290. }
  291. subRoots := make([]*node, nCPU)
  292. bucketPairs := make([][][2][]byte, nCPU)
  293. dbgStatsPerBucket := make([]*dbgStats, nCPU)
  294. errs := make([]error, nCPU)
  295. var wg sync.WaitGroup
  296. wg.Add(nCPU)
  297. for i := 0; i < nCPU; i++ {
  298. go func(cpu int) {
  299. bucketVT := newVT(t.params.maxLevels, t.params.hashFunction)
  300. bucketVT.params.dbg = newDbgStats()
  301. bucketVT.root = nodesAtL[cpu]
  302. var err error
  303. bucketPairs[cpu], err = bucketVT.root.computeHashes(l-1,
  304. t.params.maxLevels, bucketVT.params, bucketPairs[cpu])
  305. if err != nil {
  306. errs[cpu] = err
  307. }
  308. subRoots[cpu] = bucketVT.root
  309. dbgStatsPerBucket[cpu] = bucketVT.params.dbg
  310. wg.Done()
  311. }(i)
  312. }
  313. wg.Wait()
  314. for i := 0; i < len(errs); i++ {
  315. if errs[i] != nil {
  316. return nil, errs[i]
  317. }
  318. }
  319. for i := 0; i < len(dbgStatsPerBucket); i++ {
  320. t.params.dbg.add(dbgStatsPerBucket[i])
  321. }
  322. var pairs [][2][]byte
  323. for i := 0; i < len(bucketPairs); i++ {
  324. pairs = append(pairs, bucketPairs[i]...)
  325. }
  326. nodesAtL, err = t.getNodesAtLevel(l)
  327. if err != nil {
  328. return nil, err
  329. }
  330. for i := 0; i < len(nodesAtL); i++ {
  331. nodesAtL = subRoots
  332. }
  333. pairs, err = t.root.computeHashes(0, l, t.params, pairs)
  334. if err != nil {
  335. return nil, err
  336. }
  337. return pairs, nil
  338. }
  339. func newLeafNode(p *params, k, v []byte) (*node, error) {
  340. if err := checkKeyValueLen(k, v); err != nil {
  341. return nil, err
  342. }
  343. keyPath, err := keyPathFromKey(p.maxLevels, k)
  344. if err != nil {
  345. return nil, err
  346. }
  347. path := getPath(p.maxLevels, keyPath)
  348. n := &node{
  349. k: k,
  350. v: v,
  351. path: path,
  352. }
  353. return n, nil
  354. }
  355. type virtualNodeType int
  356. const (
  357. vtEmpty = 0 // for convenience uses same value that PrefixValueEmpty
  358. vtLeaf = 1 // for convenience uses same value that PrefixValueLeaf
  359. vtMid = 2 // for convenience uses same value that PrefixValueIntermediate
  360. )
  361. func (n *node) typ() virtualNodeType {
  362. if n == nil {
  363. return vtEmpty
  364. }
  365. if n.l == nil && n.r == nil && n.k != nil {
  366. return vtLeaf
  367. }
  368. if n.l != nil || n.r != nil {
  369. return vtMid
  370. }
  371. return vtEmpty
  372. }
  373. func (n *node) add(p *params, currLvl int, leaf *node) error {
  374. if currLvl > p.maxLevels-1 {
  375. return ErrMaxVirtualLevel
  376. }
  377. if n == nil {
  378. // n = leaf // TMP!
  379. return nil
  380. }
  381. t := n.typ()
  382. switch t {
  383. case vtMid:
  384. if leaf.path[currLvl] {
  385. //right
  386. if n.r == nil {
  387. // empty sub-node, add the leaf here
  388. n.r = leaf
  389. return nil
  390. }
  391. if err := n.r.add(p, currLvl+1, leaf); err != nil {
  392. return err
  393. }
  394. } else {
  395. if n.l == nil {
  396. // empty sub-node, add the leaf here
  397. n.l = leaf
  398. return nil
  399. }
  400. if err := n.l.add(p, currLvl+1, leaf); err != nil {
  401. return err
  402. }
  403. }
  404. case vtLeaf:
  405. if bytes.Equal(n.k, leaf.k) {
  406. return fmt.Errorf("%s. Existing node: %s, trying to add node: %s",
  407. ErrKeyAlreadyExists, hex.EncodeToString(n.k),
  408. hex.EncodeToString(leaf.k))
  409. }
  410. oldLeaf := &node{
  411. k: n.k,
  412. v: n.v,
  413. path: n.path,
  414. }
  415. // remove values from current node (converting it to mid node)
  416. n.k = nil
  417. n.v = nil
  418. n.h = nil
  419. n.path = nil
  420. if err := n.downUntilDivergence(p, currLvl, oldLeaf, leaf); err != nil {
  421. return err
  422. }
  423. case vtEmpty:
  424. return fmt.Errorf("virtual tree node.add() with empty node %v", n)
  425. default:
  426. return fmt.Errorf("virtual tree node.add() with unknown node type %v", n)
  427. }
  428. return nil
  429. }
  430. func (n *node) downUntilDivergence(p *params, currLvl int, oldLeaf, newLeaf *node) error {
  431. if currLvl > p.maxLevels-1 {
  432. return ErrMaxVirtualLevel
  433. }
  434. if oldLeaf.path[currLvl] != newLeaf.path[currLvl] {
  435. // reached divergence in next level
  436. if newLeaf.path[currLvl] {
  437. n.l = oldLeaf
  438. n.r = newLeaf
  439. } else {
  440. n.l = newLeaf
  441. n.r = oldLeaf
  442. }
  443. return nil
  444. }
  445. // no divergence yet, continue going down
  446. if newLeaf.path[currLvl] {
  447. // right
  448. n.r = &node{}
  449. if err := n.r.downUntilDivergence(p, currLvl+1, oldLeaf, newLeaf); err != nil {
  450. return err
  451. }
  452. } else {
  453. // left
  454. n.l = &node{}
  455. if err := n.l.downUntilDivergence(p, currLvl+1, oldLeaf, newLeaf); err != nil {
  456. return err
  457. }
  458. }
  459. return nil
  460. }
  461. func splitInBuckets(kvs []kv, nBuckets int) [][]kv {
  462. buckets := make([][]kv, nBuckets)
  463. // 1. classify the keyvalues into buckets
  464. for i := 0; i < len(kvs); i++ {
  465. pair := kvs[i]
  466. // bucketnum := keyToBucket(pair.k, nBuckets)
  467. bucketnum := keyToBucket(pair.keyPath, nBuckets)
  468. buckets[bucketnum] = append(buckets[bucketnum], pair)
  469. }
  470. return buckets
  471. }
  472. // TODO rename in a more 'real' name (calculate bucket from/for key)
  473. func keyToBucket(k []byte, nBuckets int) int {
  474. nLevels := int(math.Log2(float64(nBuckets)))
  475. b := make([]int, nBuckets)
  476. for i := 0; i < nBuckets; i++ {
  477. b[i] = i
  478. }
  479. r := b
  480. mid := len(r) / 2 //nolint:gomnd
  481. for i := 0; i < nLevels; i++ {
  482. if int(k[i/8]&(1<<(i%8))) != 0 {
  483. r = r[mid:]
  484. mid = len(r) / 2 //nolint:gomnd
  485. } else {
  486. r = r[:mid]
  487. mid = len(r) / 2 //nolint:gomnd
  488. }
  489. }
  490. return r[0]
  491. }
  492. // flp2 computes the floor power of 2, the highest power of 2 under the given
  493. // value.
  494. func flp2(n int) int {
  495. res := 0
  496. for i := n; i >= 1; i-- {
  497. if (i & (i - 1)) == 0 {
  498. res = i
  499. break
  500. }
  501. }
  502. return res
  503. }
  504. // computeHashes computes the hashes under the node from which is called the
  505. // method. Returns an array of key-values to store in the db
  506. func (n *node) computeHashes(currLvl, maxLvl int, p *params, pairs [][2][]byte) (
  507. [][2][]byte, error) {
  508. if n == nil || currLvl >= maxLvl {
  509. // no need to compute any hash
  510. return pairs, nil
  511. }
  512. if pairs == nil {
  513. pairs = [][2][]byte{}
  514. }
  515. var err error
  516. t := n.typ()
  517. switch t {
  518. case vtLeaf:
  519. p.dbg.incHash()
  520. leafKey, leafValue, err := newLeafValue(p.hashFunction, n.k, n.v)
  521. if err != nil {
  522. return pairs, err
  523. }
  524. n.h = leafKey
  525. kv := [2][]byte{leafKey, leafValue}
  526. pairs = append(pairs, kv)
  527. case vtMid:
  528. if n.l != nil {
  529. pairs, err = n.l.computeHashes(currLvl+1, maxLvl, p, pairs)
  530. if err != nil {
  531. return pairs, err
  532. }
  533. } else {
  534. n.l = &node{
  535. h: p.emptyHash,
  536. }
  537. }
  538. if n.r != nil {
  539. pairs, err = n.r.computeHashes(currLvl+1, maxLvl, p, pairs)
  540. if err != nil {
  541. return pairs, err
  542. }
  543. } else {
  544. n.r = &node{
  545. h: p.emptyHash,
  546. }
  547. }
  548. // once the sub nodes are computed, can compute the current node
  549. // hash
  550. p.dbg.incHash()
  551. k, v, err := newIntermediate(p.hashFunction, n.l.h, n.r.h)
  552. if err != nil {
  553. return nil, err
  554. }
  555. n.h = k
  556. kv := [2][]byte{k, v}
  557. pairs = append(pairs, kv)
  558. case vtEmpty:
  559. default:
  560. return nil, fmt.Errorf("error: n.computeHashes type (%d) no match", t)
  561. }
  562. return pairs, nil
  563. }
  564. //nolint:unused
  565. func (t *vt) graphviz(w io.Writer) error {
  566. fmt.Fprintf(w, `digraph hierarchy {
  567. node [fontname=Monospace,fontsize=10,shape=box]
  568. `)
  569. if _, err := t.root.graphviz(w, t.params, 0); err != nil {
  570. return err
  571. }
  572. fmt.Fprintf(w, "}\n")
  573. return nil
  574. }
  575. //nolint:unused
  576. func (n *node) graphviz(w io.Writer, p *params, nEmpties int) (int, error) {
  577. if n == nil {
  578. return nEmpties, nil
  579. }
  580. t := n.typ()
  581. switch t {
  582. case vtLeaf:
  583. leafKey, _, err := newLeafValue(p.hashFunction, n.k, n.v)
  584. if err != nil {
  585. return nEmpties, err
  586. }
  587. fmt.Fprintf(w, "\"%p\" [style=filled,label=\"%v\"];\n", n, hex.EncodeToString(leafKey[:nChars]))
  588. k := n.k
  589. v := n.v
  590. if len(n.k) >= nChars {
  591. k = n.k[:nChars]
  592. }
  593. if len(n.v) >= nChars {
  594. v = n.v[:nChars]
  595. }
  596. fmt.Fprintf(w, "\"%p\" -> {\"k:%v\\nv:%v\"}\n", n,
  597. hex.EncodeToString(k),
  598. hex.EncodeToString(v))
  599. fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
  600. hex.EncodeToString(k),
  601. hex.EncodeToString(v))
  602. case vtMid:
  603. fmt.Fprintf(w, "\"%p\" [label=\"\"];\n", n)
  604. lStr := fmt.Sprintf("%p", n.l)
  605. rStr := fmt.Sprintf("%p", n.r)
  606. eStr := ""
  607. if n.l == nil {
  608. lStr = fmt.Sprintf("empty%v", nEmpties)
  609. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  610. lStr)
  611. nEmpties++
  612. }
  613. if n.r == nil {
  614. rStr = fmt.Sprintf("empty%v", nEmpties)
  615. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  616. rStr)
  617. nEmpties++
  618. }
  619. fmt.Fprintf(w, "\"%p\" -> {\"%v\" \"%v\"}\n", n, lStr, rStr)
  620. fmt.Fprint(w, eStr)
  621. nEmpties, err := n.l.graphviz(w, p, nEmpties)
  622. if err != nil {
  623. return nEmpties, err
  624. }
  625. nEmpties, err = n.r.graphviz(w, p, nEmpties)
  626. if err != nil {
  627. return nEmpties, err
  628. }
  629. case vtEmpty:
  630. default:
  631. return nEmpties, fmt.Errorf("ERR")
  632. }
  633. return nEmpties, nil
  634. }
  635. //nolint:unused
  636. func (t *vt) printGraphviz() error {
  637. w := bytes.NewBufferString("")
  638. fmt.Fprintf(w,
  639. "--------\nGraphviz:\n")
  640. err := t.graphviz(w)
  641. if err != nil {
  642. fmt.Println(w)
  643. return err
  644. }
  645. fmt.Fprintf(w,
  646. "End of Graphviz --------\n")
  647. fmt.Println(w)
  648. return nil
  649. }