You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

179 lines
3.4 KiB

  1. package arbo
  2. import (
  3. "bytes"
  4. "fmt"
  5. )
  6. type node struct {
  7. l *node
  8. r *node
  9. k []byte
  10. v []byte
  11. path []bool
  12. h []byte
  13. }
  14. type params struct {
  15. maxLevels int
  16. hashFunction HashFunction
  17. emptyHash []byte
  18. }
  19. // vt stands for virtual tree. It's a tree that does not have any computed hash
  20. // while placing the leafs. Once all the leafs are placed, it computes all the
  21. // hashes. In this way, each node hash is only computed one time.
  22. type vt struct {
  23. root *node
  24. params *params
  25. }
  26. func newVT(maxLevels int, hash HashFunction) vt {
  27. return vt{
  28. root: nil,
  29. params: &params{
  30. maxLevels: maxLevels,
  31. hashFunction: hash,
  32. emptyHash: make([]byte, hash.Len()), // empty
  33. },
  34. }
  35. }
  36. func (t *vt) add(k, v []byte) error {
  37. leaf := newLeafNode(t.params, k, v)
  38. if t.root == nil {
  39. t.root = leaf
  40. return nil
  41. }
  42. if err := t.root.add(t.params, 0, leaf); err != nil {
  43. return err
  44. }
  45. return nil
  46. }
  47. func newLeafNode(p *params, k, v []byte) *node {
  48. keyPath := make([]byte, p.hashFunction.Len())
  49. copy(keyPath[:], k)
  50. path := getPath(p.maxLevels, keyPath)
  51. n := &node{
  52. k: k,
  53. v: v,
  54. path: path,
  55. }
  56. return n
  57. }
  58. type virtualNodeType int
  59. const (
  60. vtEmpty = 0 // for convenience uses same value that PrefixValueEmpty
  61. vtLeaf = 1 // for convenience uses same value that PrefixValueLeaf
  62. vtMid = 2 // for convenience uses same value that PrefixValueIntermediate
  63. )
  64. func (n *node) typ() virtualNodeType {
  65. if n.l == nil && n.r == nil && n.k != nil {
  66. return vtLeaf
  67. }
  68. if n.l != nil || n.r != nil {
  69. return vtMid
  70. }
  71. return vtEmpty
  72. }
  73. func (n *node) add(p *params, currLvl int, leaf *node) error {
  74. if currLvl > p.maxLevels-1 {
  75. return fmt.Errorf("max virtual level %d", currLvl)
  76. }
  77. if n == nil {
  78. // n = leaf // TMP!
  79. return nil
  80. }
  81. t := n.typ()
  82. switch t {
  83. case vtMid:
  84. if leaf.path[currLvl] {
  85. //right
  86. if n.r == nil {
  87. // empty sub-node, add the leaf here
  88. n.r = leaf
  89. }
  90. if err := n.r.add(p, currLvl+1, leaf); err != nil {
  91. return err
  92. }
  93. } else {
  94. if n.l == nil {
  95. // empty sub-node, add the leaf here
  96. n.l = leaf
  97. }
  98. if err := n.l.add(p, currLvl+1, leaf); err != nil {
  99. return err
  100. }
  101. }
  102. case vtLeaf:
  103. if bytes.Equal(n.k, leaf.k) {
  104. return fmt.Errorf("key already exists")
  105. }
  106. oldLeaf := &node{
  107. k: n.k,
  108. v: n.v,
  109. path: n.path,
  110. }
  111. // remove values from current node (converting it to mid node)
  112. n.k = nil
  113. n.v = nil
  114. n.path = nil
  115. if err := n.downUntilDivergence(p, currLvl, oldLeaf, leaf); err != nil {
  116. return err
  117. }
  118. default:
  119. return fmt.Errorf("ERR")
  120. }
  121. return nil
  122. }
  123. func (n *node) downUntilDivergence(p *params, currLvl int, oldLeaf, newLeaf *node) error {
  124. if currLvl > p.maxLevels-1 {
  125. return fmt.Errorf("max virtual level %d", currLvl)
  126. }
  127. // if oldLeaf.path[currLvl+1] != newLeaf.path[currLvl+1] {
  128. if oldLeaf.path[currLvl] != newLeaf.path[currLvl] {
  129. // reached divergence in next level
  130. // if newLeaf.path[currLvl+1] {
  131. if newLeaf.path[currLvl] {
  132. n.l = oldLeaf
  133. n.r = newLeaf
  134. } else {
  135. n.l = newLeaf
  136. n.r = oldLeaf
  137. }
  138. return nil
  139. }
  140. // no divergence yet, continue going down
  141. if newLeaf.path[currLvl] {
  142. // right
  143. n.r = &node{}
  144. if err := n.r.downUntilDivergence(p, currLvl+1, oldLeaf, newLeaf); err != nil {
  145. return err
  146. }
  147. } else {
  148. // left
  149. n.l = &node{}
  150. if err := n.l.downUntilDivergence(p, currLvl+1, oldLeaf, newLeaf); err != nil {
  151. return err
  152. }
  153. }
  154. return nil
  155. }
  156. func (n *node) computeHashes() ([]kv, error) {
  157. return nil, nil
  158. }