You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

950 lines
25 KiB

  1. /**
  2. * @file
  3. * @copyright defined in aergo/LICENSE.txt
  4. */
  5. package trie
  6. import (
  7. "bytes"
  8. "runtime"
  9. //"io/ioutil"
  10. "os"
  11. "path"
  12. "time"
  13. //"encoding/hex"
  14. "fmt"
  15. "math/rand"
  16. "sort"
  17. "testing"
  18. "github.com/p4u/asmt/db"
  19. )
  20. func TestTrieEmpty(t *testing.T) {
  21. smt := NewTrie(nil, Hasher, nil)
  22. if len(smt.Root) != 0 {
  23. t.Fatal("empty trie root hash not correct")
  24. }
  25. }
  26. func TestTrieUpdateAndGet(t *testing.T) {
  27. smt := NewTrie(nil, Hasher, nil)
  28. smt.atomicUpdate = false
  29. // Add data to empty trie
  30. keys := getFreshData(10, 32)
  31. values := getFreshData(10, 32)
  32. ch := make(chan mresult, 1)
  33. smt.update(smt.Root, keys, values, nil, 0, smt.TrieHeight, ch)
  34. res := <-ch
  35. root := res.update
  36. // Check all keys have been stored
  37. for i, key := range keys {
  38. value, _ := smt.get(root, key, nil, 0, smt.TrieHeight)
  39. if !bytes.Equal(values[i], value) {
  40. t.Fatal("value not updated")
  41. }
  42. }
  43. // Append to the trie
  44. newKeys := getFreshData(5, 32)
  45. newValues := getFreshData(5, 32)
  46. ch = make(chan mresult, 1)
  47. smt.update(root, newKeys, newValues, nil, 0, smt.TrieHeight, ch)
  48. res = <-ch
  49. newRoot := res.update
  50. if bytes.Equal(root, newRoot) {
  51. t.Fatal("trie not updated")
  52. }
  53. for i, newKey := range newKeys {
  54. newValue, _ := smt.get(newRoot, newKey, nil, 0, smt.TrieHeight)
  55. if !bytes.Equal(newValues[i], newValue) {
  56. t.Fatal("failed to get value")
  57. }
  58. }
  59. }
  60. func TestTrieAtomicUpdate(t *testing.T) {
  61. smt := NewTrie(nil, Hasher, nil)
  62. smt.CacheHeightLimit = 0
  63. keys := getFreshData(1, 32)
  64. values := getFreshData(1, 32)
  65. root, _ := smt.AtomicUpdate(keys, values)
  66. updatedNb := len(smt.db.updatedNodes)
  67. cacheNb := len(smt.db.liveCache)
  68. newvalues := getFreshData(1, 32)
  69. if _, err := smt.AtomicUpdate(keys, newvalues); err != nil {
  70. t.Fatal(err)
  71. }
  72. if len(smt.db.updatedNodes) != 2*updatedNb {
  73. t.Fatal("Atomic update doesnt store all tries")
  74. }
  75. if len(smt.db.liveCache) != cacheNb {
  76. t.Fatal("Cache size should remain the same")
  77. }
  78. // check keys of previous atomic update are accessible in
  79. // updated nodes with root.
  80. smt.atomicUpdate = false
  81. for i, key := range keys {
  82. value, _ := smt.get(root, key, nil, 0, smt.TrieHeight)
  83. if !bytes.Equal(values[i], value) {
  84. t.Fatal("failed to get value")
  85. }
  86. }
  87. }
  88. func TestTriePublicUpdateAndGet(t *testing.T) {
  89. smt := NewTrie(nil, Hasher, nil)
  90. smt.CacheHeightLimit = 0
  91. // Add data to empty trie
  92. keys := getFreshData(20, 32)
  93. values := getFreshData(20, 32)
  94. root, _ := smt.Update(keys, values)
  95. updatedNb := len(smt.db.updatedNodes)
  96. cacheNb := len(smt.db.liveCache)
  97. // Check all keys have been stored
  98. for i, key := range keys {
  99. value, _ := smt.Get(key)
  100. if !bytes.Equal(values[i], value) {
  101. t.Fatal("trie not updated")
  102. }
  103. }
  104. if !bytes.Equal(root, smt.Root) {
  105. t.Fatal("Root not stored")
  106. }
  107. newValues := getFreshData(20, 32)
  108. if _, err := smt.Update(keys, newValues); err != nil {
  109. t.Fatal(err)
  110. }
  111. if len(smt.db.updatedNodes) != updatedNb {
  112. t.Fatal("multiple updates don't actualise updated nodes")
  113. }
  114. if len(smt.db.liveCache) != cacheNb {
  115. t.Fatal("multiple updates don't actualise liveCache")
  116. }
  117. // Check all keys have been modified
  118. for i, key := range keys {
  119. value, _ := smt.Get(key)
  120. if !bytes.Equal(newValues[i], value) {
  121. t.Fatal("trie not updated")
  122. }
  123. }
  124. }
  125. func TestGetWithRoot(t *testing.T) {
  126. dbPath := t.TempDir()
  127. st := db.NewDB(db.LevelImpl, dbPath)
  128. smt := NewTrie(nil, Hasher, st)
  129. smt.CacheHeightLimit = 0
  130. // Add data to empty trie
  131. keys := getFreshData(20, 32)
  132. values := getFreshData(20, 32)
  133. root, _ := smt.Update(keys, values)
  134. // Check all keys have been stored
  135. for i, key := range keys {
  136. value, _ := smt.Get(key)
  137. if !bytes.Equal(values[i], value) {
  138. t.Fatal("trie not updated")
  139. }
  140. }
  141. if !bytes.Equal(root, smt.Root) {
  142. t.Fatal("Root not stored")
  143. }
  144. if err := smt.Commit(); err != nil {
  145. t.Fatal(err)
  146. }
  147. // Delete two values (0 and 1)
  148. if _, err := smt.Update([][]byte{keys[0], keys[1]}, [][]byte{DefaultLeaf, DefaultLeaf}); err != nil {
  149. t.Fatal(err)
  150. }
  151. // Change one value
  152. oldValue3 := make([]byte, 32)
  153. copy(oldValue3, values[3])
  154. values[3] = getFreshData(1, 32)[0]
  155. if _, err := smt.Update([][]byte{keys[3]}, [][]byte{values[3]}); err != nil {
  156. t.Fatal(err)
  157. }
  158. // Check root has been actually updated
  159. if bytes.Equal(smt.Root, root) {
  160. t.Fatal("root not updated")
  161. }
  162. // Get the third value with the new root
  163. v3, err := smt.GetWithRoot(keys[3], smt.Root)
  164. if err != nil {
  165. t.Fatal(err)
  166. }
  167. if !bytes.Equal(v3, values[3]) {
  168. t.Fatalf("GetWithRoot did not keep the value: %x != %x", v3, values[3])
  169. }
  170. // Get the third value with the old root
  171. v3, err = smt.GetWithRoot(keys[3], root)
  172. if err != nil {
  173. t.Fatal(err)
  174. }
  175. if !bytes.Equal(v3, oldValue3) {
  176. t.Fatalf("GetWithRoot did not keep the value: %x != %x", v3, oldValue3)
  177. }
  178. st.Close()
  179. }
  180. func TestTrieWalk(t *testing.T) {
  181. dbPath := t.TempDir()
  182. st := db.NewDB(db.LevelImpl, dbPath)
  183. smt := NewTrie(nil, Hasher, st)
  184. smt.CacheHeightLimit = 0
  185. // Add data to empty trie
  186. keys := getFreshData(20, 32)
  187. values := getFreshData(20, 32)
  188. root, _ := smt.Update(keys, values)
  189. // Check all keys have been stored
  190. for i, key := range keys {
  191. value, _ := smt.Get(key)
  192. if !bytes.Equal(values[i], value) {
  193. t.Fatal("trie not updated")
  194. }
  195. }
  196. if !bytes.Equal(root, smt.Root) {
  197. t.Fatal("Root not stored")
  198. }
  199. // Walk over the whole tree and compare the values
  200. i := 0
  201. if err := smt.Walk(nil, func(v *WalkResult) int32 {
  202. if !bytes.Equal(v.Value, values[i]) {
  203. t.Fatalf("walk value does not match %x != %x", v.Value, values[i])
  204. }
  205. if !bytes.Equal(v.Key, keys[i]) {
  206. t.Fatalf("walk key does not match %x != %x", v.Key, keys[i])
  207. }
  208. i++
  209. return 0
  210. }); err != nil {
  211. t.Fatal(err)
  212. }
  213. // Delete two values (0 and 3)
  214. if _, err := smt.Update([][]byte{keys[0], keys[3]}, [][]byte{DefaultLeaf, DefaultLeaf}); err != nil {
  215. t.Fatal(err)
  216. }
  217. // Delete two elements and walk again
  218. i = 1
  219. if err := smt.Walk(nil, func(v *WalkResult) int32 {
  220. if i == 3 {
  221. i++
  222. }
  223. if !bytes.Equal(v.Value, values[i]) {
  224. t.Fatalf("walk value does not match %x == %x\n", v.Value, values[i])
  225. }
  226. if !bytes.Equal(v.Key, keys[i]) {
  227. t.Fatalf("walk key does not match %x == %x\n", v.Key, keys[i])
  228. }
  229. i++
  230. return 0
  231. }); err != nil {
  232. t.Fatal(err)
  233. }
  234. // Add one new value to preivous deleted key
  235. values[3] = getFreshData(1, 32)[0]
  236. if _, err := smt.Update([][]byte{keys[3]}, [][]byte{values[3]}); err != nil {
  237. t.Fatal(err)
  238. }
  239. // Walk and check again
  240. i = 1
  241. if err := smt.Walk(nil, func(v *WalkResult) int32 {
  242. if !bytes.Equal(v.Value, values[i]) {
  243. t.Fatalf("walk value does not match %x != %x\n", v.Value, values[i])
  244. }
  245. if !bytes.Equal(v.Key, keys[i]) {
  246. t.Fatalf("walk key does not match %x != %x\n", v.Key, keys[i])
  247. }
  248. i++
  249. return 0
  250. }); err != nil {
  251. t.Fatal(err)
  252. }
  253. // Find a specific value and test stop
  254. i = 0
  255. if err := smt.Walk(nil, func(v *WalkResult) int32 {
  256. if bytes.Equal(v.Value, values[5]) {
  257. return 1
  258. }
  259. i++
  260. return 0
  261. }); err != nil {
  262. t.Fatal(err)
  263. }
  264. if i != 4 {
  265. t.Fatalf("Needed more iterations on walk than expected: %d != 4", i)
  266. }
  267. st.Close()
  268. }
  269. func TestTrieDelete(t *testing.T) {
  270. smt := NewTrie(nil, Hasher, nil)
  271. // Add data to empty trie
  272. keys := getFreshData(20, 32)
  273. values := getFreshData(20, 32)
  274. ch := make(chan mresult, 1)
  275. smt.update(smt.Root, keys, values, nil, 0, smt.TrieHeight, ch)
  276. result := <-ch
  277. root := result.update
  278. value, _ := smt.get(root, keys[0], nil, 0, smt.TrieHeight)
  279. if !bytes.Equal(values[0], value) {
  280. t.Fatal("trie not updated")
  281. }
  282. // Delete from trie
  283. // To delete a key, just set it's value to Default leaf hash.
  284. ch = make(chan mresult, 1)
  285. smt.update(root, keys[0:1], [][]byte{DefaultLeaf}, nil, 0, smt.TrieHeight, ch)
  286. result = <-ch
  287. updatedNb := len(smt.db.updatedNodes)
  288. newRoot := result.update
  289. newValue, _ := smt.get(newRoot, keys[0], nil, 0, smt.TrieHeight)
  290. if len(newValue) != 0 {
  291. t.Fatal("Failed to delete from trie")
  292. }
  293. // Remove deleted key from keys and check root with a clean trie.
  294. smt2 := NewTrie(nil, Hasher, nil)
  295. ch = make(chan mresult, 1)
  296. smt2.update(smt.Root, keys[1:], values[1:], nil, 0, smt.TrieHeight, ch)
  297. result = <-ch
  298. cleanRoot := result.update
  299. if !bytes.Equal(newRoot, cleanRoot) {
  300. t.Fatal("roots mismatch")
  301. }
  302. if len(smt2.db.updatedNodes) != updatedNb {
  303. t.Fatal("deleting doesn't actualise updated nodes")
  304. }
  305. //Empty the trie
  306. var newValues [][]byte
  307. for i := 0; i < 20; i++ {
  308. newValues = append(newValues, DefaultLeaf)
  309. }
  310. ch = make(chan mresult, 1)
  311. smt.update(root, keys, newValues, nil, 0, smt.TrieHeight, ch)
  312. result = <-ch
  313. root = result.update
  314. //if !bytes.Equal(smt.DefaultHash(256), root) {
  315. if len(root) != 0 {
  316. t.Fatal("empty trie root hash not correct")
  317. }
  318. // Test deleting an already empty key
  319. smt = NewTrie(nil, Hasher, nil)
  320. keys = getFreshData(2, 32)
  321. values = getFreshData(2, 32)
  322. root, _ = smt.Update(keys, values)
  323. key0 := make([]byte, 32)
  324. key1 := make([]byte, 32)
  325. if _, err := smt.Update([][]byte{key0, key1}, [][]byte{DefaultLeaf, DefaultLeaf}); err != nil {
  326. t.Fatal(err)
  327. }
  328. if !bytes.Equal(root, smt.Root) {
  329. t.Fatal("deleting a default key shouldnt' modify the tree")
  330. }
  331. }
  332. // test updating and deleting at the same time
  333. func TestTrieUpdateAndDelete(t *testing.T) {
  334. smt := NewTrie(nil, Hasher, nil)
  335. smt.CacheHeightLimit = 0
  336. key0 := make([]byte, 32)
  337. values := getFreshData(1, 32)
  338. root, _ := smt.Update([][]byte{key0}, values)
  339. cacheNb := len(smt.db.liveCache)
  340. updatedNb := len(smt.db.updatedNodes)
  341. smt.atomicUpdate = false
  342. _, _, k, v, isShortcut, _ := smt.loadChildren(root, smt.TrieHeight, 0, nil)
  343. if !isShortcut || !bytes.Equal(k[:HashLength], key0) || !bytes.Equal(v[:HashLength], values[0]) {
  344. t.Fatal("leaf shortcut didn't move up to root")
  345. }
  346. key1 := make([]byte, 32)
  347. // set the last bit
  348. bitSet(key1, 255)
  349. keys := [][]byte{key0, key1}
  350. values = [][]byte{DefaultLeaf, getFreshData(1, 32)[0]}
  351. root, _ = smt.Update(keys, values)
  352. if len(smt.db.liveCache) != cacheNb {
  353. t.Fatal("number of cache nodes not correct after delete")
  354. }
  355. if len(smt.db.updatedNodes) != updatedNb {
  356. t.Fatal("number of cache nodes not correct after delete")
  357. }
  358. smt.atomicUpdate = false
  359. _, _, k, v, isShortcut, _ = smt.loadChildren(root, smt.TrieHeight, 0, nil)
  360. if !isShortcut || !bytes.Equal(k[:HashLength], key1) || !bytes.Equal(v[:HashLength], values[1]) {
  361. t.Fatal("leaf shortcut didn't move up to root")
  362. }
  363. }
  364. func TestTrieMerkleProof(t *testing.T) {
  365. smt := NewTrie(nil, Hasher, nil)
  366. // Add data to empty trie
  367. keys := getFreshData(10, 32)
  368. values := getFreshData(10, 32)
  369. if _, err := smt.Update(keys, values); err != nil {
  370. t.Fatal(err)
  371. }
  372. for i, key := range keys {
  373. ap, _, k, v, _ := smt.MerkleProof(key)
  374. if !smt.VerifyInclusion(ap, key, values[i]) {
  375. t.Fatalf("failed to verify inclusion proof")
  376. }
  377. if !bytes.Equal(key, k) && !bytes.Equal(values[i], v) {
  378. t.Fatalf("merkle proof didnt return the correct key-value pair")
  379. }
  380. }
  381. emptyKey := Hasher([]byte("non-member"))
  382. ap, included, proofKey, proofValue, _ := smt.MerkleProof(emptyKey)
  383. if included {
  384. t.Fatalf("failed to verify non inclusion proof")
  385. }
  386. if !smt.VerifyNonInclusion(ap, emptyKey, proofValue, proofKey) {
  387. t.Fatalf("failed to verify non inclusion proof")
  388. }
  389. }
  390. func TestTrieMerkleProofCompressed(t *testing.T) {
  391. smt := NewTrie(nil, Hasher, nil)
  392. // Add data to empty trie
  393. keys := getFreshData(10, 32)
  394. values := getFreshData(10, 32)
  395. if _, err := smt.Update(keys, values); err != nil {
  396. t.Fatal(err)
  397. }
  398. for i, key := range keys {
  399. bitmap, ap, length, _, k, v, _ := smt.MerkleProofCompressed(key)
  400. if !smt.VerifyInclusionC(bitmap, key, values[i], ap, length) {
  401. t.Fatalf("failed to verify inclusion proof")
  402. }
  403. if !bytes.Equal(key, k) && !bytes.Equal(values[i], v) {
  404. t.Fatalf("merkle proof didnt return the correct key-value pair")
  405. }
  406. }
  407. emptyKey := Hasher([]byte("non-member"))
  408. bitmap, ap, length, included, proofKey, proofValue, _ := smt.MerkleProofCompressed(emptyKey)
  409. if included {
  410. t.Fatalf("failed to verify non inclusion proof")
  411. }
  412. if !smt.VerifyNonInclusionC(ap, length, bitmap, emptyKey, proofValue, proofKey) {
  413. t.Fatalf("failed to verify non inclusion proof")
  414. }
  415. }
  416. func TestTrieCommit(t *testing.T) {
  417. dbPath := path.Join(".aergo", "db")
  418. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  419. _ = os.MkdirAll(dbPath, 0711)
  420. }
  421. st := db.NewDB(db.LevelImpl, dbPath)
  422. smt := NewTrie(nil, Hasher, st)
  423. keys := getFreshData(10, 32)
  424. values := getFreshData(10, 32)
  425. if _, err := smt.Update(keys, values); err != nil {
  426. t.Fatal(err)
  427. }
  428. if err := smt.Commit(); err != nil {
  429. t.Fatal(err)
  430. }
  431. // liveCache is deleted so the key is fetched in badger db
  432. smt.db.liveCache = make(map[Hash][][]byte)
  433. for i, key := range keys {
  434. value, _ := smt.Get(key)
  435. if !bytes.Equal(value, values[i]) {
  436. t.Fatal("failed to get value in committed db")
  437. }
  438. }
  439. st.Close()
  440. os.RemoveAll(".aergo")
  441. }
  442. func TestTrieStageUpdates(t *testing.T) {
  443. dbPath := path.Join(".aergo", "db")
  444. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  445. _ = os.MkdirAll(dbPath, 0711)
  446. }
  447. st := db.NewDB(db.LevelImpl, dbPath)
  448. smt := NewTrie(nil, Hasher, st)
  449. keys := getFreshData(10, 32)
  450. values := getFreshData(10, 32)
  451. if _, err := smt.Update(keys, values); err != nil {
  452. t.Fatal(err)
  453. }
  454. txn := st.NewTx()
  455. smt.StageUpdates(txn.(DbTx))
  456. txn.Commit()
  457. // liveCache is deleted so the key is fetched in badger db
  458. smt.db.liveCache = make(map[Hash][][]byte)
  459. for i, key := range keys {
  460. value, _ := smt.Get(key)
  461. if !bytes.Equal(value, values[i]) {
  462. t.Fatal("failed to get value in committed db")
  463. }
  464. }
  465. st.Close()
  466. os.RemoveAll(".aergo")
  467. }
  468. func TestTrieRevert(t *testing.T) {
  469. dbPath := path.Join(".aergo", "db")
  470. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  471. _ = os.MkdirAll(dbPath, 0711)
  472. }
  473. st := db.NewDB(db.LevelImpl, dbPath)
  474. smt := NewTrie(nil, Hasher, st)
  475. // Edge case : Test that revert doesnt delete shortcut nodes
  476. // when moved to a different position in tree
  477. key0 := make([]byte, 32)
  478. key1 := make([]byte, 32)
  479. // setting the bit at 251 creates 2 shortcut batches at height 252
  480. bitSet(key1, 251)
  481. values := getFreshData(2, 32)
  482. root, _ := smt.Update([][]byte{key0}, [][]byte{values[0]})
  483. if err := smt.Commit(); err != nil {
  484. t.Fatal(err)
  485. }
  486. root2, _ := smt.Update([][]byte{key1}, [][]byte{values[1]})
  487. if err := smt.Commit(); err != nil {
  488. t.Fatal(err)
  489. }
  490. if err := smt.Revert(root); err != nil {
  491. t.Fatal(err)
  492. }
  493. if len(smt.db.Store.Get(root)) == 0 {
  494. t.Fatal("shortcut node shouldnt be deleted by revert")
  495. }
  496. if len(smt.db.Store.Get(root2)) != 0 {
  497. t.Fatal("reverted root should have been deleted")
  498. }
  499. key1 = make([]byte, 32)
  500. // setting the bit at 255 stores the keys as the tip
  501. bitSet(key1, 255)
  502. if _, err := smt.Update([][]byte{key1}, [][]byte{values[1]}); err != nil {
  503. t.Fatal(err)
  504. }
  505. if err := smt.Commit(); err != nil {
  506. t.Fatal(err)
  507. }
  508. if err := smt.Revert(root); err != nil {
  509. t.Fatal(err)
  510. }
  511. if len(smt.db.Store.Get(root)) == 0 {
  512. t.Fatal("shortcut node shouldnt be deleted by revert")
  513. }
  514. // Test all nodes are reverted in the usual case
  515. // Add data to empty trie
  516. keys := getFreshData(10, 32)
  517. values = getFreshData(10, 32)
  518. root, _ = smt.Update(keys, values)
  519. if err := smt.Commit(); err != nil {
  520. t.Fatal(err)
  521. }
  522. // Update the values
  523. newValues := getFreshData(10, 32)
  524. if _, err := smt.Update(keys, newValues); err != nil {
  525. t.Fatal(err)
  526. }
  527. updatedNodes1 := smt.db.updatedNodes
  528. if err := smt.Commit(); err != nil {
  529. t.Fatal(err)
  530. }
  531. newKeys := getFreshData(10, 32)
  532. newValues = getFreshData(10, 32)
  533. if _, err := smt.Update(newKeys, newValues); err != nil {
  534. t.Fatal(err)
  535. }
  536. updatedNodes2 := smt.db.updatedNodes
  537. if err := smt.Commit(); err != nil {
  538. t.Fatal(err)
  539. }
  540. if err := smt.Revert(root); err != nil {
  541. t.Fatal(err)
  542. }
  543. if !bytes.Equal(smt.Root, root) {
  544. t.Fatal("revert failed")
  545. }
  546. if len(smt.pastTries) != 2 { // contains empty trie + reverted trie
  547. t.Fatal("past tries not updated after revert")
  548. }
  549. // Check all keys have been reverted
  550. for i, key := range keys {
  551. value, _ := smt.Get(key)
  552. if !bytes.Equal(values[i], value) {
  553. t.Fatal("revert failed, values not updated")
  554. }
  555. }
  556. if len(smt.db.liveCache) != 0 {
  557. t.Fatal("live cache not reset after revert")
  558. }
  559. // Check all reverted nodes have been deleted
  560. for node := range updatedNodes2 {
  561. if len(smt.db.Store.Get(node[:])) != 0 {
  562. t.Fatal("nodes not deleted from database", node)
  563. }
  564. }
  565. for node := range updatedNodes1 {
  566. if len(smt.db.Store.Get(node[:])) != 0 {
  567. t.Fatal("nodes not deleted from database", node)
  568. }
  569. }
  570. st.Close()
  571. os.RemoveAll(".aergo")
  572. }
  573. func TestTrieRaisesError(t *testing.T) {
  574. dbPath := path.Join(".aergo", "db")
  575. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  576. _ = os.MkdirAll(dbPath, 0711)
  577. }
  578. st := db.NewDB(db.LevelImpl, dbPath)
  579. smt := NewTrie(nil, Hasher, st)
  580. // Add data to empty trie
  581. keys := getFreshData(10, 32)
  582. values := getFreshData(10, 32)
  583. if _, err := smt.Update(keys, values); err != nil {
  584. t.Fatal(err)
  585. }
  586. smt.db.liveCache = make(map[Hash][][]byte)
  587. smt.db.updatedNodes = make(map[Hash][][]byte)
  588. // Check errors are raised is a keys is not in cache nore db
  589. for _, key := range keys {
  590. _, err := smt.Get(key)
  591. if err == nil {
  592. t.Fatal("Error not created if database doesnt have a node")
  593. }
  594. }
  595. _, _, _, _, _, _, err := smt.MerkleProofCompressed(keys[0])
  596. if err == nil {
  597. t.Fatal("Error not created if database doesnt have a node")
  598. }
  599. _, err = smt.Update(keys, values)
  600. if err == nil {
  601. t.Fatal("Error not created if database doesnt have a node")
  602. }
  603. st.Close()
  604. os.RemoveAll(".aergo")
  605. smt = NewTrie(nil, Hasher, nil)
  606. err = smt.Commit()
  607. if err == nil {
  608. t.Fatal("Error not created if database not connected")
  609. }
  610. smt.db.liveCache = make(map[Hash][][]byte)
  611. smt.atomicUpdate = false
  612. _, _, _, _, _, err = smt.loadChildren(make([]byte, 32), smt.TrieHeight, 0, nil)
  613. if err == nil {
  614. t.Fatal("Error not created if database not connected")
  615. }
  616. err = smt.LoadCache(make([]byte, 32))
  617. if err == nil {
  618. t.Fatal("Error not created if database not connected")
  619. }
  620. }
  621. func TestTrieLoadCache(t *testing.T) {
  622. dbPath := path.Join(".aergo", "db")
  623. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  624. _ = os.MkdirAll(dbPath, 0711)
  625. }
  626. st := db.NewDB(db.LevelImpl, dbPath)
  627. smt := NewTrie(nil, Hasher, st)
  628. // Test size of cache
  629. smt.CacheHeightLimit = 0
  630. key0 := make([]byte, 32)
  631. key1 := make([]byte, 32)
  632. bitSet(key1, 255)
  633. values := getFreshData(2, 32)
  634. if _, err := smt.Update([][]byte{key0, key1}, values); err != nil {
  635. t.Fatal(err)
  636. }
  637. if len(smt.db.liveCache) != 66 {
  638. // the nodes are at the tip, so 64 + 2 = 66
  639. t.Fatal("cache size incorrect")
  640. }
  641. // Add data to empty trie
  642. keys := getFreshData(10, 32)
  643. values = getFreshData(10, 32)
  644. if _, err := smt.Update(keys, values); err != nil {
  645. t.Fatal(err)
  646. }
  647. if err := smt.Commit(); err != nil {
  648. t.Fatal(err)
  649. }
  650. // Simulate node restart by deleting and loading cache
  651. cacheSize := len(smt.db.liveCache)
  652. smt.db.liveCache = make(map[Hash][][]byte)
  653. err := smt.LoadCache(smt.Root)
  654. if err != nil {
  655. t.Fatal(err)
  656. }
  657. if cacheSize != len(smt.db.liveCache) {
  658. t.Fatal("Cache loading from db incorrect")
  659. }
  660. st.Close()
  661. os.RemoveAll(".aergo")
  662. }
  663. func TestHeight0LeafShortcut(t *testing.T) {
  664. keySize := 32
  665. smt := NewTrie(nil, Hasher, nil)
  666. // Add 2 sibling keys that will be stored at height 0
  667. key0 := make([]byte, keySize)
  668. key1 := make([]byte, keySize)
  669. bitSet(key1, keySize*8-1)
  670. keys := [][]byte{key0, key1}
  671. values := getFreshData(2, 32)
  672. if _, err := smt.Update(keys, values); err != nil {
  673. t.Fatal(err)
  674. }
  675. updatedNb := len(smt.db.updatedNodes)
  676. // Check all keys have been stored
  677. for i, key := range keys {
  678. value, _ := smt.Get(key)
  679. if !bytes.Equal(values[i], value) {
  680. t.Fatal("trie not updated")
  681. }
  682. }
  683. bitmap, ap, length, _, k, v, err := smt.MerkleProofCompressed(key1)
  684. if err != nil {
  685. t.Fatal(err)
  686. }
  687. if !bytes.Equal(key1, k) && !bytes.Equal(values[1], v) {
  688. t.Fatalf("merkle proof didnt return the correct key-value pair")
  689. }
  690. if length != smt.TrieHeight {
  691. t.Fatal("proof should have length equal to trie height for a leaf shortcut")
  692. }
  693. if !smt.VerifyInclusionC(bitmap, key1, values[1], ap, length) {
  694. t.Fatal("failed to verify inclusion proof")
  695. }
  696. // Delete one key and check that the remaining one moved up to the root of the tree
  697. newRoot, _ := smt.AtomicUpdate(keys[0:1], [][]byte{DefaultLeaf})
  698. // Nb of updated nodes remains same because the new shortcut root was already stored at height 0.
  699. if len(smt.db.updatedNodes) != updatedNb {
  700. fmt.Println(len(smt.db.updatedNodes), updatedNb)
  701. t.Fatal("number of cache nodes not correct after delete")
  702. }
  703. smt.atomicUpdate = false
  704. _, _, k, v, isShortcut, err := smt.loadChildren(newRoot, smt.TrieHeight, 0, nil)
  705. if err != nil {
  706. t.Fatal(err)
  707. }
  708. if !isShortcut || !bytes.Equal(k[:HashLength], key1) || !bytes.Equal(v[:HashLength], values[1]) {
  709. t.Fatal("leaf shortcut didn't move up to root")
  710. }
  711. _, _, length, _, k, v, _ = smt.MerkleProofCompressed(key1)
  712. if length != 0 {
  713. t.Fatal("proof should have length equal to trie height for a leaf shortcut")
  714. }
  715. if !bytes.Equal(key1, k) && !bytes.Equal(values[1], v) {
  716. t.Fatalf("merkle proof didnt return the correct key-value pair")
  717. }
  718. }
  719. func TestStash(t *testing.T) {
  720. dbPath := path.Join(".aergo", "db")
  721. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  722. _ = os.MkdirAll(dbPath, 0711)
  723. }
  724. st := db.NewDB(db.LevelImpl, dbPath)
  725. smt := NewTrie(nil, Hasher, st)
  726. // Add data to empty trie
  727. keys := getFreshData(20, 32)
  728. values := getFreshData(20, 32)
  729. root, _ := smt.Update(keys, values)
  730. cacheSize := len(smt.db.liveCache)
  731. if err := smt.Commit(); err != nil {
  732. t.Fatal(err)
  733. }
  734. if len(smt.pastTries) != 1 {
  735. t.Fatal("Past tries not updated after commit")
  736. }
  737. values = getFreshData(20, 32)
  738. if _, err := smt.Update(keys, values); err != nil {
  739. t.Fatal(err)
  740. }
  741. if err := smt.Stash(true); err != nil {
  742. t.Fatal(err)
  743. }
  744. if len(smt.pastTries) != 1 {
  745. t.Fatal("Past tries not updated after commit")
  746. }
  747. if !bytes.Equal(smt.Root, root) {
  748. t.Fatal("Trie not rolled back")
  749. }
  750. if len(smt.db.updatedNodes) != 0 {
  751. t.Fatal("Trie not rolled back")
  752. }
  753. if len(smt.db.liveCache) != cacheSize {
  754. t.Fatal("Trie not rolled back")
  755. }
  756. keys = getFreshData(20, 32)
  757. values = getFreshData(20, 32)
  758. if _, err := smt.AtomicUpdate(keys, values); err != nil {
  759. t.Fatal(err)
  760. }
  761. values = getFreshData(20, 32)
  762. if _, err := smt.AtomicUpdate(keys, values); err != nil {
  763. t.Fatal(err)
  764. }
  765. if len(smt.pastTries) != 3 {
  766. t.Fatal("Past tries not updated after commit")
  767. }
  768. if err := smt.Stash(true); err != nil {
  769. t.Fatal(err)
  770. }
  771. if !bytes.Equal(smt.Root, root) {
  772. t.Fatal("Trie not rolled back")
  773. }
  774. if len(smt.db.updatedNodes) != 0 {
  775. t.Fatal("Trie not rolled back")
  776. }
  777. if len(smt.db.liveCache) != cacheSize {
  778. t.Fatal("Trie not rolled back")
  779. }
  780. if len(smt.pastTries) != 1 {
  781. t.Fatal("Past tries not updated after commit")
  782. }
  783. st.Close()
  784. os.RemoveAll(".aergo")
  785. }
  786. func benchmark10MAccounts10Ktps(smt *Trie, b *testing.B) {
  787. //b.ReportAllocs()
  788. keys := getFreshData(100, 32)
  789. values := getFreshData(100, 32)
  790. if _, err := smt.Update(keys, values); err != nil {
  791. b.Fatal(err)
  792. }
  793. fmt.Println("\nLoading b.N x 1000 accounts")
  794. for i := 0; i < b.N; i++ {
  795. newkeys := getFreshData(1000, 32)
  796. newvalues := getFreshData(1000, 32)
  797. start := time.Now()
  798. if _, err := smt.Update(newkeys, newvalues); err != nil {
  799. b.Fatal(err)
  800. }
  801. end := time.Now()
  802. if err := smt.Commit(); err != nil {
  803. b.Fatal(err)
  804. }
  805. end2 := time.Now()
  806. for j, key := range newkeys {
  807. val, _ := smt.Get(key)
  808. if !bytes.Equal(val, newvalues[j]) {
  809. b.Fatal("new key not included")
  810. }
  811. }
  812. end3 := time.Now()
  813. elapsed := end.Sub(start)
  814. elapsed2 := end2.Sub(end)
  815. elapsed3 := end3.Sub(end2)
  816. var m runtime.MemStats
  817. runtime.ReadMemStats(&m)
  818. fmt.Println(i, " : update time : ", elapsed, "commit time : ", elapsed2,
  819. "\n1000 Get time : ", elapsed3,
  820. "\ndb read : ", smt.LoadDbCounter, " cache read : ", smt.LoadCacheCounter,
  821. "\ncache size : ", len(smt.db.liveCache),
  822. "\nRAM : ", m.Sys/1024/1024, " MiB")
  823. }
  824. }
  825. //go test -run=xxx -bench=. -benchmem -test.benchtime=20s
  826. func BenchmarkCacheHeightLimit233(b *testing.B) {
  827. dbPath := path.Join(".aergo", "db")
  828. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  829. _ = os.MkdirAll(dbPath, 0711)
  830. }
  831. st := db.NewDB(db.LevelImpl, dbPath)
  832. smt := NewTrie(nil, Hasher, st)
  833. smt.CacheHeightLimit = 233
  834. benchmark10MAccounts10Ktps(smt, b)
  835. st.Close()
  836. os.RemoveAll(".aergo")
  837. }
  838. func BenchmarkCacheHeightLimit238(b *testing.B) {
  839. dbPath := path.Join(".aergo", "db")
  840. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  841. _ = os.MkdirAll(dbPath, 0711)
  842. }
  843. st := db.NewDB(db.LevelImpl, dbPath)
  844. smt := NewTrie(nil, Hasher, st)
  845. smt.CacheHeightLimit = 238
  846. benchmark10MAccounts10Ktps(smt, b)
  847. st.Close()
  848. os.RemoveAll(".aergo")
  849. }
  850. func BenchmarkCacheHeightLimit245(b *testing.B) {
  851. dbPath := path.Join(".aergo", "db")
  852. if _, err := os.Stat(dbPath); os.IsNotExist(err) {
  853. _ = os.MkdirAll(dbPath, 0711)
  854. }
  855. st := db.NewDB(db.LevelImpl, dbPath)
  856. smt := NewTrie(nil, Hasher, st)
  857. smt.CacheHeightLimit = 245
  858. benchmark10MAccounts10Ktps(smt, b)
  859. st.Close()
  860. os.RemoveAll(".aergo")
  861. }
  862. func getFreshData(size, length int) [][]byte {
  863. var data [][]byte
  864. for i := 0; i < size; i++ {
  865. key := make([]byte, 32)
  866. _, err := rand.Read(key)
  867. if err != nil {
  868. panic(err)
  869. }
  870. data = append(data, Hasher(key)[:length])
  871. }
  872. sort.Sort(DataArray(data))
  873. return data
  874. }