You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

505 lines
11 KiB

  1. // Package arbo > vt.go implements the Virtual Tree, which computes a tree
  2. // without computing any hash. With the idea of once all the leafs are placed in
  3. // their positions, the hashes can be computed, avoiding computing a node hash
  4. // more than one time.
  5. package arbo
  6. import (
  7. "bytes"
  8. "encoding/hex"
  9. "fmt"
  10. "io"
  11. "math"
  12. "runtime"
  13. "sync"
  14. )
  15. type node struct {
  16. l *node
  17. r *node
  18. k []byte
  19. v []byte
  20. path []bool
  21. h []byte
  22. }
  23. type params struct {
  24. maxLevels int
  25. hashFunction HashFunction
  26. emptyHash []byte
  27. dbg *dbgStats
  28. }
  29. func (p *params) keysValuesToKvs(ks, vs [][]byte) ([]kv, error) {
  30. if len(ks) != len(vs) {
  31. return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
  32. len(ks), len(vs))
  33. }
  34. kvs := make([]kv, len(ks))
  35. for i := 0; i < len(ks); i++ {
  36. keyPath := make([]byte, p.hashFunction.Len())
  37. copy(keyPath[:], ks[i])
  38. kvs[i].pos = i
  39. kvs[i].keyPath = keyPath
  40. kvs[i].k = ks[i]
  41. kvs[i].v = vs[i]
  42. }
  43. return kvs, nil
  44. }
  45. // vt stands for virtual tree. It's a tree that does not have any computed hash
  46. // while placing the leafs. Once all the leafs are placed, it computes all the
  47. // hashes. In this way, each node hash is only computed one time (at the end)
  48. // and the tree is computed in memory.
  49. type vt struct {
  50. root *node
  51. params *params
  52. }
  53. func newVT(maxLevels int, hash HashFunction) vt {
  54. return vt{
  55. root: nil,
  56. params: &params{
  57. maxLevels: maxLevels,
  58. hashFunction: hash,
  59. emptyHash: make([]byte, hash.Len()), // empty
  60. },
  61. }
  62. }
  63. func (t *vt) addBatch(ks, vs [][]byte) error {
  64. // parallelize adding leafs in the virtual tree
  65. nCPU := flp2(runtime.NumCPU())
  66. if nCPU == 1 || len(ks) < nCPU {
  67. // var invalids []int
  68. for i := 0; i < len(ks); i++ {
  69. if err := t.add(0, ks[i], vs[i]); err != nil {
  70. // invalids = append(invalids, i)
  71. fmt.Println(err) // TODO WIP
  72. }
  73. }
  74. return nil // TODO invalids
  75. }
  76. l := int(math.Log2(float64(nCPU)))
  77. kvs, err := t.params.keysValuesToKvs(ks, vs)
  78. if err != nil {
  79. return err
  80. }
  81. buckets := splitInBuckets(kvs, nCPU)
  82. nodesAtL, err := t.getNodesAtLevel(l)
  83. if err != nil {
  84. return err
  85. }
  86. // fmt.Println("nodesatL pre-E", len(nodesAtL))
  87. if len(nodesAtL) != nCPU {
  88. // CASE E: add one key at each bucket, and then do CASE D
  89. for i := 0; i < len(buckets); i++ {
  90. // add one leaf of the bucket, if there is an error when
  91. // adding the k-v, try to add the next one of the bucket
  92. // (until one is added)
  93. var inserted int
  94. for j := 0; j < len(buckets[i]); j++ {
  95. if err := t.add(0, buckets[i][j].k, buckets[i][j].v); err == nil {
  96. inserted = j
  97. break
  98. }
  99. }
  100. // remove the inserted element from buckets[i]
  101. buckets[i] = append(buckets[i][:inserted], buckets[i][inserted+1:]...)
  102. }
  103. nodesAtL, err = t.getNodesAtLevel(l)
  104. if err != nil {
  105. return err
  106. }
  107. }
  108. subRoots := make([]*node, nCPU)
  109. invalidsInBucket := make([][]int, nCPU)
  110. var wg sync.WaitGroup
  111. wg.Add(nCPU)
  112. for i := 0; i < nCPU; i++ {
  113. go func(cpu int) {
  114. sortKvs(buckets[cpu])
  115. bucketVT := newVT(t.params.maxLevels-l, t.params.hashFunction)
  116. bucketVT.root = nodesAtL[cpu]
  117. for j := 0; j < len(buckets[cpu]); j++ {
  118. if err = bucketVT.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil {
  119. invalidsInBucket[cpu] = append(invalidsInBucket[cpu], buckets[cpu][j].pos)
  120. }
  121. }
  122. subRoots[cpu] = bucketVT.root
  123. wg.Done()
  124. }(i)
  125. }
  126. wg.Wait()
  127. newRootNode, err := upFromNodes(subRoots)
  128. if err != nil {
  129. return err
  130. }
  131. t.root = newRootNode
  132. return nil
  133. }
  134. func (t *vt) getNodesAtLevel(l int) ([]*node, error) {
  135. if t.root == nil {
  136. return nil, nil
  137. }
  138. return t.root.getNodesAtLevel(0, l)
  139. }
  140. func (n *node) getNodesAtLevel(currLvl, l int) ([]*node, error) {
  141. var nodes []*node
  142. typ := n.typ()
  143. if currLvl == l && typ != vtEmpty {
  144. nodes = append(nodes, n)
  145. return nodes, nil
  146. }
  147. if currLvl >= l {
  148. panic("should not reach this point") // TODO TMP
  149. // return nil, nil
  150. }
  151. if n.l != nil {
  152. nodesL, err := n.l.getNodesAtLevel(currLvl+1, l)
  153. if err != nil {
  154. return nil, err
  155. }
  156. nodes = append(nodes, nodesL...)
  157. }
  158. if n.r != nil {
  159. nodesR, err := n.r.getNodesAtLevel(currLvl+1, l)
  160. if err != nil {
  161. return nil, err
  162. }
  163. nodes = append(nodes, nodesR...)
  164. }
  165. return nodes, nil
  166. }
  167. func upFromNodes(ns []*node) (*node, error) {
  168. if len(ns) == 1 {
  169. return ns[0], nil
  170. }
  171. var res []*node
  172. for i := 0; i < len(ns); i += 2 {
  173. if ns[i].typ() == vtEmpty && ns[i+1].typ() == vtEmpty {
  174. // when both sub nodes are empty, the node is also empty
  175. res = append(res, ns[i]) // empty node
  176. }
  177. n := &node{
  178. l: ns[i],
  179. r: ns[i+1],
  180. }
  181. res = append(res, n)
  182. }
  183. return upFromNodes(res)
  184. }
  185. func (t *vt) add(fromLvl int, k, v []byte) error {
  186. leaf := newLeafNode(t.params, k, v)
  187. if t.root == nil {
  188. t.root = leaf
  189. return nil
  190. }
  191. if err := t.root.add(t.params, fromLvl, leaf); err != nil {
  192. return err
  193. }
  194. return nil
  195. }
  196. // computeHashes should be called after all the vt.add is used, once all the
  197. // leafs are in the tree
  198. func (t *vt) computeHashes() ([][2][]byte, error) {
  199. var pairs [][2][]byte
  200. var err error
  201. // TODO parallelize computeHashes
  202. pairs, err = t.root.computeHashes(t.params, pairs)
  203. if err != nil {
  204. return pairs, err
  205. }
  206. return pairs, nil
  207. }
  208. func newLeafNode(p *params, k, v []byte) *node {
  209. keyPath := make([]byte, p.hashFunction.Len())
  210. copy(keyPath[:], k)
  211. path := getPath(p.maxLevels, keyPath)
  212. n := &node{
  213. k: k,
  214. v: v,
  215. path: path,
  216. }
  217. return n
  218. }
  219. type virtualNodeType int
  220. const (
  221. vtEmpty = 0 // for convenience uses same value that PrefixValueEmpty
  222. vtLeaf = 1 // for convenience uses same value that PrefixValueLeaf
  223. vtMid = 2 // for convenience uses same value that PrefixValueIntermediate
  224. )
  225. func (n *node) typ() virtualNodeType {
  226. if n == nil {
  227. return vtEmpty // TODO decide if return 'vtEmpty' or an error
  228. }
  229. if n.l == nil && n.r == nil && n.k != nil {
  230. return vtLeaf
  231. }
  232. if n.l != nil || n.r != nil {
  233. return vtMid
  234. }
  235. return vtEmpty
  236. }
  237. func (n *node) add(p *params, currLvl int, leaf *node) error {
  238. if currLvl > p.maxLevels-1 {
  239. return fmt.Errorf("max virtual level %d", currLvl)
  240. }
  241. if n == nil {
  242. // n = leaf // TMP!
  243. return nil
  244. }
  245. t := n.typ()
  246. switch t {
  247. case vtMid:
  248. if leaf.path[currLvl] {
  249. //right
  250. if n.r == nil {
  251. // empty sub-node, add the leaf here
  252. n.r = leaf
  253. return nil
  254. }
  255. if err := n.r.add(p, currLvl+1, leaf); err != nil {
  256. return err
  257. }
  258. } else {
  259. if n.l == nil {
  260. // empty sub-node, add the leaf here
  261. n.l = leaf
  262. return nil
  263. }
  264. if err := n.l.add(p, currLvl+1, leaf); err != nil {
  265. return err
  266. }
  267. }
  268. case vtLeaf:
  269. if bytes.Equal(n.k, leaf.k) {
  270. return fmt.Errorf("key already exists. Existing node: %s, trying to add node: %s",
  271. hex.EncodeToString(n.k), hex.EncodeToString(leaf.k))
  272. }
  273. oldLeaf := &node{
  274. k: n.k,
  275. v: n.v,
  276. path: n.path,
  277. }
  278. // remove values from current node (converting it to mid node)
  279. n.k = nil
  280. n.v = nil
  281. n.h = nil
  282. n.path = nil
  283. if err := n.downUntilDivergence(p, currLvl, oldLeaf, leaf); err != nil {
  284. return err
  285. }
  286. case vtEmpty:
  287. panic(fmt.Errorf("EMPTY %v", n)) // TODO TMP
  288. default:
  289. return fmt.Errorf("ERR")
  290. }
  291. return nil
  292. }
  293. func (n *node) downUntilDivergence(p *params, currLvl int, oldLeaf, newLeaf *node) error {
  294. if currLvl > p.maxLevels-1 {
  295. return fmt.Errorf("max virtual level %d", currLvl)
  296. }
  297. if oldLeaf.path[currLvl] != newLeaf.path[currLvl] {
  298. // reached divergence in next level
  299. if newLeaf.path[currLvl] {
  300. n.l = oldLeaf
  301. n.r = newLeaf
  302. } else {
  303. n.l = newLeaf
  304. n.r = oldLeaf
  305. }
  306. return nil
  307. }
  308. // no divergence yet, continue going down
  309. if newLeaf.path[currLvl] {
  310. // right
  311. n.r = &node{}
  312. if err := n.r.downUntilDivergence(p, currLvl+1, oldLeaf, newLeaf); err != nil {
  313. return err
  314. }
  315. } else {
  316. // left
  317. n.l = &node{}
  318. if err := n.l.downUntilDivergence(p, currLvl+1, oldLeaf, newLeaf); err != nil {
  319. return err
  320. }
  321. }
  322. return nil
  323. }
  324. // returns an array of key-values to store in the db
  325. func (n *node) computeHashes(p *params, pairs [][2][]byte) ([][2][]byte, error) {
  326. if pairs == nil {
  327. pairs = [][2][]byte{}
  328. }
  329. var err error
  330. t := n.typ()
  331. switch t {
  332. case vtLeaf:
  333. p.dbg.incHash()
  334. leafKey, leafValue, err := newLeafValue(p.hashFunction, n.k, n.v)
  335. if err != nil {
  336. return pairs, err
  337. }
  338. n.h = leafKey
  339. kv := [2][]byte{leafKey, leafValue}
  340. pairs = append(pairs, kv)
  341. case vtMid:
  342. if n.l != nil {
  343. pairs, err = n.l.computeHashes(p, pairs)
  344. if err != nil {
  345. return pairs, err
  346. }
  347. } else {
  348. n.l = &node{
  349. h: p.emptyHash,
  350. }
  351. }
  352. if n.r != nil {
  353. pairs, err = n.r.computeHashes(p, pairs)
  354. if err != nil {
  355. return pairs, err
  356. }
  357. } else {
  358. n.r = &node{
  359. h: p.emptyHash,
  360. }
  361. }
  362. // once the sub nodes are computed, can compute the current node
  363. // hash
  364. p.dbg.incHash()
  365. k, v, err := newIntermediate(p.hashFunction, n.l.h, n.r.h)
  366. if err != nil {
  367. return nil, err
  368. }
  369. n.h = k
  370. kv := [2][]byte{k, v}
  371. pairs = append(pairs, kv)
  372. default:
  373. return nil, fmt.Errorf("ERR TMP") // TODO
  374. }
  375. return pairs, nil
  376. }
  377. //nolint:unused
  378. func (t *vt) graphviz(w io.Writer) error {
  379. fmt.Fprintf(w, `digraph hierarchy {
  380. node [fontname=Monospace,fontsize=10,shape=box]
  381. `)
  382. if _, err := t.root.graphviz(w, t.params, 0); err != nil {
  383. return err
  384. }
  385. fmt.Fprintf(w, "}\n")
  386. return nil
  387. }
  388. //nolint:unused
  389. func (n *node) graphviz(w io.Writer, p *params, nEmpties int) (int, error) {
  390. nChars := 4 // TODO move to global constant
  391. if n == nil {
  392. return nEmpties, nil
  393. }
  394. t := n.typ()
  395. switch t {
  396. case vtLeaf:
  397. leafKey, _, err := newLeafValue(p.hashFunction, n.k, n.v)
  398. if err != nil {
  399. return nEmpties, err
  400. }
  401. fmt.Fprintf(w, "\"%p\" [style=filled,label=\"%v\"];\n", n, hex.EncodeToString(leafKey[:nChars]))
  402. fmt.Fprintf(w, "\"%p\" -> {\"k:%v\\nv:%v\"}\n", n,
  403. hex.EncodeToString(n.k[:nChars]),
  404. hex.EncodeToString(n.v[:nChars]))
  405. fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
  406. hex.EncodeToString(n.k[:nChars]),
  407. hex.EncodeToString(n.v[:nChars]))
  408. case vtMid:
  409. fmt.Fprintf(w, "\"%p\" [label=\"\"];\n", n)
  410. lStr := fmt.Sprintf("%p", n.l)
  411. rStr := fmt.Sprintf("%p", n.r)
  412. eStr := ""
  413. if n.l == nil {
  414. lStr = fmt.Sprintf("empty%v", nEmpties)
  415. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  416. lStr)
  417. nEmpties++
  418. }
  419. if n.r == nil {
  420. rStr = fmt.Sprintf("empty%v", nEmpties)
  421. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  422. rStr)
  423. nEmpties++
  424. }
  425. fmt.Fprintf(w, "\"%p\" -> {\"%v\" \"%v\"}\n", n, lStr, rStr)
  426. fmt.Fprint(w, eStr)
  427. nEmpties, err := n.l.graphviz(w, p, nEmpties)
  428. if err != nil {
  429. return nEmpties, err
  430. }
  431. nEmpties, err = n.r.graphviz(w, p, nEmpties)
  432. if err != nil {
  433. return nEmpties, err
  434. }
  435. case vtEmpty:
  436. default:
  437. return nEmpties, fmt.Errorf("ERR")
  438. }
  439. return nEmpties, nil
  440. }
  441. //nolint:unused
  442. func (t *vt) printGraphviz() error {
  443. w := bytes.NewBufferString("")
  444. fmt.Fprintf(w,
  445. "--------\nGraphviz:\n")
  446. err := t.graphviz(w)
  447. if err != nil {
  448. fmt.Println(w)
  449. return err
  450. }
  451. fmt.Fprintf(w,
  452. "End of Graphviz --------\n")
  453. fmt.Println(w)
  454. return nil
  455. }