You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

270 lines
7.0 KiB

  1. package statedb
  2. import (
  3. "encoding/hex"
  4. "fmt"
  5. "io/ioutil"
  6. "math/big"
  7. "testing"
  8. ethCrypto "github.com/ethereum/go-ethereum/crypto"
  9. "github.com/hermeznetwork/hermez-node/common"
  10. "github.com/iden3/go-iden3-crypto/babyjub"
  11. "github.com/iden3/go-merkletree/db"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/stretchr/testify/require"
  14. )
  15. func newAccount(t *testing.T, i int) *common.Account {
  16. var sk babyjub.PrivateKey
  17. _, err := hex.Decode(sk[:], []byte("0001020304050607080900010203040506070809000102030405060708090001"))
  18. require.Nil(t, err)
  19. pk := sk.Public()
  20. key, err := ethCrypto.GenerateKey()
  21. require.Nil(t, err)
  22. address := ethCrypto.PubkeyToAddress(key.PublicKey)
  23. return &common.Account{
  24. TokenID: common.TokenID(i),
  25. Nonce: uint64(i),
  26. Balance: big.NewInt(1000),
  27. PublicKey: pk,
  28. EthAddr: address,
  29. }
  30. }
  31. func TestStateDBWithoutMT(t *testing.T) {
  32. dir, err := ioutil.TempDir("", "tmpdb")
  33. require.Nil(t, err)
  34. sdb, err := NewStateDB(dir, false, 0)
  35. assert.Nil(t, err)
  36. // create test accounts
  37. var accounts []*common.Account
  38. for i := 0; i < 100; i++ {
  39. accounts = append(accounts, newAccount(t, i))
  40. }
  41. // get non-existing account, expecting an error
  42. _, err = sdb.GetAccount(common.Idx(1))
  43. assert.NotNil(t, err)
  44. assert.Equal(t, db.ErrNotFound, err)
  45. // add test accounts
  46. for i := 0; i < len(accounts); i++ {
  47. err = sdb.CreateAccount(common.Idx(i), accounts[i])
  48. assert.Nil(t, err)
  49. }
  50. for i := 0; i < len(accounts); i++ {
  51. accGetted, err := sdb.GetAccount(common.Idx(i))
  52. assert.Nil(t, err)
  53. assert.Equal(t, accounts[i], accGetted)
  54. }
  55. // try already existing idx and get error
  56. _, err = sdb.GetAccount(common.Idx(1)) // check that exist
  57. assert.Nil(t, err)
  58. err = sdb.CreateAccount(common.Idx(1), accounts[1]) // check that can not be created twice
  59. assert.NotNil(t, err)
  60. assert.Equal(t, ErrAccountAlreadyExists, err)
  61. // update accounts
  62. for i := 0; i < len(accounts); i++ {
  63. accounts[i].Nonce = accounts[i].Nonce + 1
  64. err = sdb.UpdateAccount(common.Idx(i), accounts[i])
  65. assert.Nil(t, err)
  66. }
  67. // check that can not call MerkleTree methods of the StateDB
  68. _, err = sdb.MTCreateAccount(common.Idx(1), accounts[1])
  69. assert.NotNil(t, err)
  70. assert.Equal(t, ErrStateDBWithoutMT, err)
  71. _, err = sdb.MTUpdateAccount(common.Idx(1), accounts[1])
  72. assert.NotNil(t, err)
  73. assert.Equal(t, ErrStateDBWithoutMT, err)
  74. _, err = sdb.MTGetProof(common.Idx(1))
  75. assert.NotNil(t, err)
  76. assert.Equal(t, ErrStateDBWithoutMT, err)
  77. }
  78. func TestStateDBWithMT(t *testing.T) {
  79. dir, err := ioutil.TempDir("", "tmpdb")
  80. require.Nil(t, err)
  81. sdb, err := NewStateDB(dir, true, 32)
  82. assert.Nil(t, err)
  83. // create test accounts
  84. var accounts []*common.Account
  85. for i := 0; i < 20; i++ {
  86. accounts = append(accounts, newAccount(t, i))
  87. }
  88. // get non-existing account, expecting an error
  89. _, err = sdb.GetAccount(common.Idx(1))
  90. assert.NotNil(t, err)
  91. assert.Equal(t, db.ErrNotFound, err)
  92. // add test accounts
  93. for i := 0; i < len(accounts); i++ {
  94. _, err = sdb.MTCreateAccount(common.Idx(i), accounts[i])
  95. assert.Nil(t, err)
  96. }
  97. for i := 0; i < len(accounts); i++ {
  98. accGetted, err := sdb.GetAccount(common.Idx(i))
  99. assert.Nil(t, err)
  100. assert.Equal(t, accounts[i], accGetted)
  101. }
  102. // try already existing idx and get error
  103. _, err = sdb.GetAccount(common.Idx(1)) // check that exist
  104. assert.Nil(t, err)
  105. _, err = sdb.MTCreateAccount(common.Idx(1), accounts[1]) // check that can not be created twice
  106. assert.NotNil(t, err)
  107. assert.Equal(t, ErrAccountAlreadyExists, err)
  108. _, err = sdb.MTGetProof(common.Idx(1))
  109. assert.Nil(t, err)
  110. // update accounts
  111. for i := 0; i < len(accounts); i++ {
  112. accounts[i].Nonce = accounts[i].Nonce + 1
  113. _, err = sdb.MTUpdateAccount(common.Idx(i), accounts[i])
  114. assert.Nil(t, err)
  115. }
  116. a, err := sdb.GetAccount(common.Idx(1)) // check that account value has been updated
  117. assert.Nil(t, err)
  118. assert.Equal(t, accounts[1].Nonce, a.Nonce)
  119. }
  120. func TestCheckpoints(t *testing.T) {
  121. dir, err := ioutil.TempDir("", "sdb")
  122. require.Nil(t, err)
  123. sdb, err := NewStateDB(dir, true, 32)
  124. assert.Nil(t, err)
  125. // create test accounts
  126. var accounts []*common.Account
  127. for i := 0; i < 10; i++ {
  128. accounts = append(accounts, newAccount(t, i))
  129. }
  130. // add test accounts
  131. for i := 0; i < len(accounts); i++ {
  132. _, err = sdb.MTCreateAccount(common.Idx(i), accounts[i])
  133. assert.Nil(t, err)
  134. }
  135. // do checkpoints and check that currentBatch is correct
  136. err = sdb.MakeCheckpoint()
  137. assert.Nil(t, err)
  138. cb, err := sdb.GetCurrentBatch()
  139. assert.Nil(t, err)
  140. assert.Equal(t, uint64(1), cb)
  141. for i := 1; i < 10; i++ {
  142. err = sdb.MakeCheckpoint()
  143. assert.Nil(t, err)
  144. cb, err = sdb.GetCurrentBatch()
  145. assert.Nil(t, err)
  146. assert.Equal(t, uint64(i+1), cb)
  147. }
  148. // printCheckpoints(t, sdb.path)
  149. // reset checkpoint
  150. err = sdb.Reset(3)
  151. assert.Nil(t, err)
  152. // check that reset can be repeated (as there exist the 'current' and
  153. // 'BatchNum3', from where the 'current' is a copy)
  154. err = sdb.Reset(3)
  155. require.Nil(t, err)
  156. // check that currentBatch is as expected after Reset
  157. cb, err = sdb.GetCurrentBatch()
  158. assert.Nil(t, err)
  159. assert.Equal(t, uint64(3), cb)
  160. // advance one checkpoint and check that currentBatch is fine
  161. err = sdb.MakeCheckpoint()
  162. assert.Nil(t, err)
  163. cb, err = sdb.GetCurrentBatch()
  164. assert.Nil(t, err)
  165. assert.Equal(t, uint64(4), cb)
  166. err = sdb.DeleteCheckpoint(uint64(9))
  167. assert.Nil(t, err)
  168. err = sdb.DeleteCheckpoint(uint64(10))
  169. assert.Nil(t, err)
  170. err = sdb.DeleteCheckpoint(uint64(9)) // does not exist, should return err
  171. assert.NotNil(t, err)
  172. err = sdb.DeleteCheckpoint(uint64(11)) // does not exist, should return err
  173. assert.NotNil(t, err)
  174. // Create a LocalStateDB from the initial StateDB
  175. dirLocal, err := ioutil.TempDir("", "ldb")
  176. require.Nil(t, err)
  177. ldb, err := NewLocalStateDB(dirLocal, sdb, true, 32)
  178. assert.Nil(t, err)
  179. // get checkpoint 4 from sdb (StateDB) to ldb (LocalStateDB)
  180. err = ldb.Reset(4, true)
  181. assert.Nil(t, err)
  182. // check that currentBatch is 4 after the Reset
  183. cb, err = ldb.GetCurrentBatch()
  184. assert.Nil(t, err)
  185. assert.Equal(t, uint64(4), cb)
  186. // advance one checkpoint in ldb
  187. err = ldb.MakeCheckpoint()
  188. assert.Nil(t, err)
  189. cb, err = ldb.GetCurrentBatch()
  190. assert.Nil(t, err)
  191. assert.Equal(t, uint64(5), cb)
  192. // Create a 2nd LocalStateDB from the initial StateDB
  193. dirLocal2, err := ioutil.TempDir("", "ldb2")
  194. require.Nil(t, err)
  195. ldb2, err := NewLocalStateDB(dirLocal2, sdb, true, 32)
  196. assert.Nil(t, err)
  197. // get checkpoint 4 from sdb (StateDB) to ldb (LocalStateDB)
  198. err = ldb2.Reset(4, true)
  199. assert.Nil(t, err)
  200. // check that currentBatch is 4 after the Reset
  201. cb, err = ldb2.GetCurrentBatch()
  202. assert.Nil(t, err)
  203. assert.Equal(t, uint64(4), cb)
  204. // advance one checkpoint in ldb2
  205. err = ldb2.MakeCheckpoint()
  206. assert.Nil(t, err)
  207. cb, err = ldb2.GetCurrentBatch()
  208. assert.Nil(t, err)
  209. assert.Equal(t, uint64(5), cb)
  210. debug := false
  211. if debug {
  212. printCheckpoints(t, sdb.path)
  213. printCheckpoints(t, ldb.path)
  214. printCheckpoints(t, ldb2.path)
  215. }
  216. }
  217. func printCheckpoints(t *testing.T, path string) {
  218. files, err := ioutil.ReadDir(path)
  219. assert.Nil(t, err)
  220. fmt.Println(path)
  221. for _, f := range files {
  222. fmt.Println(" " + f.Name())
  223. }
  224. }