You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

261 lines
6.8 KiB

  1. package statedb
  2. import (
  3. "encoding/hex"
  4. "fmt"
  5. "io/ioutil"
  6. "math/big"
  7. "testing"
  8. ethCrypto "github.com/ethereum/go-ethereum/crypto"
  9. "github.com/hermeznetwork/hermez-node/common"
  10. "github.com/iden3/go-iden3-crypto/babyjub"
  11. "github.com/iden3/go-merkletree/db"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/stretchr/testify/require"
  14. )
  15. func newAccount(t *testing.T, i int) *common.Account {
  16. var sk babyjub.PrivateKey
  17. _, err := hex.Decode(sk[:], []byte("0001020304050607080900010203040506070809000102030405060708090001"))
  18. require.Nil(t, err)
  19. pk := sk.Public()
  20. key, err := ethCrypto.GenerateKey()
  21. require.Nil(t, err)
  22. address := ethCrypto.PubkeyToAddress(key.PublicKey)
  23. return &common.Account{
  24. TokenID: common.TokenID(i),
  25. Nonce: common.Nonce(i),
  26. Balance: big.NewInt(1000),
  27. PublicKey: pk,
  28. EthAddr: address,
  29. }
  30. }
  31. func TestStateDBWithoutMT(t *testing.T) {
  32. dir, err := ioutil.TempDir("", "tmpdb")
  33. require.Nil(t, err)
  34. sdb, err := NewStateDB(dir, false, 0)
  35. assert.Nil(t, err)
  36. // create test accounts
  37. var accounts []*common.Account
  38. for i := 0; i < 100; i++ {
  39. accounts = append(accounts, newAccount(t, i))
  40. }
  41. // get non-existing account, expecting an error
  42. _, err = sdb.GetAccount(common.Idx(1))
  43. assert.NotNil(t, err)
  44. assert.Equal(t, db.ErrNotFound, err)
  45. // add test accounts
  46. for i := 0; i < len(accounts); i++ {
  47. _, err = sdb.CreateAccount(common.Idx(i), accounts[i])
  48. assert.Nil(t, err)
  49. }
  50. for i := 0; i < len(accounts); i++ {
  51. accGetted, err := sdb.GetAccount(common.Idx(i))
  52. assert.Nil(t, err)
  53. assert.Equal(t, accounts[i], accGetted)
  54. }
  55. // try already existing idx and get error
  56. _, err = sdb.GetAccount(common.Idx(1)) // check that exist
  57. assert.Nil(t, err)
  58. _, err = sdb.CreateAccount(common.Idx(1), accounts[1]) // check that can not be created twice
  59. assert.NotNil(t, err)
  60. assert.Equal(t, ErrAccountAlreadyExists, err)
  61. // update accounts
  62. for i := 0; i < len(accounts); i++ {
  63. accounts[i].Nonce = accounts[i].Nonce + 1
  64. _, err = sdb.UpdateAccount(common.Idx(i), accounts[i])
  65. assert.Nil(t, err)
  66. }
  67. _, err = sdb.MTGetProof(common.Idx(1))
  68. assert.NotNil(t, err)
  69. assert.Equal(t, ErrStateDBWithoutMT, err)
  70. }
  71. func TestStateDBWithMT(t *testing.T) {
  72. dir, err := ioutil.TempDir("", "tmpdb")
  73. require.Nil(t, err)
  74. sdb, err := NewStateDB(dir, true, 32)
  75. assert.Nil(t, err)
  76. // create test accounts
  77. var accounts []*common.Account
  78. for i := 0; i < 20; i++ {
  79. accounts = append(accounts, newAccount(t, i))
  80. }
  81. // get non-existing account, expecting an error
  82. _, err = sdb.GetAccount(common.Idx(1))
  83. assert.NotNil(t, err)
  84. assert.Equal(t, db.ErrNotFound, err)
  85. // add test accounts
  86. for i := 0; i < len(accounts); i++ {
  87. _, err = sdb.CreateAccount(common.Idx(i), accounts[i])
  88. assert.Nil(t, err)
  89. }
  90. for i := 0; i < len(accounts); i++ {
  91. accGetted, err := sdb.GetAccount(common.Idx(i))
  92. assert.Nil(t, err)
  93. assert.Equal(t, accounts[i], accGetted)
  94. }
  95. // try already existing idx and get error
  96. _, err = sdb.GetAccount(common.Idx(1)) // check that exist
  97. assert.Nil(t, err)
  98. _, err = sdb.CreateAccount(common.Idx(1), accounts[1]) // check that can not be created twice
  99. assert.NotNil(t, err)
  100. assert.Equal(t, ErrAccountAlreadyExists, err)
  101. _, err = sdb.MTGetProof(common.Idx(1))
  102. assert.Nil(t, err)
  103. // update accounts
  104. for i := 0; i < len(accounts); i++ {
  105. accounts[i].Nonce = accounts[i].Nonce + 1
  106. _, err = sdb.UpdateAccount(common.Idx(i), accounts[i])
  107. assert.Nil(t, err)
  108. }
  109. a, err := sdb.GetAccount(common.Idx(1)) // check that account value has been updated
  110. assert.Nil(t, err)
  111. assert.Equal(t, accounts[1].Nonce, a.Nonce)
  112. }
  113. func TestCheckpoints(t *testing.T) {
  114. dir, err := ioutil.TempDir("", "sdb")
  115. require.Nil(t, err)
  116. sdb, err := NewStateDB(dir, true, 32)
  117. assert.Nil(t, err)
  118. // create test accounts
  119. var accounts []*common.Account
  120. for i := 0; i < 10; i++ {
  121. accounts = append(accounts, newAccount(t, i))
  122. }
  123. // add test accounts
  124. for i := 0; i < len(accounts); i++ {
  125. _, err = sdb.CreateAccount(common.Idx(i), accounts[i])
  126. assert.Nil(t, err)
  127. }
  128. // do checkpoints and check that currentBatch is correct
  129. err = sdb.MakeCheckpoint()
  130. assert.Nil(t, err)
  131. cb, err := sdb.GetCurrentBatch()
  132. assert.Nil(t, err)
  133. assert.Equal(t, common.BatchNum(1), cb)
  134. for i := 1; i < 10; i++ {
  135. err = sdb.MakeCheckpoint()
  136. assert.Nil(t, err)
  137. cb, err = sdb.GetCurrentBatch()
  138. assert.Nil(t, err)
  139. assert.Equal(t, common.BatchNum(i+1), cb)
  140. }
  141. // printCheckpoints(t, sdb.path)
  142. // reset checkpoint
  143. err = sdb.Reset(3)
  144. assert.Nil(t, err)
  145. // check that reset can be repeated (as there exist the 'current' and
  146. // 'BatchNum3', from where the 'current' is a copy)
  147. err = sdb.Reset(3)
  148. require.Nil(t, err)
  149. // check that currentBatch is as expected after Reset
  150. cb, err = sdb.GetCurrentBatch()
  151. assert.Nil(t, err)
  152. assert.Equal(t, common.BatchNum(3), cb)
  153. // advance one checkpoint and check that currentBatch is fine
  154. err = sdb.MakeCheckpoint()
  155. assert.Nil(t, err)
  156. cb, err = sdb.GetCurrentBatch()
  157. assert.Nil(t, err)
  158. assert.Equal(t, common.BatchNum(4), cb)
  159. err = sdb.DeleteCheckpoint(common.BatchNum(9))
  160. assert.Nil(t, err)
  161. err = sdb.DeleteCheckpoint(common.BatchNum(10))
  162. assert.Nil(t, err)
  163. err = sdb.DeleteCheckpoint(common.BatchNum(9)) // does not exist, should return err
  164. assert.NotNil(t, err)
  165. err = sdb.DeleteCheckpoint(common.BatchNum(11)) // does not exist, should return err
  166. assert.NotNil(t, err)
  167. // Create a LocalStateDB from the initial StateDB
  168. dirLocal, err := ioutil.TempDir("", "ldb")
  169. require.Nil(t, err)
  170. ldb, err := NewLocalStateDB(dirLocal, sdb, true, 32)
  171. assert.Nil(t, err)
  172. // get checkpoint 4 from sdb (StateDB) to ldb (LocalStateDB)
  173. err = ldb.Reset(4, true)
  174. assert.Nil(t, err)
  175. // check that currentBatch is 4 after the Reset
  176. cb, err = ldb.GetCurrentBatch()
  177. assert.Nil(t, err)
  178. assert.Equal(t, common.BatchNum(4), cb)
  179. // advance one checkpoint in ldb
  180. err = ldb.MakeCheckpoint()
  181. assert.Nil(t, err)
  182. cb, err = ldb.GetCurrentBatch()
  183. assert.Nil(t, err)
  184. assert.Equal(t, common.BatchNum(5), cb)
  185. // Create a 2nd LocalStateDB from the initial StateDB
  186. dirLocal2, err := ioutil.TempDir("", "ldb2")
  187. require.Nil(t, err)
  188. ldb2, err := NewLocalStateDB(dirLocal2, sdb, true, 32)
  189. assert.Nil(t, err)
  190. // get checkpoint 4 from sdb (StateDB) to ldb (LocalStateDB)
  191. err = ldb2.Reset(4, true)
  192. assert.Nil(t, err)
  193. // check that currentBatch is 4 after the Reset
  194. cb, err = ldb2.GetCurrentBatch()
  195. assert.Nil(t, err)
  196. assert.Equal(t, common.BatchNum(4), cb)
  197. // advance one checkpoint in ldb2
  198. err = ldb2.MakeCheckpoint()
  199. assert.Nil(t, err)
  200. cb, err = ldb2.GetCurrentBatch()
  201. assert.Nil(t, err)
  202. assert.Equal(t, common.BatchNum(5), cb)
  203. debug := false
  204. if debug {
  205. printCheckpoints(t, sdb.path)
  206. printCheckpoints(t, ldb.path)
  207. printCheckpoints(t, ldb2.path)
  208. }
  209. }
  210. func printCheckpoints(t *testing.T, path string) {
  211. files, err := ioutil.ReadDir(path)
  212. assert.Nil(t, err)
  213. fmt.Println(path)
  214. for _, f := range files {
  215. fmt.Println(" " + f.Name())
  216. }
  217. }