You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

337 lines
7.1 KiB

  1. // Package arbo > vt.go implements the Virtual Tree, which computes a tree
  2. // without computing any hash. With the idea of once all the leafs are placed in
  3. // their positions, the hashes can be computed, avoiding computing a node hash
  4. // more than one time.
  5. package arbo
  6. import (
  7. "bytes"
  8. "encoding/hex"
  9. "fmt"
  10. "io"
  11. )
  12. type node struct {
  13. l *node
  14. r *node
  15. k []byte
  16. v []byte
  17. path []bool
  18. h []byte
  19. }
  20. type params struct {
  21. maxLevels int
  22. hashFunction HashFunction
  23. emptyHash []byte
  24. }
  25. // vt stands for virtual tree. It's a tree that does not have any computed hash
  26. // while placing the leafs. Once all the leafs are placed, it computes all the
  27. // hashes. In this way, each node hash is only computed one time.
  28. type vt struct {
  29. root *node
  30. params *params
  31. }
  32. func newVT(maxLevels int, hash HashFunction) vt {
  33. return vt{
  34. root: nil,
  35. params: &params{
  36. maxLevels: maxLevels,
  37. hashFunction: hash,
  38. emptyHash: make([]byte, hash.Len()), // empty
  39. },
  40. }
  41. }
  42. func (t *vt) add(fromLvl int, k, v []byte) error {
  43. leaf := newLeafNode(t.params, k, v)
  44. if t.root == nil {
  45. t.root = leaf
  46. return nil
  47. }
  48. if err := t.root.add(t.params, fromLvl, leaf); err != nil {
  49. return err
  50. }
  51. return nil
  52. }
  53. // computeHashes should be called after all the vt.add is used, once all the
  54. // leafs are in the tree
  55. func (t *vt) computeHashes() ([][2][]byte, error) {
  56. var pairs [][2][]byte
  57. var err error
  58. pairs, err = t.root.computeHashes(t.params, pairs)
  59. if err != nil {
  60. return pairs, err
  61. }
  62. return pairs, nil
  63. }
  64. func newLeafNode(p *params, k, v []byte) *node {
  65. keyPath := make([]byte, p.hashFunction.Len())
  66. copy(keyPath[:], k)
  67. path := getPath(p.maxLevels, keyPath)
  68. n := &node{
  69. k: k,
  70. v: v,
  71. path: path,
  72. }
  73. return n
  74. }
  75. type virtualNodeType int
  76. const (
  77. vtEmpty = 0 // for convenience uses same value that PrefixValueEmpty
  78. vtLeaf = 1 // for convenience uses same value that PrefixValueLeaf
  79. vtMid = 2 // for convenience uses same value that PrefixValueIntermediate
  80. )
  81. func (n *node) typ() virtualNodeType {
  82. if n.l == nil && n.r == nil && n.k != nil {
  83. return vtLeaf
  84. }
  85. if n.l != nil || n.r != nil {
  86. return vtMid
  87. }
  88. return vtEmpty
  89. }
  90. func (n *node) add(p *params, currLvl int, leaf *node) error {
  91. if currLvl > p.maxLevels-1 {
  92. return fmt.Errorf("max virtual level %d", currLvl)
  93. }
  94. if n == nil {
  95. // n = leaf // TMP!
  96. return nil
  97. }
  98. t := n.typ()
  99. switch t {
  100. case vtMid:
  101. if leaf.path[currLvl] {
  102. //right
  103. if n.r == nil {
  104. // empty sub-node, add the leaf here
  105. n.r = leaf
  106. return nil
  107. }
  108. if err := n.r.add(p, currLvl+1, leaf); err != nil {
  109. return err
  110. }
  111. } else {
  112. if n.l == nil {
  113. // empty sub-node, add the leaf here
  114. n.l = leaf
  115. return nil
  116. }
  117. if err := n.l.add(p, currLvl+1, leaf); err != nil {
  118. return err
  119. }
  120. }
  121. case vtLeaf:
  122. if bytes.Equal(n.k, leaf.k) {
  123. return fmt.Errorf("key already exists. Existing node: %s, trying to add node: %s",
  124. hex.EncodeToString(n.k), hex.EncodeToString(leaf.k))
  125. }
  126. oldLeaf := &node{
  127. k: n.k,
  128. v: n.v,
  129. path: n.path,
  130. }
  131. // remove values from current node (converting it to mid node)
  132. n.k = nil
  133. n.v = nil
  134. n.h = nil
  135. n.path = nil
  136. if err := n.downUntilDivergence(p, currLvl, oldLeaf, leaf); err != nil {
  137. return err
  138. }
  139. case vtEmpty:
  140. panic(fmt.Errorf("EMPTY %v", n)) // TODO TMP
  141. default:
  142. return fmt.Errorf("ERR")
  143. }
  144. return nil
  145. }
  146. func (n *node) downUntilDivergence(p *params, currLvl int, oldLeaf, newLeaf *node) error {
  147. if currLvl > p.maxLevels-1 {
  148. return fmt.Errorf("max virtual level %d", currLvl)
  149. }
  150. if oldLeaf.path[currLvl] != newLeaf.path[currLvl] {
  151. // reached divergence in next level
  152. if newLeaf.path[currLvl] {
  153. n.l = oldLeaf
  154. n.r = newLeaf
  155. } else {
  156. n.l = newLeaf
  157. n.r = oldLeaf
  158. }
  159. return nil
  160. }
  161. // no divergence yet, continue going down
  162. if newLeaf.path[currLvl] {
  163. // right
  164. n.r = &node{}
  165. if err := n.r.downUntilDivergence(p, currLvl+1, oldLeaf, newLeaf); err != nil {
  166. return err
  167. }
  168. } else {
  169. // left
  170. n.l = &node{}
  171. if err := n.l.downUntilDivergence(p, currLvl+1, oldLeaf, newLeaf); err != nil {
  172. return err
  173. }
  174. }
  175. return nil
  176. }
  177. // returns an array of key-values to store in the db
  178. func (n *node) computeHashes(p *params, pairs [][2][]byte) ([][2][]byte, error) {
  179. if pairs == nil {
  180. pairs = [][2][]byte{}
  181. }
  182. var err error
  183. t := n.typ()
  184. switch t {
  185. case vtLeaf:
  186. leafKey, leafValue, err := newLeafValue(p.hashFunction, n.k, n.v)
  187. if err != nil {
  188. return pairs, err
  189. }
  190. n.h = leafKey
  191. kv := [2][]byte{leafKey, leafValue}
  192. pairs = append(pairs, kv)
  193. case vtMid:
  194. if n.l != nil {
  195. pairs, err = n.l.computeHashes(p, pairs)
  196. if err != nil {
  197. return pairs, err
  198. }
  199. } else {
  200. n.l = &node{
  201. h: p.emptyHash,
  202. }
  203. }
  204. if n.r != nil {
  205. pairs, err = n.r.computeHashes(p, pairs)
  206. if err != nil {
  207. return pairs, err
  208. }
  209. } else {
  210. n.r = &node{
  211. h: p.emptyHash,
  212. }
  213. }
  214. // once the sub nodes are computed, can compute the current node
  215. // hash
  216. k, v, err := newIntermediate(p.hashFunction, n.l.h, n.r.h)
  217. if err != nil {
  218. return nil, err
  219. }
  220. n.h = k
  221. kv := [2][]byte{k, v}
  222. pairs = append(pairs, kv)
  223. default:
  224. return nil, fmt.Errorf("ERR TMP") // TODO
  225. }
  226. return pairs, nil
  227. }
  228. //nolint:unused
  229. func (t *vt) graphviz(w io.Writer) error {
  230. fmt.Fprintf(w, `digraph hierarchy {
  231. node [fontname=Monospace,fontsize=10,shape=box]
  232. `)
  233. if _, err := t.root.graphviz(w, t.params, 0); err != nil {
  234. return err
  235. }
  236. fmt.Fprintf(w, "}\n")
  237. return nil
  238. }
  239. //nolint:unused
  240. func (n *node) graphviz(w io.Writer, p *params, nEmpties int) (int, error) {
  241. nChars := 4 // TODO move to global constant
  242. if n == nil {
  243. return nEmpties, nil
  244. }
  245. t := n.typ()
  246. switch t {
  247. case vtLeaf:
  248. leafKey, _, err := newLeafValue(p.hashFunction, n.k, n.v)
  249. if err != nil {
  250. return nEmpties, err
  251. }
  252. fmt.Fprintf(w, "\"%p\" [style=filled,label=\"%v\"];\n", n, hex.EncodeToString(leafKey[:nChars]))
  253. fmt.Fprintf(w, "\"%p\" -> {\"k:%v\\nv:%v\"}\n", n,
  254. hex.EncodeToString(n.k[:nChars]),
  255. hex.EncodeToString(n.v[:nChars]))
  256. fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
  257. hex.EncodeToString(n.k[:nChars]),
  258. hex.EncodeToString(n.v[:nChars]))
  259. case vtMid:
  260. fmt.Fprintf(w, "\"%p\" [label=\"\"];\n", n)
  261. lStr := fmt.Sprintf("%p", n.l)
  262. rStr := fmt.Sprintf("%p", n.r)
  263. eStr := ""
  264. if n.l == nil {
  265. lStr = fmt.Sprintf("empty%v", nEmpties)
  266. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  267. lStr)
  268. nEmpties++
  269. }
  270. if n.r == nil {
  271. rStr = fmt.Sprintf("empty%v", nEmpties)
  272. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  273. rStr)
  274. nEmpties++
  275. }
  276. fmt.Fprintf(w, "\"%p\" -> {\"%v\" \"%v\"}\n", n, lStr, rStr)
  277. fmt.Fprint(w, eStr)
  278. nEmpties, err := n.l.graphviz(w, p, nEmpties)
  279. if err != nil {
  280. return nEmpties, err
  281. }
  282. nEmpties, err = n.r.graphviz(w, p, nEmpties)
  283. if err != nil {
  284. return nEmpties, err
  285. }
  286. case vtEmpty:
  287. default:
  288. return nEmpties, fmt.Errorf("ERR")
  289. }
  290. return nEmpties, nil
  291. }
  292. //nolint:unused
  293. func (t *vt) printGraphviz() error {
  294. w := bytes.NewBufferString("")
  295. fmt.Fprintf(w,
  296. "--------\nGraphviz:\n")
  297. err := t.graphviz(w)
  298. if err != nil {
  299. fmt.Println(w)
  300. return err
  301. }
  302. fmt.Fprintf(w,
  303. "End of Graphviz --------\n")
  304. fmt.Println(w)
  305. return nil
  306. }