You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

471 lines
13 KiB

  1. package statedb
  2. import (
  3. "encoding/hex"
  4. "fmt"
  5. "io/ioutil"
  6. "math/big"
  7. "os"
  8. "strings"
  9. "testing"
  10. ethCommon "github.com/ethereum/go-ethereum/common"
  11. ethCrypto "github.com/ethereum/go-ethereum/crypto"
  12. "github.com/hermeznetwork/hermez-node/common"
  13. "github.com/hermeznetwork/hermez-node/log"
  14. "github.com/hermeznetwork/tracerr"
  15. "github.com/iden3/go-iden3-crypto/babyjub"
  16. "github.com/iden3/go-merkletree/db"
  17. "github.com/stretchr/testify/assert"
  18. "github.com/stretchr/testify/require"
  19. )
  20. func newAccount(t *testing.T, i int) *common.Account {
  21. var sk babyjub.PrivateKey
  22. _, err := hex.Decode(sk[:], []byte("0001020304050607080900010203040506070809000102030405060708090001"))
  23. require.NoError(t, err)
  24. pk := sk.Public()
  25. key, err := ethCrypto.GenerateKey()
  26. require.NoError(t, err)
  27. address := ethCrypto.PubkeyToAddress(key.PublicKey)
  28. return &common.Account{
  29. Idx: common.Idx(256 + i),
  30. TokenID: common.TokenID(i),
  31. Nonce: common.Nonce(i),
  32. Balance: big.NewInt(1000),
  33. PublicKey: pk.Compress(),
  34. EthAddr: address,
  35. }
  36. }
  37. func TestNewStateDBIntermediateState(t *testing.T) {
  38. dir, err := ioutil.TempDir("", "tmpdb")
  39. require.NoError(t, err)
  40. defer assert.NoError(t, os.RemoveAll(dir))
  41. chainID := uint16(0)
  42. sdb, err := NewStateDB(dir, TypeTxSelector, 0, chainID)
  43. assert.NoError(t, err)
  44. // test values
  45. k0 := []byte("testkey0")
  46. k1 := []byte("testkey1")
  47. v0 := []byte("testvalue0")
  48. v1 := []byte("testvalue1")
  49. // store some data
  50. tx, err := sdb.db.NewTx()
  51. assert.NoError(t, err)
  52. err = tx.Put(k0, v0)
  53. assert.NoError(t, err)
  54. err = tx.Commit()
  55. assert.NoError(t, err)
  56. v, err := sdb.db.Get(k0)
  57. assert.NoError(t, err)
  58. assert.Equal(t, v0, v)
  59. // call NewStateDB which should get the db at the last checkpoint state
  60. // executing a Reset (discarding the last 'testkey0'&'testvalue0' data)
  61. sdb, err = NewStateDB(dir, TypeTxSelector, 0, chainID)
  62. assert.NoError(t, err)
  63. v, err = sdb.db.Get(k0)
  64. assert.NotNil(t, err)
  65. assert.Equal(t, db.ErrNotFound, tracerr.Unwrap(err))
  66. assert.Nil(t, v)
  67. // store the same data from the beginning that has ben lost since last NewStateDB
  68. tx, err = sdb.db.NewTx()
  69. assert.NoError(t, err)
  70. err = tx.Put(k0, v0)
  71. assert.NoError(t, err)
  72. err = tx.Commit()
  73. assert.NoError(t, err)
  74. v, err = sdb.db.Get(k0)
  75. assert.NoError(t, err)
  76. assert.Equal(t, v0, v)
  77. // make checkpoints with the current state
  78. bn, err := sdb.GetCurrentBatch()
  79. assert.NoError(t, err)
  80. assert.Equal(t, common.BatchNum(0), bn)
  81. err = sdb.MakeCheckpoint()
  82. assert.NoError(t, err)
  83. bn, err = sdb.GetCurrentBatch()
  84. assert.NoError(t, err)
  85. assert.Equal(t, common.BatchNum(1), bn)
  86. // write more data
  87. tx, err = sdb.db.NewTx()
  88. assert.NoError(t, err)
  89. err = tx.Put(k1, v1)
  90. assert.NoError(t, err)
  91. err = tx.Commit()
  92. assert.NoError(t, err)
  93. v, err = sdb.db.Get(k1)
  94. assert.NoError(t, err)
  95. assert.Equal(t, v1, v)
  96. // call NewStateDB which should get the db at the last checkpoint state
  97. // executing a Reset (discarding the last 'testkey1'&'testvalue1' data)
  98. sdb, err = NewStateDB(dir, TypeTxSelector, 0, chainID)
  99. assert.NoError(t, err)
  100. v, err = sdb.db.Get(k0)
  101. assert.NoError(t, err)
  102. assert.Equal(t, v0, v)
  103. v, err = sdb.db.Get(k1)
  104. assert.NotNil(t, err)
  105. assert.Equal(t, db.ErrNotFound, tracerr.Unwrap(err))
  106. assert.Nil(t, v)
  107. }
  108. func TestStateDBWithoutMT(t *testing.T) {
  109. dir, err := ioutil.TempDir("", "tmpdb")
  110. require.NoError(t, err)
  111. defer assert.NoError(t, os.RemoveAll(dir))
  112. chainID := uint16(0)
  113. sdb, err := NewStateDB(dir, TypeTxSelector, 0, chainID)
  114. assert.NoError(t, err)
  115. // create test accounts
  116. var accounts []*common.Account
  117. for i := 0; i < 4; i++ {
  118. accounts = append(accounts, newAccount(t, i))
  119. }
  120. // get non-existing account, expecting an error
  121. unexistingAccount := common.Idx(1)
  122. _, err = sdb.GetAccount(unexistingAccount)
  123. assert.NotNil(t, err)
  124. assert.Equal(t, db.ErrNotFound, tracerr.Unwrap(err))
  125. // add test accounts
  126. for i := 0; i < len(accounts); i++ {
  127. _, err = sdb.CreateAccount(accounts[i].Idx, accounts[i])
  128. assert.NoError(t, err)
  129. }
  130. for i := 0; i < len(accounts); i++ {
  131. existingAccount := accounts[i].Idx
  132. accGetted, err := sdb.GetAccount(existingAccount)
  133. assert.NoError(t, err)
  134. assert.Equal(t, accounts[i], accGetted)
  135. }
  136. // try already existing idx and get error
  137. existingAccount := common.Idx(256)
  138. _, err = sdb.GetAccount(existingAccount) // check that exist
  139. assert.NoError(t, err)
  140. _, err = sdb.CreateAccount(common.Idx(256), accounts[1]) // check that can not be created twice
  141. assert.NotNil(t, err)
  142. assert.Equal(t, ErrAccountAlreadyExists, tracerr.Unwrap(err))
  143. // update accounts
  144. for i := 0; i < len(accounts); i++ {
  145. accounts[i].Nonce = accounts[i].Nonce + 1
  146. existingAccount = common.Idx(i)
  147. _, err = sdb.UpdateAccount(existingAccount, accounts[i])
  148. assert.NoError(t, err)
  149. }
  150. _, err = sdb.MTGetProof(common.Idx(1))
  151. assert.NotNil(t, err)
  152. assert.Equal(t, ErrStateDBWithoutMT, tracerr.Unwrap(err))
  153. }
  154. func TestStateDBWithMT(t *testing.T) {
  155. dir, err := ioutil.TempDir("", "tmpdb")
  156. require.NoError(t, err)
  157. defer assert.NoError(t, os.RemoveAll(dir))
  158. chainID := uint16(0)
  159. sdb, err := NewStateDB(dir, TypeSynchronizer, 32, chainID)
  160. assert.NoError(t, err)
  161. // create test accounts
  162. var accounts []*common.Account
  163. for i := 0; i < 20; i++ {
  164. accounts = append(accounts, newAccount(t, i))
  165. }
  166. // get non-existing account, expecting an error
  167. _, err = sdb.GetAccount(common.Idx(1))
  168. assert.NotNil(t, err)
  169. assert.Equal(t, db.ErrNotFound, tracerr.Unwrap(err))
  170. // add test accounts
  171. for i := 0; i < len(accounts); i++ {
  172. _, err = sdb.CreateAccount(accounts[i].Idx, accounts[i])
  173. assert.NoError(t, err)
  174. }
  175. for i := 0; i < len(accounts); i++ {
  176. accGetted, err := sdb.GetAccount(accounts[i].Idx)
  177. assert.NoError(t, err)
  178. assert.Equal(t, accounts[i], accGetted)
  179. }
  180. // try already existing idx and get error
  181. _, err = sdb.GetAccount(common.Idx(256)) // check that exist
  182. assert.NoError(t, err)
  183. _, err = sdb.CreateAccount(common.Idx(256), accounts[1]) // check that can not be created twice
  184. assert.NotNil(t, err)
  185. assert.Equal(t, ErrAccountAlreadyExists, tracerr.Unwrap(err))
  186. _, err = sdb.MTGetProof(common.Idx(256))
  187. assert.NoError(t, err)
  188. // update accounts
  189. for i := 0; i < len(accounts); i++ {
  190. accounts[i].Nonce = accounts[i].Nonce + 1
  191. _, err = sdb.UpdateAccount(accounts[i].Idx, accounts[i])
  192. assert.NoError(t, err)
  193. }
  194. a, err := sdb.GetAccount(common.Idx(256)) // check that account value has been updated
  195. assert.NoError(t, err)
  196. assert.Equal(t, accounts[0].Nonce, a.Nonce)
  197. }
  198. func TestCheckpoints(t *testing.T) {
  199. dir, err := ioutil.TempDir("", "sdb")
  200. require.NoError(t, err)
  201. defer assert.NoError(t, os.RemoveAll(dir))
  202. chainID := uint16(0)
  203. sdb, err := NewStateDB(dir, TypeSynchronizer, 32, chainID)
  204. assert.NoError(t, err)
  205. // create test accounts
  206. var accounts []*common.Account
  207. for i := 0; i < 10; i++ {
  208. accounts = append(accounts, newAccount(t, i))
  209. }
  210. // add test accounts
  211. for i := 0; i < len(accounts); i++ {
  212. _, err = sdb.CreateAccount(accounts[i].Idx, accounts[i])
  213. assert.NoError(t, err)
  214. }
  215. // do checkpoints and check that currentBatch is correct
  216. err = sdb.MakeCheckpoint()
  217. assert.NoError(t, err)
  218. cb, err := sdb.GetCurrentBatch()
  219. assert.NoError(t, err)
  220. assert.Equal(t, common.BatchNum(1), cb)
  221. for i := 1; i < 10; i++ {
  222. err = sdb.MakeCheckpoint()
  223. assert.NoError(t, err)
  224. cb, err = sdb.GetCurrentBatch()
  225. assert.NoError(t, err)
  226. assert.Equal(t, common.BatchNum(i+1), cb)
  227. }
  228. // printCheckpoints(t, sdb.path)
  229. // reset checkpoint
  230. err = sdb.Reset(3)
  231. assert.NoError(t, err)
  232. // check that reset can be repeated (as there exist the 'current' and
  233. // 'BatchNum3', from where the 'current' is a copy)
  234. err = sdb.Reset(3)
  235. require.NoError(t, err)
  236. // check that currentBatch is as expected after Reset
  237. cb, err = sdb.GetCurrentBatch()
  238. assert.NoError(t, err)
  239. assert.Equal(t, common.BatchNum(3), cb)
  240. // advance one checkpoint and check that currentBatch is fine
  241. err = sdb.MakeCheckpoint()
  242. assert.NoError(t, err)
  243. cb, err = sdb.GetCurrentBatch()
  244. assert.NoError(t, err)
  245. assert.Equal(t, common.BatchNum(4), cb)
  246. err = sdb.DeleteCheckpoint(common.BatchNum(9))
  247. assert.NoError(t, err)
  248. err = sdb.DeleteCheckpoint(common.BatchNum(10))
  249. assert.NoError(t, err)
  250. err = sdb.DeleteCheckpoint(common.BatchNum(9)) // does not exist, should return err
  251. assert.NotNil(t, err)
  252. err = sdb.DeleteCheckpoint(common.BatchNum(11)) // does not exist, should return err
  253. assert.NotNil(t, err)
  254. // Create a LocalStateDB from the initial StateDB
  255. dirLocal, err := ioutil.TempDir("", "ldb")
  256. require.NoError(t, err)
  257. defer assert.NoError(t, os.RemoveAll(dirLocal))
  258. ldb, err := NewLocalStateDB(dirLocal, sdb, TypeBatchBuilder, 32)
  259. assert.NoError(t, err)
  260. // get checkpoint 4 from sdb (StateDB) to ldb (LocalStateDB)
  261. err = ldb.Reset(4, true)
  262. assert.NoError(t, err)
  263. // check that currentBatch is 4 after the Reset
  264. cb, err = ldb.GetCurrentBatch()
  265. assert.NoError(t, err)
  266. assert.Equal(t, common.BatchNum(4), cb)
  267. // advance one checkpoint in ldb
  268. err = ldb.MakeCheckpoint()
  269. assert.NoError(t, err)
  270. cb, err = ldb.GetCurrentBatch()
  271. assert.NoError(t, err)
  272. assert.Equal(t, common.BatchNum(5), cb)
  273. // Create a 2nd LocalStateDB from the initial StateDB
  274. dirLocal2, err := ioutil.TempDir("", "ldb2")
  275. require.NoError(t, err)
  276. defer assert.NoError(t, os.RemoveAll(dirLocal2))
  277. ldb2, err := NewLocalStateDB(dirLocal2, sdb, TypeBatchBuilder, 32)
  278. assert.NoError(t, err)
  279. // get checkpoint 4 from sdb (StateDB) to ldb (LocalStateDB)
  280. err = ldb2.Reset(4, true)
  281. assert.NoError(t, err)
  282. // check that currentBatch is 4 after the Reset
  283. cb, err = ldb2.GetCurrentBatch()
  284. assert.NoError(t, err)
  285. assert.Equal(t, common.BatchNum(4), cb)
  286. // advance one checkpoint in ldb2
  287. err = ldb2.MakeCheckpoint()
  288. assert.NoError(t, err)
  289. cb, err = ldb2.GetCurrentBatch()
  290. assert.NoError(t, err)
  291. assert.Equal(t, common.BatchNum(5), cb)
  292. debug := false
  293. if debug {
  294. printCheckpoints(t, sdb.path)
  295. printCheckpoints(t, ldb.path)
  296. printCheckpoints(t, ldb2.path)
  297. }
  298. }
  299. func TestStateDBGetAccounts(t *testing.T) {
  300. dir, err := ioutil.TempDir("", "tmpdb")
  301. require.NoError(t, err)
  302. chainID := uint16(0)
  303. sdb, err := NewStateDB(dir, TypeTxSelector, 0, chainID)
  304. assert.NoError(t, err)
  305. // create test accounts
  306. var accounts []common.Account
  307. for i := 0; i < 16; i++ {
  308. account := newAccount(t, i)
  309. accounts = append(accounts, *account)
  310. }
  311. // add test accounts
  312. for i := range accounts {
  313. _, err = sdb.CreateAccount(accounts[i].Idx, &accounts[i])
  314. require.NoError(t, err)
  315. }
  316. dbAccounts, err := sdb.GetAccounts()
  317. require.NoError(t, err)
  318. assert.Equal(t, accounts, dbAccounts)
  319. }
  320. func printCheckpoints(t *testing.T, path string) {
  321. files, err := ioutil.ReadDir(path)
  322. assert.NoError(t, err)
  323. fmt.Println(path)
  324. for _, f := range files {
  325. fmt.Println(" " + f.Name())
  326. }
  327. }
  328. func bigFromStr(h string, u int) *big.Int {
  329. if u == 16 {
  330. h = strings.TrimPrefix(h, "0x")
  331. }
  332. b, ok := new(big.Int).SetString(h, u)
  333. if !ok {
  334. panic("bigFromStr err")
  335. }
  336. return b
  337. }
  338. func TestCheckAccountsTreeTestVectors(t *testing.T) {
  339. dir, err := ioutil.TempDir("", "tmpdb")
  340. require.NoError(t, err)
  341. defer assert.NoError(t, os.RemoveAll(dir))
  342. chainID := uint16(0)
  343. sdb, err := NewStateDB(dir, TypeSynchronizer, 32, chainID)
  344. require.NoError(t, err)
  345. ay0 := new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(253), nil), big.NewInt(1))
  346. // test value from js version (compatibility-canary)
  347. assert.Equal(t, "1fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", (hex.EncodeToString(ay0.Bytes())))
  348. bjjPoint0Comp := babyjub.PackSignY(true, ay0)
  349. bjj0 := babyjub.PublicKeyComp(bjjPoint0Comp)
  350. ay1 := bigFromStr("00", 16)
  351. bjjPoint1Comp := babyjub.PackSignY(false, ay1)
  352. bjj1 := babyjub.PublicKeyComp(bjjPoint1Comp)
  353. ay2 := bigFromStr("21b0a1688b37f77b1d1d5539ec3b826db5ac78b2513f574a04c50a7d4f8246d7", 16)
  354. bjjPoint2Comp := babyjub.PackSignY(false, ay2)
  355. bjj2 := babyjub.PublicKeyComp(bjjPoint2Comp)
  356. ay3 := bigFromStr("0x10", 16) // 0x10=16
  357. bjjPoint3Comp := babyjub.PackSignY(false, ay3)
  358. require.NoError(t, err)
  359. bjj3 := babyjub.PublicKeyComp(bjjPoint3Comp)
  360. accounts := []*common.Account{
  361. {
  362. Idx: 1,
  363. TokenID: 0xFFFFFFFF,
  364. PublicKey: bjj0,
  365. EthAddr: ethCommon.HexToAddress("0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF"),
  366. Nonce: common.Nonce(0xFFFFFFFFFF),
  367. Balance: bigFromStr("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", 16),
  368. },
  369. {
  370. Idx: 100,
  371. TokenID: 0,
  372. PublicKey: bjj1,
  373. EthAddr: ethCommon.HexToAddress("0x00"),
  374. Nonce: common.Nonce(0),
  375. Balance: bigFromStr("0", 10),
  376. },
  377. {
  378. Idx: 0xFFFFFFFFFFFF,
  379. TokenID: 3,
  380. PublicKey: bjj2,
  381. EthAddr: ethCommon.HexToAddress("0xA3C88ac39A76789437AED31B9608da72e1bbfBF9"),
  382. Nonce: common.Nonce(129),
  383. Balance: bigFromStr("42000000000000000000", 10),
  384. },
  385. {
  386. Idx: 10000,
  387. TokenID: 1000,
  388. PublicKey: bjj3,
  389. EthAddr: ethCommon.HexToAddress("0x64"),
  390. Nonce: common.Nonce(1900),
  391. Balance: bigFromStr("14000000000000000000", 10),
  392. },
  393. }
  394. for i := 0; i < len(accounts); i++ {
  395. _, err = accounts[i].HashValue()
  396. require.NoError(t, err)
  397. _, err = sdb.CreateAccount(accounts[i].Idx, accounts[i])
  398. if err != nil {
  399. log.Error(err)
  400. }
  401. require.NoError(t, err)
  402. }
  403. // root value generated by js version:
  404. assert.Equal(t, "17298264051379321456969039521810887093935433569451713402227686942080129181291", sdb.mt.Root().BigInt().String())
  405. }