You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

883 lines
24 KiB

  1. /**
  2. * @file
  3. * @copyright defined in aergo/LICENSE.txt
  4. */
  5. package trie
  6. import (
  7. "bytes"
  8. "runtime"
  9. //"io/ioutil"
  10. "os"
  11. "path"
  12. "time"
  13. //"encoding/hex"
  14. "fmt"
  15. "math/rand"
  16. "sort"
  17. "testing"
  18. "github.com/p4u/asmt/db"
  19. )
  20. func TestTrieEmpty(t *testing.T) {
  21. smt := NewTrie(nil, Hasher, nil)
  22. if len(smt.Root) != 0 {
  23. t.Fatal("empty trie root hash not correct")
  24. }
  25. }
  26. func TestTrieUpdateAndGet(t *testing.T) {
  27. smt := NewTrie(nil, Hasher, nil)
  28. smt.atomicUpdate = false
  29. // Add data to empty trie
  30. keys := getFreshData(10, 32)
  31. values := getFreshData(10, 32)
  32. ch := make(chan mresult, 1)
  33. smt.update(smt.Root, keys, values, nil, 0, smt.TrieHeight, ch)
  34. res := <-ch
  35. root := res.update
  36. // Check all keys have been stored
  37. for i, key := range keys {
  38. value, _ := smt.get(root, key, nil, 0, smt.TrieHeight)
  39. if !bytes.Equal(values[i], value) {
  40. t.Fatal("value not updated")
  41. }
  42. }
  43. // Append to the trie
  44. newKeys := getFreshData(5, 32)
  45. newValues := getFreshData(5, 32)
  46. ch = make(chan mresult, 1)
  47. smt.update(root, newKeys, newValues, nil, 0, smt.TrieHeight, ch)
  48. res = <-ch
  49. newRoot := res.update
  50. if bytes.Equal(root, newRoot) {
  51. t.Fatal("trie not updated")
  52. }
  53. for i, newKey := range newKeys {
  54. newValue, _ := smt.get(newRoot, newKey, nil, 0, smt.TrieHeight)
  55. if !bytes.Equal(newValues[i], newValue) {
  56. t.Fatal("failed to get value")
  57. }
  58. }
  59. }
  60. func TestTrieAtomicUpdate(t *testing.T) {
  61. smt := NewTrie(nil, Hasher, nil)
  62. smt.CacheHeightLimit = 0
  63. keys := getFreshData(1, 32)
  64. values := getFreshData(1, 32)
  65. root, _ := smt.AtomicUpdate(keys, values)
  66. updatedNb := len(smt.db.updatedNodes)
  67. cacheNb := len(smt.db.liveCache)
  68. newvalues := getFreshData(1, 32)
  69. smt.AtomicUpdate(keys, newvalues)
  70. if len(smt.db.updatedNodes) != 2*updatedNb {
  71. t.Fatal("Atomic update doesnt store all tries")
  72. }
  73. if len(smt.db.liveCache) != cacheNb {
  74. t.Fatal("Cache size should remain the same")
  75. }
  76. // check keys of previous atomic update are accessible in
  77. // updated nodes with root.
  78. smt.atomicUpdate = false
  79. for i, key := range keys {
  80. value, _ := smt.get(root, key, nil, 0, smt.TrieHeight)
  81. if !bytes.Equal(values[i], value) {
  82. t.Fatal("failed to get value")
  83. }
  84. }
  85. }
  86. func TestTriePublicUpdateAndGet(t *testing.T) {
  87. smt := NewTrie(nil, Hasher, nil)
  88. smt.CacheHeightLimit = 0
  89. // Add data to empty trie
  90. keys := getFreshData(20, 32)
  91. values := getFreshData(20, 32)
  92. root, _ := smt.Update(keys, values)
  93. updatedNb := len(smt.db.updatedNodes)
  94. cacheNb := len(smt.db.liveCache)
  95. // Check all keys have been stored
  96. for i, key := range keys {
  97. value, _ := smt.Get(key)
  98. if !bytes.Equal(values[i], value) {
  99. t.Fatal("trie not updated")
  100. }
  101. }
  102. if !bytes.Equal(root, smt.Root) {
  103. t.Fatal("Root not stored")
  104. }
  105. newValues := getFreshData(20, 32)
  106. smt.Update(keys, newValues)
  107. if len(smt.db.updatedNodes) != updatedNb {
  108. t.Fatal("multiple updates don't actualise updated nodes")
  109. }
  110. if len(smt.db.liveCache) != cacheNb {
  111. t.Fatal("multiple updates don't actualise liveCache")
  112. }
  113. // Check all keys have been modified
  114. for i, key := range keys {
  115. value, _ := smt.Get(key)
  116. if !bytes.Equal(newValues[i], value) {
  117. t.Fatal("trie not updated")
  118. }
  119. }
  120. }
  121. func TestGetWithRoot(t *testing.T) {
  122. dbPath := t.TempDir()
  123. st := db.NewDB(db.LevelImpl, dbPath)
  124. smt := NewTrie(nil, Hasher, st)
  125. smt.CacheHeightLimit = 0
  126. // Add data to empty trie
  127. keys := getFreshData(20, 32)
  128. values := getFreshData(20, 32)
  129. root, _ := smt.Update(keys, values)
  130. // Check all keys have been stored
  131. for i, key := range keys {
  132. value, _ := smt.Get(key)
  133. if !bytes.Equal(values[i], value) {
  134. t.Fatal("trie not updated")
  135. }
  136. }
  137. if !bytes.Equal(root, smt.Root) {
  138. t.Fatal("Root not stored")
  139. }
  140. if err := smt.Commit(); err != nil {
  141. t.Fatal(err)
  142. }
  143. // Delete two values (0 and 1)
  144. if _, err := smt.Update([][]byte{keys[0], keys[1]}, [][]byte{DefaultLeaf, DefaultLeaf}); err != nil {
  145. t.Fatal(err)
  146. }
  147. // Change one value
  148. oldValue3 := make([]byte, 32)
  149. copy(oldValue3, values[3])
  150. values[3] = getFreshData(1, 32)[0]
  151. if _, err := smt.Update([][]byte{keys[3]}, [][]byte{values[3]}); err != nil {
  152. t.Fatal(err)
  153. }
  154. // Check root has been actually updated
  155. if bytes.Equal(smt.Root, root) {
  156. t.Fatal("root not updated")
  157. }
  158. // Get the third value with the new root
  159. v3, err := smt.GetWithRoot(keys[3], smt.Root)
  160. if err != nil {
  161. t.Fatal(err)
  162. }
  163. if !bytes.Equal(v3, values[3]) {
  164. t.Fatalf("GetWithRoot did not keep the value: %x != %x", v3, values[3])
  165. }
  166. // Get the third value with the old root
  167. v3, err = smt.GetWithRoot(keys[3], root)
  168. if err != nil {
  169. t.Fatal(err)
  170. }
  171. if !bytes.Equal(v3, oldValue3) {
  172. t.Fatalf("GetWithRoot did not keep the value: %x != %x", v3, oldValue3)
  173. }
  174. st.Close()
  175. }
  176. func TestTrieWalk(t *testing.T) {
  177. dbPath := t.TempDir()
  178. st := db.NewDB(db.LevelImpl, dbPath)
  179. smt := NewTrie(nil, Hasher, st)
  180. smt.CacheHeightLimit = 0
  181. // Add data to empty trie
  182. keys := getFreshData(20, 32)
  183. values := getFreshData(20, 32)
  184. root, _ := smt.Update(keys, values)
  185. // Check all keys have been stored
  186. for i, key := range keys {
  187. value, _ := smt.Get(key)
  188. if !bytes.Equal(values[i], value) {
  189. t.Fatal("trie not updated")
  190. }
  191. }
  192. if !bytes.Equal(root, smt.Root) {
  193. t.Fatal("Root not stored")
  194. }
  195. // Walk over the whole tree and compare the values
  196. i := 0
  197. if err := smt.Walk(nil, func(v *WalkResult) int32 {
  198. if !bytes.Equal(v.Value, values[i]) {
  199. t.Fatalf("walk value does not match %x != %x", v.Value, values[i])
  200. }
  201. if !bytes.Equal(v.Key, keys[i]) {
  202. t.Fatalf("walk key does not match %x != %x", v.Key, keys[i])
  203. }
  204. i++
  205. return 0
  206. }); err != nil {
  207. t.Fatal(err)
  208. }
  209. // Delete two values (0 and 3)
  210. if _, err := smt.Update([][]byte{keys[0], keys[3]}, [][]byte{DefaultLeaf, DefaultLeaf}); err != nil {
  211. t.Fatal(err)
  212. }
  213. // Delete two elements and walk again
  214. i = 1
  215. if err := smt.Walk(nil, func(v *WalkResult) int32 {
  216. if i == 3 {
  217. i++
  218. }
  219. if !bytes.Equal(v.Value, values[i]) {
  220. t.Fatalf("walk value does not match %x == %x\n", v.Value, values[i])
  221. }
  222. if !bytes.Equal(v.Key, keys[i]) {
  223. t.Fatalf("walk key does not match %x == %x\n", v.Key, keys[i])
  224. }
  225. i++
  226. return 0
  227. }); err != nil {
  228. t.Fatal(err)
  229. }
  230. // Add one new value to preivous deleted key
  231. values[3] = getFreshData(1, 32)[0]
  232. if _, err := smt.Update([][]byte{keys[3]}, [][]byte{values[3]}); err != nil {
  233. t.Fatal(err)
  234. }
  235. // Walk and check again
  236. i = 1
  237. if err := smt.Walk(nil, func(v *WalkResult) int32 {
  238. if !bytes.Equal(v.Value, values[i]) {
  239. t.Fatalf("walk value does not match %x != %x\n", v.Value, values[i])
  240. }
  241. if !bytes.Equal(v.Key, keys[i]) {
  242. t.Fatalf("walk key does not match %x != %x\n", v.Key, keys[i])
  243. }
  244. i++
  245. return 0
  246. }); err != nil {
  247. t.Fatal(err)
  248. }
  249. // Find a specific value and test stop
  250. i = 0
  251. if err := smt.Walk(nil, func(v *WalkResult) int32 {
  252. if bytes.Equal(v.Value, values[5]) {
  253. return 1
  254. }
  255. i++
  256. return 0
  257. }); err != nil {
  258. t.Fatal(err)
  259. }
  260. if i != 4 {
  261. t.Fatalf("Needed more iterations on walk than expected: %d != 4", i)
  262. }
  263. st.Close()
  264. }
  265. func TestTrieDelete(t *testing.T) {
  266. smt := NewTrie(nil, Hasher, nil)
  267. // Add data to empty trie
  268. keys := getFreshData(20, 32)
  269. values := getFreshData(20, 32)
  270. ch := make(chan mresult, 1)
  271. smt.update(smt.Root, keys, values, nil, 0, smt.TrieHeight, ch)
  272. result := <-ch
  273. root := result.update
  274. value, _ := smt.get(root, keys[0], nil, 0, smt.TrieHeight)
  275. if !bytes.Equal(values[0], value) {
  276. t.Fatal("trie not updated")
  277. }
  278. // Delete from trie
  279. // To delete a key, just set it's value to Default leaf hash.
  280. ch = make(chan mresult, 1)
  281. smt.update(root, keys[0:1], [][]byte{DefaultLeaf}, nil, 0, smt.TrieHeight, ch)
  282. result = <-ch
  283. updatedNb := len(smt.db.updatedNodes)
  284. newRoot := result.update
  285. newValue, _ := smt.get(newRoot, keys[0], nil, 0, smt.TrieHeight)
  286. if len(newValue) != 0 {
  287. t.Fatal("Failed to delete from trie")
  288. }
  289. // Remove deleted key from keys and check root with a clean trie.
  290. smt2 := NewTrie(nil, Hasher, nil)
  291. ch = make(chan mresult, 1)
  292. smt2.update(smt.Root, keys[1:], values[1:], nil, 0, smt.TrieHeight, ch)
  293. result = <-ch
  294. cleanRoot := result.update
  295. if !bytes.Equal(newRoot, cleanRoot) {
  296. t.Fatal("roots mismatch")
  297. }
  298. if len(smt2.db.updatedNodes) != updatedNb {
  299. t.Fatal("deleting doesn't actualise updated nodes")
  300. }
  301. //Empty the trie
  302. var newValues [][]byte
  303. for i := 0; i < 20; i++ {
  304. newValues = append(newValues, DefaultLeaf)
  305. }
  306. ch = make(chan mresult, 1)
  307. smt.update(root, keys, newValues, nil, 0, smt.TrieHeight, ch)
  308. result = <-ch
  309. root = result.update
  310. //if !bytes.Equal(smt.DefaultHash(256), root) {
  311. if len(root) != 0 {
  312. t.Fatal("empty trie root hash not correct")
  313. }
  314. // Test deleting an already empty key
  315. smt = NewTrie(nil, Hasher, nil)
  316. keys = getFreshData(2, 32)
  317. values = getFreshData(2, 32)
  318. root, _ = smt.Update(keys, values)
  319. key0 := make([]byte, 32, 32)
  320. key1 := make([]byte, 32, 32)
  321. smt.Update([][]byte{key0, key1}, [][]byte{DefaultLeaf, DefaultLeaf})
  322. if !bytes.Equal(root, smt.Root) {
  323. t.Fatal("deleting a default key shouldnt' modify the tree")
  324. }
  325. }
  326. // test updating and deleting at the same time
  327. func TestTrieUpdateAndDelete(t *testing.T) {
  328. smt := NewTrie(nil, Hasher, nil)
  329. smt.CacheHeightLimit = 0
  330. key0 := make([]byte, 32, 32)
  331. values := getFreshData(1, 32)
  332. root, _ := smt.Update([][]byte{key0}, values)
  333. cacheNb := len(smt.db.liveCache)
  334. updatedNb := len(smt.db.updatedNodes)
  335. smt.atomicUpdate = false
  336. _, _, k, v, isShortcut, _ := smt.loadChildren(root, smt.TrieHeight, 0, nil)
  337. if !isShortcut || !bytes.Equal(k[:HashLength], key0) || !bytes.Equal(v[:HashLength], values[0]) {
  338. t.Fatal("leaf shortcut didn't move up to root")
  339. }
  340. key1 := make([]byte, 32, 32)
  341. // set the last bit
  342. bitSet(key1, 255)
  343. keys := [][]byte{key0, key1}
  344. values = [][]byte{DefaultLeaf, getFreshData(1, 32)[0]}
  345. root, _ = smt.Update(keys, values)
  346. if len(smt.db.liveCache) != cacheNb {
  347. t.Fatal("number of cache nodes not correct after delete")
  348. }
  349. if len(smt.db.updatedNodes) != updatedNb {
  350. t.Fatal("number of cache nodes not correct after delete")
  351. }
  352. smt.atomicUpdate = false
  353. _, _, k, v, isShortcut, _ = smt.loadChildren(root, smt.TrieHeight, 0, nil)
  354. if !isShortcut || !bytes.Equal(k[:HashLength], key1) || !bytes.Equal(v[:HashLength], values[1]) {
  355. t.Fatal("leaf shortcut didn't move up to root")
  356. }
  357. }
  358. func TestTrieMerkleProof(t *testing.T) {
  359. smt := NewTrie(nil, Hasher, nil)
  360. // Add data to empty trie
  361. keys := getFreshData(10, 32)
  362. values := getFreshData(10, 32)
  363. smt.Update(keys, values)
  364. for i, key := range keys {
  365. ap, _, k, v, _ := smt.MerkleProof(key)
  366. if !smt.VerifyInclusion(ap, key, values[i]) {
  367. t.Fatalf("failed to verify inclusion proof")
  368. }
  369. if !bytes.Equal(key, k) && !bytes.Equal(values[i], v) {
  370. t.Fatalf("merkle proof didnt return the correct key-value pair")
  371. }
  372. }
  373. emptyKey := Hasher([]byte("non-member"))
  374. ap, included, proofKey, proofValue, _ := smt.MerkleProof(emptyKey)
  375. if included {
  376. t.Fatalf("failed to verify non inclusion proof")
  377. }
  378. if !smt.VerifyNonInclusion(ap, emptyKey, proofValue, proofKey) {
  379. t.Fatalf("failed to verify non inclusion proof")
  380. }
  381. }
  382. func TestTrieMerkleProofCompressed(t *testing.T) {
  383. smt := NewTrie(nil, Hasher, nil)
  384. // Add data to empty trie
  385. keys := getFreshData(10, 32)
  386. values := getFreshData(10, 32)
  387. smt.Update(keys, values)
  388. for i, key := range keys {
  389. bitmap, ap, length, _, k, v, _ := smt.MerkleProofCompressed(key)
  390. if !smt.VerifyInclusionC(bitmap, key, values[i], ap, length) {
  391. t.Fatalf("failed to verify inclusion proof")
  392. }
  393. if !bytes.Equal(key, k) && !bytes.Equal(values[i], v) {
  394. t.Fatalf("merkle proof didnt return the correct key-value pair")
  395. }
  396. }
  397. emptyKey := Hasher([]byte("non-member"))
  398. bitmap, ap, length, included, proofKey, proofValue, _ := smt.MerkleProofCompressed(emptyKey)
  399. if included {
  400. t.Fatalf("failed to verify non inclusion proof")
  401. }
  402. if !smt.VerifyNonInclusionC(ap, length, bitmap, emptyKey, proofValue, proofKey) {
  403. t.Fatalf("failed to verify non inclusion proof")
  404. }
  405. }
  406. func TestTrieCommit(t *testing.T) {
  407. dbPath := path.Join(".aergo", "db")
  408. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  409. _ = os.MkdirAll(dbPath, 0711)
  410. }
  411. st := db.NewDB(db.LevelImpl, dbPath)
  412. smt := NewTrie(nil, Hasher, st)
  413. keys := getFreshData(10, 32)
  414. values := getFreshData(10, 32)
  415. smt.Update(keys, values)
  416. smt.Commit()
  417. // liveCache is deleted so the key is fetched in badger db
  418. smt.db.liveCache = make(map[Hash][][]byte)
  419. for i, key := range keys {
  420. value, _ := smt.Get(key)
  421. if !bytes.Equal(value, values[i]) {
  422. t.Fatal("failed to get value in committed db")
  423. }
  424. }
  425. st.Close()
  426. os.RemoveAll(".aergo")
  427. }
  428. func TestTrieStageUpdates(t *testing.T) {
  429. dbPath := path.Join(".aergo", "db")
  430. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  431. _ = os.MkdirAll(dbPath, 0711)
  432. }
  433. st := db.NewDB(db.LevelImpl, dbPath)
  434. smt := NewTrie(nil, Hasher, st)
  435. keys := getFreshData(10, 32)
  436. values := getFreshData(10, 32)
  437. smt.Update(keys, values)
  438. txn := st.NewTx()
  439. smt.StageUpdates(txn.(DbTx))
  440. txn.Commit()
  441. // liveCache is deleted so the key is fetched in badger db
  442. smt.db.liveCache = make(map[Hash][][]byte)
  443. for i, key := range keys {
  444. value, _ := smt.Get(key)
  445. if !bytes.Equal(value, values[i]) {
  446. t.Fatal("failed to get value in committed db")
  447. }
  448. }
  449. st.Close()
  450. os.RemoveAll(".aergo")
  451. }
  452. func TestTrieRevert(t *testing.T) {
  453. dbPath := path.Join(".aergo", "db")
  454. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  455. _ = os.MkdirAll(dbPath, 0711)
  456. }
  457. st := db.NewDB(db.LevelImpl, dbPath)
  458. smt := NewTrie(nil, Hasher, st)
  459. // Edge case : Test that revert doesnt delete shortcut nodes
  460. // when moved to a different position in tree
  461. key0 := make([]byte, 32, 32)
  462. key1 := make([]byte, 32, 32)
  463. // setting the bit at 251 creates 2 shortcut batches at height 252
  464. bitSet(key1, 251)
  465. values := getFreshData(2, 32)
  466. keys := [][]byte{key0, key1}
  467. root, _ := smt.Update([][]byte{key0}, [][]byte{values[0]})
  468. smt.Commit()
  469. root2, _ := smt.Update([][]byte{key1}, [][]byte{values[1]})
  470. smt.Commit()
  471. smt.Revert(root)
  472. if len(smt.db.Store.Get(root)) == 0 {
  473. t.Fatal("shortcut node shouldnt be deleted by revert")
  474. }
  475. if len(smt.db.Store.Get(root2)) != 0 {
  476. t.Fatal("reverted root should have been deleted")
  477. }
  478. key1 = make([]byte, 32, 32)
  479. // setting the bit at 255 stores the keys as the tip
  480. bitSet(key1, 255)
  481. smt.Update([][]byte{key1}, [][]byte{values[1]})
  482. smt.Commit()
  483. smt.Revert(root)
  484. if len(smt.db.Store.Get(root)) == 0 {
  485. t.Fatal("shortcut node shouldnt be deleted by revert")
  486. }
  487. // Test all nodes are reverted in the usual case
  488. // Add data to empty trie
  489. keys = getFreshData(10, 32)
  490. values = getFreshData(10, 32)
  491. root, _ = smt.Update(keys, values)
  492. smt.Commit()
  493. // Update the values
  494. newValues := getFreshData(10, 32)
  495. smt.Update(keys, newValues)
  496. updatedNodes1 := smt.db.updatedNodes
  497. smt.Commit()
  498. newKeys := getFreshData(10, 32)
  499. newValues = getFreshData(10, 32)
  500. smt.Update(newKeys, newValues)
  501. updatedNodes2 := smt.db.updatedNodes
  502. smt.Commit()
  503. smt.Revert(root)
  504. if !bytes.Equal(smt.Root, root) {
  505. t.Fatal("revert failed")
  506. }
  507. if len(smt.pastTries) != 2 { // contains empty trie + reverted trie
  508. t.Fatal("past tries not updated after revert")
  509. }
  510. // Check all keys have been reverted
  511. for i, key := range keys {
  512. value, _ := smt.Get(key)
  513. if !bytes.Equal(values[i], value) {
  514. t.Fatal("revert failed, values not updated")
  515. }
  516. }
  517. if len(smt.db.liveCache) != 0 {
  518. t.Fatal("live cache not reset after revert")
  519. }
  520. // Check all reverted nodes have been deleted
  521. for node, _ := range updatedNodes2 {
  522. if len(smt.db.Store.Get(node[:])) != 0 {
  523. t.Fatal("nodes not deleted from database", node)
  524. }
  525. }
  526. for node, _ := range updatedNodes1 {
  527. if len(smt.db.Store.Get(node[:])) != 0 {
  528. t.Fatal("nodes not deleted from database", node)
  529. }
  530. }
  531. st.Close()
  532. os.RemoveAll(".aergo")
  533. }
  534. func TestTrieRaisesError(t *testing.T) {
  535. dbPath := path.Join(".aergo", "db")
  536. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  537. _ = os.MkdirAll(dbPath, 0711)
  538. }
  539. st := db.NewDB(db.LevelImpl, dbPath)
  540. smt := NewTrie(nil, Hasher, st)
  541. // Add data to empty trie
  542. keys := getFreshData(10, 32)
  543. values := getFreshData(10, 32)
  544. smt.Update(keys, values)
  545. smt.db.liveCache = make(map[Hash][][]byte)
  546. smt.db.updatedNodes = make(map[Hash][][]byte)
  547. // Check errors are raised is a keys is not in cache nore db
  548. for _, key := range keys {
  549. _, err := smt.Get(key)
  550. if err == nil {
  551. t.Fatal("Error not created if database doesnt have a node")
  552. }
  553. }
  554. _, _, _, _, _, _, err := smt.MerkleProofCompressed(keys[0])
  555. if err == nil {
  556. t.Fatal("Error not created if database doesnt have a node")
  557. }
  558. _, err = smt.Update(keys, values)
  559. if err == nil {
  560. t.Fatal("Error not created if database doesnt have a node")
  561. }
  562. st.Close()
  563. os.RemoveAll(".aergo")
  564. smt = NewTrie(nil, Hasher, nil)
  565. err = smt.Commit()
  566. if err == nil {
  567. t.Fatal("Error not created if database not connected")
  568. }
  569. smt.db.liveCache = make(map[Hash][][]byte)
  570. smt.atomicUpdate = false
  571. _, _, _, _, _, err = smt.loadChildren(make([]byte, 32, 32), smt.TrieHeight, 0, nil)
  572. if err == nil {
  573. t.Fatal("Error not created if database not connected")
  574. }
  575. err = smt.LoadCache(make([]byte, 32))
  576. if err == nil {
  577. t.Fatal("Error not created if database not connected")
  578. }
  579. }
  580. func TestTrieLoadCache(t *testing.T) {
  581. dbPath := path.Join(".aergo", "db")
  582. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  583. _ = os.MkdirAll(dbPath, 0711)
  584. }
  585. st := db.NewDB(db.LevelImpl, dbPath)
  586. smt := NewTrie(nil, Hasher, st)
  587. // Test size of cache
  588. smt.CacheHeightLimit = 0
  589. key0 := make([]byte, 32, 32)
  590. key1 := make([]byte, 32, 32)
  591. bitSet(key1, 255)
  592. values := getFreshData(2, 32)
  593. smt.Update([][]byte{key0, key1}, values)
  594. if len(smt.db.liveCache) != 66 {
  595. // the nodes are at the tip, so 64 + 2 = 66
  596. t.Fatal("cache size incorrect")
  597. }
  598. // Add data to empty trie
  599. keys := getFreshData(10, 32)
  600. values = getFreshData(10, 32)
  601. smt.Update(keys, values)
  602. smt.Commit()
  603. // Simulate node restart by deleting and loading cache
  604. cacheSize := len(smt.db.liveCache)
  605. smt.db.liveCache = make(map[Hash][][]byte)
  606. err := smt.LoadCache(smt.Root)
  607. if err != nil {
  608. t.Fatal(err)
  609. }
  610. if cacheSize != len(smt.db.liveCache) {
  611. t.Fatal("Cache loading from db incorrect")
  612. }
  613. st.Close()
  614. os.RemoveAll(".aergo")
  615. }
  616. func TestHeight0LeafShortcut(t *testing.T) {
  617. keySize := 32
  618. smt := NewTrie(nil, Hasher, nil)
  619. // Add 2 sibling keys that will be stored at height 0
  620. key0 := make([]byte, keySize, keySize)
  621. key1 := make([]byte, keySize, keySize)
  622. bitSet(key1, keySize*8-1)
  623. keys := [][]byte{key0, key1}
  624. values := getFreshData(2, 32)
  625. smt.Update(keys, values)
  626. updatedNb := len(smt.db.updatedNodes)
  627. // Check all keys have been stored
  628. for i, key := range keys {
  629. value, _ := smt.Get(key)
  630. if !bytes.Equal(values[i], value) {
  631. t.Fatal("trie not updated")
  632. }
  633. }
  634. bitmap, ap, length, _, k, v, err := smt.MerkleProofCompressed(key1)
  635. if err != nil {
  636. t.Fatal(err)
  637. }
  638. if !bytes.Equal(key1, k) && !bytes.Equal(values[1], v) {
  639. t.Fatalf("merkle proof didnt return the correct key-value pair")
  640. }
  641. if length != smt.TrieHeight {
  642. t.Fatal("proof should have length equal to trie height for a leaf shortcut")
  643. }
  644. if !smt.VerifyInclusionC(bitmap, key1, values[1], ap, length) {
  645. t.Fatal("failed to verify inclusion proof")
  646. }
  647. // Delete one key and check that the remaining one moved up to the root of the tree
  648. newRoot, _ := smt.AtomicUpdate(keys[0:1], [][]byte{DefaultLeaf})
  649. // Nb of updated nodes remains same because the new shortcut root was already stored at height 0.
  650. if len(smt.db.updatedNodes) != updatedNb {
  651. fmt.Println(len(smt.db.updatedNodes), updatedNb)
  652. t.Fatal("number of cache nodes not correct after delete")
  653. }
  654. smt.atomicUpdate = false
  655. _, _, k, v, isShortcut, err := smt.loadChildren(newRoot, smt.TrieHeight, 0, nil)
  656. if err != nil {
  657. t.Fatal(err)
  658. }
  659. if !isShortcut || !bytes.Equal(k[:HashLength], key1) || !bytes.Equal(v[:HashLength], values[1]) {
  660. t.Fatal("leaf shortcut didn't move up to root")
  661. }
  662. _, _, length, _, k, v, _ = smt.MerkleProofCompressed(key1)
  663. if length != 0 {
  664. t.Fatal("proof should have length equal to trie height for a leaf shortcut")
  665. }
  666. if !bytes.Equal(key1, k) && !bytes.Equal(values[1], v) {
  667. t.Fatalf("merkle proof didnt return the correct key-value pair")
  668. }
  669. }
  670. func TestStash(t *testing.T) {
  671. dbPath := path.Join(".aergo", "db")
  672. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  673. _ = os.MkdirAll(dbPath, 0711)
  674. }
  675. st := db.NewDB(db.LevelImpl, dbPath)
  676. smt := NewTrie(nil, Hasher, st)
  677. // Add data to empty trie
  678. keys := getFreshData(20, 32)
  679. values := getFreshData(20, 32)
  680. root, _ := smt.Update(keys, values)
  681. cacheSize := len(smt.db.liveCache)
  682. smt.Commit()
  683. if len(smt.pastTries) != 1 {
  684. t.Fatal("Past tries not updated after commit")
  685. }
  686. values = getFreshData(20, 32)
  687. smt.Update(keys, values)
  688. smt.Stash(true)
  689. if len(smt.pastTries) != 1 {
  690. t.Fatal("Past tries not updated after commit")
  691. }
  692. if !bytes.Equal(smt.Root, root) {
  693. t.Fatal("Trie not rolled back")
  694. }
  695. if len(smt.db.updatedNodes) != 0 {
  696. t.Fatal("Trie not rolled back")
  697. }
  698. if len(smt.db.liveCache) != cacheSize {
  699. t.Fatal("Trie not rolled back")
  700. }
  701. keys = getFreshData(20, 32)
  702. values = getFreshData(20, 32)
  703. smt.AtomicUpdate(keys, values)
  704. values = getFreshData(20, 32)
  705. smt.AtomicUpdate(keys, values)
  706. if len(smt.pastTries) != 3 {
  707. t.Fatal("Past tries not updated after commit")
  708. }
  709. smt.Stash(true)
  710. if !bytes.Equal(smt.Root, root) {
  711. t.Fatal("Trie not rolled back")
  712. }
  713. if len(smt.db.updatedNodes) != 0 {
  714. t.Fatal("Trie not rolled back")
  715. }
  716. if len(smt.db.liveCache) != cacheSize {
  717. t.Fatal("Trie not rolled back")
  718. }
  719. if len(smt.pastTries) != 1 {
  720. t.Fatal("Past tries not updated after commit")
  721. }
  722. st.Close()
  723. os.RemoveAll(".aergo")
  724. }
  725. func benchmark10MAccounts10Ktps(smt *Trie, b *testing.B) {
  726. //b.ReportAllocs()
  727. keys := getFreshData(100, 32)
  728. values := getFreshData(100, 32)
  729. smt.Update(keys, values)
  730. fmt.Println("\nLoading b.N x 1000 accounts")
  731. for i := 0; i < b.N; i++ {
  732. newkeys := getFreshData(1000, 32)
  733. newvalues := getFreshData(1000, 32)
  734. start := time.Now()
  735. smt.Update(newkeys, newvalues)
  736. end := time.Now()
  737. smt.Commit()
  738. end2 := time.Now()
  739. for j, key := range newkeys {
  740. val, _ := smt.Get(key)
  741. if !bytes.Equal(val, newvalues[j]) {
  742. b.Fatal("new key not included")
  743. }
  744. }
  745. end3 := time.Now()
  746. elapsed := end.Sub(start)
  747. elapsed2 := end2.Sub(end)
  748. elapsed3 := end3.Sub(end2)
  749. var m runtime.MemStats
  750. runtime.ReadMemStats(&m)
  751. fmt.Println(i, " : update time : ", elapsed, "commit time : ", elapsed2,
  752. "\n1000 Get time : ", elapsed3,
  753. "\ndb read : ", smt.LoadDbCounter, " cache read : ", smt.LoadCacheCounter,
  754. "\ncache size : ", len(smt.db.liveCache),
  755. "\nRAM : ", m.Sys/1024/1024, " MiB")
  756. }
  757. }
  758. //go test -run=xxx -bench=. -benchmem -test.benchtime=20s
  759. func BenchmarkCacheHeightLimit233(b *testing.B) {
  760. dbPath := path.Join(".aergo", "db")
  761. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  762. _ = os.MkdirAll(dbPath, 0711)
  763. }
  764. st := db.NewDB(db.LevelImpl, dbPath)
  765. smt := NewTrie(nil, Hasher, st)
  766. smt.CacheHeightLimit = 233
  767. benchmark10MAccounts10Ktps(smt, b)
  768. st.Close()
  769. os.RemoveAll(".aergo")
  770. }
  771. func BenchmarkCacheHeightLimit238(b *testing.B) {
  772. dbPath := path.Join(".aergo", "db")
  773. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  774. _ = os.MkdirAll(dbPath, 0711)
  775. }
  776. st := db.NewDB(db.LevelImpl, dbPath)
  777. smt := NewTrie(nil, Hasher, st)
  778. smt.CacheHeightLimit = 238
  779. benchmark10MAccounts10Ktps(smt, b)
  780. st.Close()
  781. os.RemoveAll(".aergo")
  782. }
  783. func BenchmarkCacheHeightLimit245(b *testing.B) {
  784. dbPath := path.Join(".aergo", "db")
  785. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  786. _ = os.MkdirAll(dbPath, 0711)
  787. }
  788. st := db.NewDB(db.LevelImpl, dbPath)
  789. smt := NewTrie(nil, Hasher, st)
  790. smt.CacheHeightLimit = 245
  791. benchmark10MAccounts10Ktps(smt, b)
  792. st.Close()
  793. os.RemoveAll(".aergo")
  794. }
  795. func getFreshData(size, length int) [][]byte {
  796. var data [][]byte
  797. for i := 0; i < size; i++ {
  798. key := make([]byte, 32)
  799. _, err := rand.Read(key)
  800. if err != nil {
  801. panic(err)
  802. }
  803. data = append(data, Hasher(key)[:length])
  804. }
  805. sort.Sort(DataArray(data))
  806. return data
  807. }