You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

270 lines
7.5 KiB

  1. /**
  2. * @file
  3. * @copyright defined in aergo/LICENSE.txt
  4. */
  5. package trie
  6. import (
  7. "bytes"
  8. "fmt"
  9. "sync"
  10. "sync/atomic"
  11. "github.com/p4u/asmt/db"
  12. )
  13. // LoadCache loads the first layers of the merkle tree given a root
  14. // This is called after a node restarts so that it doesnt become slow with db reads
  15. // LoadCache also updates the Root with the given root.
  16. func (s *Trie) LoadCache(root []byte) error {
  17. if s.db.Store == nil {
  18. return fmt.Errorf("DB not connected to trie")
  19. }
  20. s.db.liveCache = make(map[Hash][][]byte)
  21. ch := make(chan error, 1)
  22. s.loadCache(root, nil, 0, s.TrieHeight, ch)
  23. s.Root = root
  24. return <-ch
  25. }
  26. // loadCache loads the first layers of the merkle tree given a root
  27. func (s *Trie) loadCache(root []byte, batch [][]byte, iBatch, height int, ch chan<- (error)) {
  28. if height < s.CacheHeightLimit || len(root) == 0 {
  29. ch <- nil
  30. return
  31. }
  32. if height%4 == 0 {
  33. // Load the node from db
  34. s.db.lock.Lock()
  35. dbval := s.db.Store.Get(root[:HashLength])
  36. s.db.lock.Unlock()
  37. if len(dbval) == 0 {
  38. ch <- fmt.Errorf("the trie node %x is unavailable in the disk db, db may be corrupted", root)
  39. return
  40. }
  41. //Store node in cache.
  42. var node Hash
  43. copy(node[:], root)
  44. batch = s.parseBatch(dbval)
  45. s.db.liveMux.Lock()
  46. s.db.liveCache[node] = batch
  47. s.db.liveMux.Unlock()
  48. iBatch = 0
  49. if batch[0][0] == 1 {
  50. // if height == 0 this will also return
  51. ch <- nil
  52. return
  53. }
  54. }
  55. if iBatch != 0 && batch[iBatch][HashLength] == 1 {
  56. // Check if node is a leaf node
  57. ch <- nil
  58. } else {
  59. // Load subtree
  60. lnode, rnode := batch[2*iBatch+1], batch[2*iBatch+2]
  61. lch := make(chan error, 1)
  62. rch := make(chan error, 1)
  63. go s.loadCache(lnode, batch, 2*iBatch+1, height-1, lch)
  64. go s.loadCache(rnode, batch, 2*iBatch+2, height-1, rch)
  65. if err := <-lch; err != nil {
  66. ch <- err
  67. return
  68. }
  69. if err := <-rch; err != nil {
  70. ch <- err
  71. return
  72. }
  73. ch <- nil
  74. }
  75. }
  76. // Get fetches the value of a key by going down the current trie root.
  77. func (s *Trie) Get(key []byte) ([]byte, error) {
  78. s.lock.RLock()
  79. defer s.lock.RUnlock()
  80. s.atomicUpdate = false
  81. return s.get(s.Root, key, nil, 0, s.TrieHeight)
  82. }
  83. // GetWithRoot fetches the value of a key by going down for the specified root.
  84. func (s *Trie) GetWithRoot(key []byte, root []byte) ([]byte, error) {
  85. s.lock.RLock()
  86. defer s.lock.RUnlock()
  87. s.atomicUpdate = false
  88. if root == nil {
  89. root = s.Root
  90. }
  91. return s.get(root, key, nil, 0, s.TrieHeight)
  92. }
  93. // get fetches the value of a key given a trie root
  94. func (s *Trie) get(root, key []byte, batch [][]byte, iBatch, height int) ([]byte, error) {
  95. if len(root) == 0 {
  96. // the trie does not contain the key
  97. return nil, nil
  98. }
  99. // Fetch the children of the node
  100. batch, iBatch, lnode, rnode, isShortcut, err := s.loadChildren(root, height, iBatch, batch)
  101. if err != nil {
  102. return nil, err
  103. }
  104. if isShortcut {
  105. if bytes.Equal(lnode[:HashLength], key) {
  106. return rnode[:HashLength], nil
  107. }
  108. // also returns nil if height 0 is not a shortcut
  109. return nil, nil
  110. }
  111. if bitIsSet(key, s.TrieHeight-height) {
  112. return s.get(rnode, key, batch, 2*iBatch+2, height-1)
  113. }
  114. return s.get(lnode, key, batch, 2*iBatch+1, height-1)
  115. }
  116. // WalkResult contains the key and value obtained with a Walk() operation
  117. type WalkResult struct {
  118. Value []byte
  119. Key []byte
  120. }
  121. // Walk finds all the trie stored values from left to right and calls callback.
  122. // If callback returns a number diferent from 0, the walk will stop, else it will continue.
  123. func (s *Trie) Walk(root []byte, callback func(*WalkResult) int32) error {
  124. walkc := make(chan *WalkResult)
  125. s.lock.RLock()
  126. defer s.lock.RUnlock()
  127. if root == nil {
  128. root = s.Root
  129. }
  130. s.atomicUpdate = false
  131. finishedWalk := make(chan (bool), 1)
  132. stop := int32(0)
  133. wg := sync.WaitGroup{} // WaitGroup to avoid Walk() return before all callback executions are finished.
  134. go func() {
  135. for {
  136. select {
  137. case <-finishedWalk:
  138. return
  139. case value := <-walkc:
  140. stopCallback := callback(value)
  141. wg.Done()
  142. // In order to avoid data races we need to check the current value of stop, while at the
  143. // same time we store our callback value. If our callback value is 0 means that we have
  144. // override the previous non-zero value, so we need to restore it.
  145. if cv := atomic.SwapInt32(&stop, stopCallback); cv != 0 || stopCallback != 0 {
  146. if stopCallback == 0 {
  147. atomic.StoreInt32(&stop, cv)
  148. }
  149. // We need to return (instead of break) in order to stop iterating if some callback returns non zero
  150. return
  151. }
  152. }
  153. }
  154. }()
  155. err := s.walk(walkc, &stop, root, nil, 0, s.TrieHeight, &wg)
  156. finishedWalk <- true
  157. wg.Wait()
  158. return err
  159. }
  160. // walk fetches the value of a key given a trie root
  161. func (s *Trie) walk(walkc chan (*WalkResult), stop *int32, root []byte, batch [][]byte, ibatch, height int, wg *sync.WaitGroup) error {
  162. if len(root) == 0 || atomic.LoadInt32(stop) != 0 {
  163. // The sub tree is empty or stop walking
  164. return nil
  165. }
  166. // Fetch the children of the node
  167. batch, ibatch, lnode, rnode, isShortcut, err := s.loadChildren(root, height, ibatch, batch)
  168. if err != nil {
  169. return err
  170. }
  171. if isShortcut {
  172. wg.Add(1)
  173. walkc <- &WalkResult{Value: rnode[:HashLength], Key: lnode[:HashLength]}
  174. return nil
  175. }
  176. // Go left
  177. if err := s.walk(walkc, stop, lnode, batch, 2*ibatch+1, height-1, wg); err != nil {
  178. return err
  179. }
  180. // Go Right
  181. if err := s.walk(walkc, stop, rnode, batch, 2*ibatch+2, height-1, wg); err != nil {
  182. return err
  183. }
  184. return nil
  185. }
  186. // TrieRootExists returns true if the root exists in Database.
  187. func (s *Trie) TrieRootExists(root []byte) bool {
  188. s.db.lock.RLock()
  189. dbval := s.db.Store.Get(root)
  190. s.db.lock.RUnlock()
  191. return len(dbval) != 0
  192. }
  193. // Commit stores the updated nodes to disk.
  194. // Commit should be called for every block otherwise past tries
  195. // are not recorded and it is not possible to revert to them
  196. // (except if AtomicUpdate is used, which records every state).
  197. func (s *Trie) Commit() error {
  198. if s.db.Store == nil {
  199. return fmt.Errorf("DB not connected to trie")
  200. }
  201. // NOTE The tx interface doesnt handle ErrTxnTooBig
  202. txn := s.db.Store.NewTx().(DbTx)
  203. s.StageUpdates(txn)
  204. txn.(db.Transaction).Commit()
  205. return nil
  206. }
  207. // StageUpdates requires a database transaction as input
  208. // Unlike Commit(), it doesnt commit the transaction
  209. // the database transaction MUST be commited otherwise the
  210. // state ROOT will not exist.
  211. func (s *Trie) StageUpdates(txn DbTx) {
  212. s.lock.Lock()
  213. defer s.lock.Unlock()
  214. // Commit the new nodes to database, clear updatedNodes and store the Root in pastTries for reverts.
  215. if !s.atomicUpdate {
  216. // if previously AtomicUpdate was called, then past tries is already updated
  217. s.updatePastTries()
  218. }
  219. s.db.commit(&txn)
  220. s.db.updatedNodes = make(map[Hash][][]byte)
  221. s.prevRoot = s.Root
  222. }
  223. // Stash rolls back the changes made by previous updates
  224. // and loads the cache from before the rollback.
  225. func (s *Trie) Stash(rollbackCache bool) error {
  226. s.lock.Lock()
  227. defer s.lock.Unlock()
  228. s.Root = s.prevRoot
  229. if rollbackCache {
  230. // Making a temporary liveCache requires it to be copied, so it's quicker
  231. // to just load the cache from DB if a block state root was incorrect.
  232. s.db.liveCache = make(map[Hash][][]byte)
  233. ch := make(chan error, 1)
  234. s.loadCache(s.Root, nil, 0, s.TrieHeight, ch)
  235. err := <-ch
  236. if err != nil {
  237. return err
  238. }
  239. } else {
  240. s.db.liveCache = make(map[Hash][][]byte)
  241. }
  242. s.db.updatedNodes = make(map[Hash][][]byte)
  243. // also stash past tries created by Atomic update
  244. for i := len(s.pastTries) - 1; i >= 0; i-- {
  245. if bytes.Equal(s.pastTries[i], s.Root) {
  246. break
  247. } else {
  248. // remove from past tries
  249. s.pastTries = s.pastTries[:len(s.pastTries)-1]
  250. }
  251. }
  252. return nil
  253. }