You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

667 lines
15 KiB

  1. // Package arbo > vt.go implements the Virtual Tree, which computes a tree
  2. // without computing any hash. With the idea of once all the leafs are placed in
  3. // their positions, the hashes can be computed, avoiding computing a node hash
  4. // more than one time.
  5. package arbo
  6. import (
  7. "bytes"
  8. "encoding/hex"
  9. "fmt"
  10. "io"
  11. "math"
  12. "runtime"
  13. "sync"
  14. )
  15. type node struct {
  16. l *node
  17. r *node
  18. k []byte
  19. v []byte
  20. path []bool
  21. h []byte
  22. }
  23. type params struct {
  24. maxLevels int
  25. hashFunction HashFunction
  26. emptyHash []byte
  27. dbg *dbgStats
  28. }
  29. type kv struct {
  30. pos int // original position in the inputted array
  31. keyPath []byte
  32. k []byte
  33. v []byte
  34. }
  35. func (p *params) keysValuesToKvs(ks, vs [][]byte) ([]kv, error) {
  36. if len(ks) != len(vs) {
  37. return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
  38. len(ks), len(vs))
  39. }
  40. kvs := make([]kv, len(ks))
  41. for i := 0; i < len(ks); i++ {
  42. keyPath := make([]byte, p.hashFunction.Len())
  43. copy(keyPath[:], ks[i])
  44. kvs[i].pos = i
  45. kvs[i].keyPath = keyPath
  46. kvs[i].k = ks[i]
  47. kvs[i].v = vs[i]
  48. }
  49. return kvs, nil
  50. }
  51. // vt stands for virtual tree. It's a tree that does not have any computed hash
  52. // while placing the leafs. Once all the leafs are placed, it computes all the
  53. // hashes. In this way, each node hash is only computed one time (at the end)
  54. // and the tree is computed in memory.
  55. type vt struct {
  56. root *node
  57. params *params
  58. }
  59. func newVT(maxLevels int, hash HashFunction) vt {
  60. return vt{
  61. root: nil,
  62. params: &params{
  63. maxLevels: maxLevels,
  64. hashFunction: hash,
  65. emptyHash: make([]byte, hash.Len()), // empty
  66. },
  67. }
  68. }
  69. // addBatch adds a batch of key-values to the VirtualTree. Returns an array
  70. // containing the indexes of the keys failed to add. Does not include the
  71. // computation of hashes of the nodes neither the storage of the key-values of
  72. // the tree into the db. After addBatch, vt.computeHashes should be called to
  73. // compute the hashes of all the nodes of the tree.
  74. func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) {
  75. nCPU := flp2(runtime.NumCPU())
  76. if nCPU == 1 || len(ks) < nCPU {
  77. var invalids []int
  78. for i := 0; i < len(ks); i++ {
  79. if err := t.add(0, ks[i], vs[i]); err != nil {
  80. invalids = append(invalids, i)
  81. }
  82. }
  83. return invalids, nil
  84. }
  85. l := int(math.Log2(float64(nCPU)))
  86. kvs, err := t.params.keysValuesToKvs(ks, vs)
  87. if err != nil {
  88. return nil, err
  89. }
  90. buckets := splitInBuckets(kvs, nCPU)
  91. nodesAtL, err := t.getNodesAtLevel(l)
  92. if err != nil {
  93. return nil, err
  94. }
  95. if len(nodesAtL) != nCPU && t.root != nil {
  96. /*
  97. Already populated Tree but Unbalanced
  98. - Need to fill M1 and M2, and then will be able to continue with the flow
  99. - Search for M1 & M2 in the inputed Keys
  100. - Add M1 & M2 to the Tree
  101. - From here can continue with the flow
  102. R
  103. / \
  104. / \
  105. / \
  106. * *
  107. | \
  108. | \
  109. | \
  110. L: M1 * M2 * (where M1 and M2 are empty)
  111. / | /
  112. / | /
  113. / | /
  114. A * *
  115. / \ | \
  116. / \ | \
  117. / \ | \
  118. B * * C
  119. / \ |\
  120. ... ... | \
  121. | \
  122. D E
  123. */
  124. // add one key at each bucket, and then continue with the flow
  125. for i := 0; i < len(buckets); i++ {
  126. // add one leaf of the bucket, if there is an error when
  127. // adding the k-v, try to add the next one of the bucket
  128. // (until one is added)
  129. inserted := -1
  130. for j := 0; j < len(buckets[i]); j++ {
  131. if err := t.add(0, buckets[i][j].k, buckets[i][j].v); err == nil {
  132. inserted = j
  133. break
  134. }
  135. }
  136. // remove the inserted element from buckets[i]
  137. // fmt.Println("rm-ins", inserted)
  138. if inserted != -1 {
  139. buckets[i] = append(buckets[i][:inserted], buckets[i][inserted+1:]...)
  140. }
  141. }
  142. nodesAtL, err = t.getNodesAtLevel(l)
  143. if err != nil {
  144. return nil, err
  145. }
  146. }
  147. if len(nodesAtL) != nCPU {
  148. panic("should not happen") // TODO TMP
  149. }
  150. subRoots := make([]*node, nCPU)
  151. invalidsInBucket := make([][]int, nCPU)
  152. var wg sync.WaitGroup
  153. wg.Add(nCPU)
  154. for i := 0; i < nCPU; i++ {
  155. go func(cpu int) {
  156. bucketVT := newVT(t.params.maxLevels-l, t.params.hashFunction)
  157. bucketVT.root = nodesAtL[cpu]
  158. for j := 0; j < len(buckets[cpu]); j++ {
  159. if err = bucketVT.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil {
  160. invalidsInBucket[cpu] = append(invalidsInBucket[cpu], buckets[cpu][j].pos)
  161. }
  162. }
  163. subRoots[cpu] = bucketVT.root
  164. wg.Done()
  165. }(i)
  166. }
  167. wg.Wait()
  168. var invalids []int
  169. for i := 0; i < len(invalidsInBucket); i++ {
  170. invalids = append(invalids, invalidsInBucket[i]...)
  171. }
  172. newRootNode, err := upFromNodes(subRoots)
  173. if err != nil {
  174. return nil, err
  175. }
  176. t.root = newRootNode
  177. return invalids, nil
  178. }
  179. func (t *vt) getNodesAtLevel(l int) ([]*node, error) {
  180. if t.root == nil {
  181. var r []*node
  182. nChilds := int(math.Pow(2, float64(l))) //nolint:gomnd
  183. for i := 0; i < nChilds; i++ {
  184. r = append(r, nil)
  185. }
  186. return r, nil
  187. }
  188. return t.root.getNodesAtLevel(0, l)
  189. }
  190. func (n *node) getNodesAtLevel(currLvl, l int) ([]*node, error) {
  191. if n == nil {
  192. var r []*node
  193. nChilds := int(math.Pow(2, float64(l-currLvl))) //nolint:gomnd
  194. for i := 0; i < nChilds; i++ {
  195. r = append(r, nil)
  196. }
  197. return r, nil
  198. }
  199. typ := n.typ()
  200. if currLvl == l && typ != vtEmpty {
  201. return []*node{n}, nil
  202. }
  203. if currLvl >= l {
  204. panic("should not reach this point") // TODO TMP
  205. }
  206. var nodes []*node
  207. nodesL, err := n.l.getNodesAtLevel(currLvl+1, l)
  208. if err != nil {
  209. return nil, err
  210. }
  211. nodes = append(nodes, nodesL...)
  212. nodesR, err := n.r.getNodesAtLevel(currLvl+1, l)
  213. if err != nil {
  214. return nil, err
  215. }
  216. nodes = append(nodes, nodesR...)
  217. return nodes, nil
  218. }
  219. func upFromNodes(ns []*node) (*node, error) {
  220. if len(ns) == 1 {
  221. return ns[0], nil
  222. }
  223. var res []*node
  224. for i := 0; i < len(ns); i += 2 {
  225. if ns[i].typ() == vtEmpty && ns[i+1].typ() == vtEmpty {
  226. // if ns[i] == nil && ns[i+1] == nil {
  227. // when both sub nodes are empty, the node is also empty
  228. res = append(res, ns[i]) // empty node
  229. continue
  230. }
  231. n := &node{
  232. l: ns[i],
  233. r: ns[i+1],
  234. }
  235. res = append(res, n)
  236. }
  237. return upFromNodes(res)
  238. }
  239. func (t *vt) add(fromLvl int, k, v []byte) error {
  240. leaf := newLeafNode(t.params, k, v)
  241. if t.root == nil {
  242. t.root = leaf
  243. return nil
  244. }
  245. if err := t.root.add(t.params, fromLvl, leaf); err != nil {
  246. return err
  247. }
  248. return nil
  249. }
  250. // computeHashes should be called after all the vt.add is used, once all the
  251. // leafs are in the tree
  252. func (t *vt) computeHashes() ([][2][]byte, error) {
  253. var err error
  254. nCPU := flp2(runtime.NumCPU())
  255. l := int(math.Log2(float64(nCPU)))
  256. nodesAtL, err := t.getNodesAtLevel(l)
  257. if err != nil {
  258. return nil, err
  259. }
  260. subRoots := make([]*node, nCPU)
  261. bucketPairs := make([][][2][]byte, nCPU)
  262. dbgStatsPerBucket := make([]*dbgStats, nCPU)
  263. var wg sync.WaitGroup
  264. wg.Add(nCPU)
  265. for i := 0; i < nCPU; i++ {
  266. go func(cpu int) {
  267. bucketVT := newVT(t.params.maxLevels-l, t.params.hashFunction)
  268. bucketVT.params.dbg = newDbgStats()
  269. bucketVT.root = nodesAtL[cpu]
  270. bucketPairs[cpu], err = bucketVT.root.computeHashes(l,
  271. t.params.maxLevels, bucketVT.params, bucketPairs[cpu])
  272. if err != nil {
  273. // TODO WIP
  274. panic("TODO" + err.Error())
  275. }
  276. subRoots[cpu] = bucketVT.root
  277. dbgStatsPerBucket[cpu] = bucketVT.params.dbg
  278. wg.Done()
  279. }(i)
  280. }
  281. wg.Wait()
  282. for i := 0; i < len(dbgStatsPerBucket); i++ {
  283. t.params.dbg.add(dbgStatsPerBucket[i])
  284. }
  285. var pairs [][2][]byte
  286. for i := 0; i < len(bucketPairs); i++ {
  287. pairs = append(pairs, bucketPairs[i]...)
  288. }
  289. nodesAtL, err = t.getNodesAtLevel(l)
  290. if err != nil {
  291. return nil, err
  292. }
  293. for i := 0; i < len(nodesAtL); i++ {
  294. nodesAtL = subRoots
  295. }
  296. pairs, err = t.root.computeHashes(0, l, t.params, pairs)
  297. if err != nil {
  298. return nil, err
  299. }
  300. return pairs, nil
  301. }
  302. func newLeafNode(p *params, k, v []byte) *node {
  303. keyPath := make([]byte, p.hashFunction.Len())
  304. copy(keyPath[:], k)
  305. path := getPath(p.maxLevels, keyPath)
  306. n := &node{
  307. k: k,
  308. v: v,
  309. path: path,
  310. }
  311. return n
  312. }
  313. type virtualNodeType int
  314. const (
  315. vtEmpty = 0 // for convenience uses same value that PrefixValueEmpty
  316. vtLeaf = 1 // for convenience uses same value that PrefixValueLeaf
  317. vtMid = 2 // for convenience uses same value that PrefixValueIntermediate
  318. )
  319. func (n *node) typ() virtualNodeType {
  320. if n == nil {
  321. return vtEmpty // TODO decide if return 'vtEmpty' or an error
  322. }
  323. if n.l == nil && n.r == nil && n.k != nil {
  324. return vtLeaf
  325. }
  326. if n.l != nil || n.r != nil {
  327. return vtMid
  328. }
  329. return vtEmpty
  330. }
  331. func (n *node) add(p *params, currLvl int, leaf *node) error {
  332. if currLvl > p.maxLevels-1 {
  333. return fmt.Errorf("max virtual level %d", currLvl)
  334. }
  335. if n == nil {
  336. // n = leaf // TMP!
  337. return nil
  338. }
  339. t := n.typ()
  340. switch t {
  341. case vtMid:
  342. if leaf.path[currLvl] {
  343. //right
  344. if n.r == nil {
  345. // empty sub-node, add the leaf here
  346. n.r = leaf
  347. return nil
  348. }
  349. if err := n.r.add(p, currLvl+1, leaf); err != nil {
  350. return err
  351. }
  352. } else {
  353. if n.l == nil {
  354. // empty sub-node, add the leaf here
  355. n.l = leaf
  356. return nil
  357. }
  358. if err := n.l.add(p, currLvl+1, leaf); err != nil {
  359. return err
  360. }
  361. }
  362. case vtLeaf:
  363. if bytes.Equal(n.k, leaf.k) {
  364. return fmt.Errorf("key already exists. Existing node: %s, trying to add node: %s",
  365. hex.EncodeToString(n.k), hex.EncodeToString(leaf.k))
  366. }
  367. oldLeaf := &node{
  368. k: n.k,
  369. v: n.v,
  370. path: n.path,
  371. }
  372. // remove values from current node (converting it to mid node)
  373. n.k = nil
  374. n.v = nil
  375. n.h = nil
  376. n.path = nil
  377. if err := n.downUntilDivergence(p, currLvl, oldLeaf, leaf); err != nil {
  378. return err
  379. }
  380. case vtEmpty:
  381. panic(fmt.Errorf("EMPTY %v", n)) // TODO TMP
  382. default:
  383. return fmt.Errorf("ERR") // TODO TMP
  384. }
  385. return nil
  386. }
  387. func (n *node) downUntilDivergence(p *params, currLvl int, oldLeaf, newLeaf *node) error {
  388. if currLvl > p.maxLevels-1 {
  389. return fmt.Errorf("max virtual level %d", currLvl)
  390. }
  391. if oldLeaf.path[currLvl] != newLeaf.path[currLvl] {
  392. // reached divergence in next level
  393. if newLeaf.path[currLvl] {
  394. n.l = oldLeaf
  395. n.r = newLeaf
  396. } else {
  397. n.l = newLeaf
  398. n.r = oldLeaf
  399. }
  400. return nil
  401. }
  402. // no divergence yet, continue going down
  403. if newLeaf.path[currLvl] {
  404. // right
  405. n.r = &node{}
  406. if err := n.r.downUntilDivergence(p, currLvl+1, oldLeaf, newLeaf); err != nil {
  407. return err
  408. }
  409. } else {
  410. // left
  411. n.l = &node{}
  412. if err := n.l.downUntilDivergence(p, currLvl+1, oldLeaf, newLeaf); err != nil {
  413. return err
  414. }
  415. }
  416. return nil
  417. }
  418. func splitInBuckets(kvs []kv, nBuckets int) [][]kv {
  419. buckets := make([][]kv, nBuckets)
  420. // 1. classify the keyvalues into buckets
  421. for i := 0; i < len(kvs); i++ {
  422. pair := kvs[i]
  423. // bucketnum := keyToBucket(pair.k, nBuckets)
  424. bucketnum := keyToBucket(pair.keyPath, nBuckets)
  425. buckets[bucketnum] = append(buckets[bucketnum], pair)
  426. }
  427. return buckets
  428. }
  429. // TODO rename in a more 'real' name (calculate bucket from/for key)
  430. func keyToBucket(k []byte, nBuckets int) int {
  431. nLevels := int(math.Log2(float64(nBuckets)))
  432. b := make([]int, nBuckets)
  433. for i := 0; i < nBuckets; i++ {
  434. b[i] = i
  435. }
  436. r := b
  437. mid := len(r) / 2 //nolint:gomnd
  438. for i := 0; i < nLevels; i++ {
  439. if int(k[i/8]&(1<<(i%8))) != 0 {
  440. r = r[mid:]
  441. mid = len(r) / 2 //nolint:gomnd
  442. } else {
  443. r = r[:mid]
  444. mid = len(r) / 2 //nolint:gomnd
  445. }
  446. }
  447. return r[0]
  448. }
  449. // flp2 computes the floor power of 2, the highest power of 2 under the given
  450. // value.
  451. func flp2(n int) int {
  452. res := 0
  453. for i := n; i >= 1; i-- {
  454. if (i & (i - 1)) == 0 {
  455. res = i
  456. break
  457. }
  458. }
  459. return res
  460. }
  461. // returns an array of key-values to store in the db
  462. func (n *node) computeHashes(currLvl, maxLvl int, p *params, pairs [][2][]byte) (
  463. [][2][]byte, error) {
  464. if n == nil || currLvl >= maxLvl {
  465. // no need to compute any hash
  466. return pairs, nil
  467. }
  468. if pairs == nil {
  469. pairs = [][2][]byte{}
  470. }
  471. var err error
  472. t := n.typ()
  473. switch t {
  474. case vtLeaf:
  475. p.dbg.incHash()
  476. leafKey, leafValue, err := newLeafValue(p.hashFunction, n.k, n.v)
  477. if err != nil {
  478. return pairs, err
  479. }
  480. n.h = leafKey
  481. kv := [2][]byte{leafKey, leafValue}
  482. pairs = append(pairs, kv)
  483. case vtMid:
  484. if n.l != nil {
  485. pairs, err = n.l.computeHashes(currLvl+1, maxLvl, p, pairs)
  486. if err != nil {
  487. return pairs, err
  488. }
  489. } else {
  490. n.l = &node{
  491. h: p.emptyHash,
  492. }
  493. }
  494. if n.r != nil {
  495. pairs, err = n.r.computeHashes(currLvl+1, maxLvl, p, pairs)
  496. if err != nil {
  497. return pairs, err
  498. }
  499. } else {
  500. n.r = &node{
  501. h: p.emptyHash,
  502. }
  503. }
  504. // once the sub nodes are computed, can compute the current node
  505. // hash
  506. p.dbg.incHash()
  507. k, v, err := newIntermediate(p.hashFunction, n.l.h, n.r.h)
  508. if err != nil {
  509. return nil, err
  510. }
  511. n.h = k
  512. kv := [2][]byte{k, v}
  513. pairs = append(pairs, kv)
  514. case vtEmpty:
  515. default:
  516. return nil, fmt.Errorf("ERR:n.computeHashes type (%d) no match", t) // TODO TMP
  517. }
  518. return pairs, nil
  519. }
  520. //nolint:unused
  521. func (t *vt) graphviz(w io.Writer) error {
  522. fmt.Fprintf(w, `digraph hierarchy {
  523. node [fontname=Monospace,fontsize=10,shape=box]
  524. `)
  525. if _, err := t.root.graphviz(w, t.params, 0); err != nil {
  526. return err
  527. }
  528. fmt.Fprintf(w, "}\n")
  529. return nil
  530. }
  531. //nolint:unused
  532. func (n *node) graphviz(w io.Writer, p *params, nEmpties int) (int, error) {
  533. nChars := 4 // TODO move to global constant
  534. if n == nil {
  535. return nEmpties, nil
  536. }
  537. t := n.typ()
  538. switch t {
  539. case vtLeaf:
  540. leafKey, _, err := newLeafValue(p.hashFunction, n.k, n.v)
  541. if err != nil {
  542. return nEmpties, err
  543. }
  544. fmt.Fprintf(w, "\"%p\" [style=filled,label=\"%v\"];\n", n, hex.EncodeToString(leafKey[:nChars]))
  545. fmt.Fprintf(w, "\"%p\" -> {\"k:%v\\nv:%v\"}\n", n,
  546. hex.EncodeToString(n.k[:nChars]),
  547. hex.EncodeToString(n.v[:nChars]))
  548. fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
  549. hex.EncodeToString(n.k[:nChars]),
  550. hex.EncodeToString(n.v[:nChars]))
  551. case vtMid:
  552. fmt.Fprintf(w, "\"%p\" [label=\"\"];\n", n)
  553. lStr := fmt.Sprintf("%p", n.l)
  554. rStr := fmt.Sprintf("%p", n.r)
  555. eStr := ""
  556. if n.l == nil {
  557. lStr = fmt.Sprintf("empty%v", nEmpties)
  558. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  559. lStr)
  560. nEmpties++
  561. }
  562. if n.r == nil {
  563. rStr = fmt.Sprintf("empty%v", nEmpties)
  564. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  565. rStr)
  566. nEmpties++
  567. }
  568. fmt.Fprintf(w, "\"%p\" -> {\"%v\" \"%v\"}\n", n, lStr, rStr)
  569. fmt.Fprint(w, eStr)
  570. nEmpties, err := n.l.graphviz(w, p, nEmpties)
  571. if err != nil {
  572. return nEmpties, err
  573. }
  574. nEmpties, err = n.r.graphviz(w, p, nEmpties)
  575. if err != nil {
  576. return nEmpties, err
  577. }
  578. case vtEmpty:
  579. default:
  580. return nEmpties, fmt.Errorf("ERR")
  581. }
  582. return nEmpties, nil
  583. }
  584. //nolint:unused
  585. func (t *vt) printGraphviz() error {
  586. w := bytes.NewBufferString("")
  587. fmt.Fprintf(w,
  588. "--------\nGraphviz:\n")
  589. err := t.graphviz(w)
  590. if err != nil {
  591. fmt.Println(w)
  592. return err
  593. }
  594. fmt.Fprintf(w,
  595. "End of Graphviz --------\n")
  596. fmt.Println(w)
  597. return nil
  598. }