You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

658 lines
18 KiB

  1. package api
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "io/ioutil"
  9. "math"
  10. "math/big"
  11. "net/http"
  12. "os"
  13. "sort"
  14. "strconv"
  15. "testing"
  16. "time"
  17. ethCommon "github.com/ethereum/go-ethereum/common"
  18. swagger "github.com/getkin/kin-openapi/openapi3filter"
  19. "github.com/gin-gonic/gin"
  20. "github.com/hermeznetwork/hermez-node/common"
  21. dbUtils "github.com/hermeznetwork/hermez-node/db"
  22. "github.com/hermeznetwork/hermez-node/db/historydb"
  23. "github.com/hermeznetwork/hermez-node/db/l2db"
  24. "github.com/hermeznetwork/hermez-node/db/statedb"
  25. "github.com/hermeznetwork/hermez-node/log"
  26. "github.com/hermeznetwork/hermez-node/test"
  27. "github.com/iden3/go-iden3-crypto/babyjub"
  28. "github.com/mitchellh/copystructure"
  29. "github.com/stretchr/testify/assert"
  30. "github.com/stretchr/testify/require"
  31. )
  32. const apiPort = ":4010"
  33. const apiURL = "http://localhost" + apiPort + "/"
  34. type testCommon struct {
  35. blocks []common.Block
  36. tokens []common.Token
  37. batches []common.Batch
  38. usrAddr string
  39. usrBjj string
  40. accs []common.Account
  41. usrTxs historyTxAPIs
  42. othrTxs historyTxAPIs
  43. allTxs historyTxAPIs
  44. router *swagger.Router
  45. }
  46. type historyTxAPIs []historyTxAPI
  47. func (h historyTxAPIs) Len() int { return len(h) }
  48. func (h historyTxAPIs) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
  49. func (h historyTxAPIs) Less(i, j int) bool {
  50. // i not forged yet
  51. if h[i].BatchNum == nil {
  52. if h[j].BatchNum != nil { // j is already forged
  53. return false
  54. }
  55. // Both aren't forged, is i in a smaller position?
  56. return h[i].Position < h[j].Position
  57. }
  58. // i is forged
  59. if h[j].BatchNum == nil {
  60. return true // j is not forged
  61. }
  62. // Both are forged
  63. if *h[i].BatchNum == *h[j].BatchNum {
  64. // At the same batch, is i in a smaller position?
  65. return h[i].Position < h[j].Position
  66. }
  67. // At different batches, is i in a smaller batch?
  68. return *h[i].BatchNum < *h[j].BatchNum
  69. }
  70. var tc testCommon
  71. func TestMain(m *testing.M) {
  72. // Init swagger
  73. router := swagger.NewRouter().WithSwaggerFromFile("./swagger.yml")
  74. // Init DBs
  75. // HistoryDB
  76. pass := os.Getenv("POSTGRES_PASS")
  77. db, err := dbUtils.InitSQLDB(5432, "localhost", "hermez", pass, "hermez")
  78. if err != nil {
  79. panic(err)
  80. }
  81. hdb := historydb.NewHistoryDB(db)
  82. err = hdb.Reorg(-1)
  83. if err != nil {
  84. panic(err)
  85. }
  86. // StateDB
  87. dir, err := ioutil.TempDir("", "tmpdb")
  88. if err != nil {
  89. panic(err)
  90. }
  91. sdb, err := statedb.NewStateDB(dir, statedb.TypeTxSelector, 0)
  92. if err != nil {
  93. panic(err)
  94. }
  95. // L2DB
  96. l2DB := l2db.NewL2DB(db, 10, 100, 24*time.Hour)
  97. test.CleanL2DB(l2DB.DB())
  98. // Init API
  99. api := gin.Default()
  100. if err := SetAPIEndpoints(
  101. true,
  102. true,
  103. api,
  104. hdb,
  105. sdb,
  106. l2DB,
  107. ); err != nil {
  108. panic(err)
  109. }
  110. // Start server
  111. server := &http.Server{Addr: apiPort, Handler: api}
  112. go func() {
  113. if err := server.ListenAndServe(); err != nil &&
  114. err != http.ErrServerClosed {
  115. panic(err)
  116. }
  117. }()
  118. // Populate DBs
  119. // Clean DB
  120. err = h.Reorg(0)
  121. if err != nil {
  122. panic(err)
  123. }
  124. // Gen blocks and add them to DB
  125. const nBlocks = 5
  126. blocks := test.GenBlocks(1, nBlocks+1)
  127. err = h.AddBlocks(blocks)
  128. if err != nil {
  129. panic(err)
  130. }
  131. // Gen tokens and add them to DB
  132. const nTokens = 10
  133. tokens := test.GenTokens(nTokens, blocks)
  134. err = h.AddTokens(tokens)
  135. if err != nil {
  136. panic(err)
  137. }
  138. // Gen batches and add them to DB
  139. const nBatches = 10
  140. batches := test.GenBatches(nBatches, blocks)
  141. err = h.AddBatches(batches)
  142. if err != nil {
  143. panic(err)
  144. }
  145. // Gen accounts and add them to DB
  146. const totalAccounts = 40
  147. const userAccounts = 4
  148. usrAddr := ethCommon.BigToAddress(big.NewInt(4896847))
  149. privK := babyjub.NewRandPrivKey()
  150. usrBjj := privK.Public()
  151. accs := test.GenAccounts(totalAccounts, userAccounts, tokens, &usrAddr, usrBjj, batches)
  152. err = h.AddAccounts(accs)
  153. if err != nil {
  154. panic(err)
  155. }
  156. // Gen L1Txs and add them to DB
  157. const totalL1Txs = 40
  158. const userL1Txs = 4
  159. usrL1Txs, othrL1Txs := test.GenL1Txs(256, totalL1Txs, userL1Txs, &usrAddr, accs, tokens, blocks, batches)
  160. var l1Txs []common.L1Tx
  161. l1Txs = append(l1Txs, usrL1Txs...)
  162. l1Txs = append(l1Txs, othrL1Txs...)
  163. err = h.AddL1Txs(l1Txs)
  164. if err != nil {
  165. panic(err)
  166. }
  167. // Gen L2Txs and add them to DB
  168. const totalL2Txs = 20
  169. const userL2Txs = 4
  170. usrL2Txs, othrL2Txs := test.GenL2Txs(256+totalL1Txs, totalL2Txs, userL2Txs, &usrAddr, accs, tokens, blocks, batches)
  171. var l2Txs []common.L2Tx
  172. l2Txs = append(l2Txs, usrL2Txs...)
  173. l2Txs = append(l2Txs, othrL2Txs...)
  174. err = h.AddL2Txs(l2Txs)
  175. if err != nil {
  176. panic(err)
  177. }
  178. // Set test commons
  179. txsToAPITxs := func(l1Txs []common.L1Tx, l2Txs []common.L2Tx, blocks []common.Block, tokens []common.Token) historyTxAPIs {
  180. // Transform L1Txs and L2Txs to generic Txs
  181. genericTxs := []*common.Tx{}
  182. for _, l1tx := range l1Txs {
  183. genericTxs = append(genericTxs, l1tx.Tx())
  184. }
  185. for _, l2tx := range l2Txs {
  186. genericTxs = append(genericTxs, l2tx.Tx())
  187. }
  188. // Transform generic Txs to HistoryTx
  189. historyTxs := []historydb.HistoryTx{}
  190. for _, genericTx := range genericTxs {
  191. // find timestamp
  192. var timestamp time.Time
  193. for i := 0; i < len(blocks); i++ {
  194. if blocks[i].EthBlockNum == genericTx.EthBlockNum {
  195. timestamp = blocks[i].Timestamp
  196. break
  197. }
  198. }
  199. // find token
  200. var token common.Token
  201. if genericTx.IsL1 {
  202. tokenID := genericTx.TokenID
  203. found := false
  204. for i := 0; i < len(tokens); i++ {
  205. if tokens[i].TokenID == tokenID {
  206. token = tokens[i]
  207. found = true
  208. break
  209. }
  210. }
  211. if !found {
  212. panic("Token not found")
  213. }
  214. } else {
  215. token = test.GetToken(*genericTx.FromIdx, accs, tokens)
  216. }
  217. var usd, loadUSD, feeUSD *float64
  218. if token.USD != nil {
  219. noDecimalsUSD := *token.USD / math.Pow(10, float64(token.Decimals))
  220. usd = new(float64)
  221. *usd = noDecimalsUSD * genericTx.AmountFloat
  222. if genericTx.IsL1 {
  223. loadUSD = new(float64)
  224. *loadUSD = noDecimalsUSD * *genericTx.LoadAmountFloat
  225. } else {
  226. feeUSD = new(float64)
  227. *feeUSD = *usd * genericTx.Fee.Percentage()
  228. }
  229. }
  230. historyTxs = append(historyTxs, historydb.HistoryTx{
  231. IsL1: genericTx.IsL1,
  232. TxID: genericTx.TxID,
  233. Type: genericTx.Type,
  234. Position: genericTx.Position,
  235. FromIdx: genericTx.FromIdx,
  236. ToIdx: *genericTx.ToIdx,
  237. Amount: genericTx.Amount,
  238. AmountFloat: genericTx.AmountFloat,
  239. HistoricUSD: usd,
  240. BatchNum: genericTx.BatchNum,
  241. EthBlockNum: genericTx.EthBlockNum,
  242. ToForgeL1TxsNum: genericTx.ToForgeL1TxsNum,
  243. UserOrigin: genericTx.UserOrigin,
  244. FromEthAddr: genericTx.FromEthAddr,
  245. FromBJJ: genericTx.FromBJJ,
  246. LoadAmount: genericTx.LoadAmount,
  247. LoadAmountFloat: genericTx.LoadAmountFloat,
  248. HistoricLoadAmountUSD: loadUSD,
  249. Fee: genericTx.Fee,
  250. HistoricFeeUSD: feeUSD,
  251. Nonce: genericTx.Nonce,
  252. Timestamp: timestamp,
  253. TokenID: token.TokenID,
  254. TokenEthBlockNum: token.EthBlockNum,
  255. TokenEthAddr: token.EthAddr,
  256. TokenName: token.Name,
  257. TokenSymbol: token.Symbol,
  258. TokenDecimals: token.Decimals,
  259. TokenUSD: token.USD,
  260. TokenUSDUpdate: token.USDUpdate,
  261. })
  262. }
  263. return historyTxAPIs(historyTxsToAPI(historyTxs))
  264. }
  265. usrTxs := txsToAPITxs(usrL1Txs, usrL2Txs, blocks, tokens)
  266. sort.Sort(usrTxs)
  267. othrTxs := txsToAPITxs(othrL1Txs, othrL2Txs, blocks, tokens)
  268. sort.Sort(othrTxs)
  269. allTxs := append(usrTxs, othrTxs...)
  270. sort.Sort(allTxs)
  271. tc = testCommon{
  272. blocks: blocks,
  273. tokens: tokens,
  274. batches: batches,
  275. usrAddr: "hez:" + usrAddr.String(),
  276. usrBjj: bjjToString(usrBjj),
  277. accs: accs,
  278. usrTxs: usrTxs,
  279. othrTxs: othrTxs,
  280. allTxs: allTxs,
  281. router: router,
  282. }
  283. // Run tests
  284. result := m.Run()
  285. // Stop server
  286. if err := server.Shutdown(context.Background()); err != nil {
  287. panic(err)
  288. }
  289. if err := db.Close(); err != nil {
  290. panic(err)
  291. }
  292. os.Exit(result)
  293. }
  294. func TestGetHistoryTxs(t *testing.T) {
  295. endpoint := apiURL + "transactions-history"
  296. fetchedTxs := historyTxAPIs{}
  297. appendIter := func(intr interface{}) {
  298. for i := 0; i < len(intr.(*historyTxsAPI).Txs); i++ {
  299. tmp, err := copystructure.Copy(intr.(*historyTxsAPI).Txs[i])
  300. if err != nil {
  301. panic(err)
  302. }
  303. fetchedTxs = append(fetchedTxs, tmp.(historyTxAPI))
  304. }
  305. }
  306. // Get all (no filters)
  307. limit := 8
  308. path := fmt.Sprintf("%s?limit=%d&offset=", endpoint, limit)
  309. err := doGoodReqPaginated(path, &historyTxsAPI{}, appendIter)
  310. assert.NoError(t, err)
  311. assertHistoryTxAPIs(t, tc.allTxs, fetchedTxs)
  312. // Get by ethAddr
  313. fetchedTxs = historyTxAPIs{}
  314. limit = 7
  315. path = fmt.Sprintf(
  316. "%s?hermezEthereumAddress=%s&limit=%d&offset=",
  317. endpoint, tc.usrAddr, limit,
  318. )
  319. err = doGoodReqPaginated(path, &historyTxsAPI{}, appendIter)
  320. assert.NoError(t, err)
  321. assertHistoryTxAPIs(t, tc.usrTxs, fetchedTxs)
  322. // Get by bjj
  323. fetchedTxs = historyTxAPIs{}
  324. limit = 6
  325. path = fmt.Sprintf(
  326. "%s?BJJ=%s&limit=%d&offset=",
  327. endpoint, tc.usrBjj, limit,
  328. )
  329. err = doGoodReqPaginated(path, &historyTxsAPI{}, appendIter)
  330. assert.NoError(t, err)
  331. assertHistoryTxAPIs(t, tc.usrTxs, fetchedTxs)
  332. // Get by tokenID
  333. fetchedTxs = historyTxAPIs{}
  334. limit = 5
  335. tokenID := tc.allTxs[0].Token.TokenID
  336. path = fmt.Sprintf(
  337. "%s?tokenId=%d&limit=%d&offset=",
  338. endpoint, tokenID, limit,
  339. )
  340. err = doGoodReqPaginated(path, &historyTxsAPI{}, appendIter)
  341. assert.NoError(t, err)
  342. tokenIDTxs := historyTxAPIs{}
  343. for i := 0; i < len(tc.allTxs); i++ {
  344. if tc.allTxs[i].Token.TokenID == tokenID {
  345. tokenIDTxs = append(tokenIDTxs, tc.allTxs[i])
  346. }
  347. }
  348. assertHistoryTxAPIs(t, tokenIDTxs, fetchedTxs)
  349. // idx
  350. fetchedTxs = historyTxAPIs{}
  351. limit = 4
  352. idx := tc.allTxs[0].ToIdx
  353. path = fmt.Sprintf(
  354. "%s?accountIndex=%s&limit=%d&offset=",
  355. endpoint, idx, limit,
  356. )
  357. err = doGoodReqPaginated(path, &historyTxsAPI{}, appendIter)
  358. assert.NoError(t, err)
  359. idxTxs := historyTxAPIs{}
  360. for i := 0; i < len(tc.allTxs); i++ {
  361. if (tc.allTxs[i].FromIdx != nil && (*tc.allTxs[i].FromIdx)[6:] == idx[6:]) ||
  362. tc.allTxs[i].ToIdx[6:] == idx[6:] {
  363. idxTxs = append(idxTxs, tc.allTxs[i])
  364. }
  365. }
  366. assertHistoryTxAPIs(t, idxTxs, fetchedTxs)
  367. // batchNum
  368. fetchedTxs = historyTxAPIs{}
  369. limit = 3
  370. batchNum := tc.allTxs[0].BatchNum
  371. path = fmt.Sprintf(
  372. "%s?batchNum=%d&limit=%d&offset=",
  373. endpoint, *batchNum, limit,
  374. )
  375. err = doGoodReqPaginated(path, &historyTxsAPI{}, appendIter)
  376. assert.NoError(t, err)
  377. batchNumTxs := historyTxAPIs{}
  378. for i := 0; i < len(tc.allTxs); i++ {
  379. if tc.allTxs[i].BatchNum != nil &&
  380. *tc.allTxs[i].BatchNum == *batchNum {
  381. batchNumTxs = append(batchNumTxs, tc.allTxs[i])
  382. }
  383. }
  384. assertHistoryTxAPIs(t, batchNumTxs, fetchedTxs)
  385. // type
  386. txTypes := []common.TxType{
  387. common.TxTypeExit,
  388. common.TxTypeTransfer,
  389. common.TxTypeDeposit,
  390. common.TxTypeCreateAccountDeposit,
  391. common.TxTypeCreateAccountDepositTransfer,
  392. common.TxTypeDepositTransfer,
  393. common.TxTypeForceTransfer,
  394. common.TxTypeForceExit,
  395. common.TxTypeTransferToEthAddr,
  396. common.TxTypeTransferToBJJ,
  397. }
  398. for _, txType := range txTypes {
  399. fetchedTxs = historyTxAPIs{}
  400. limit = 2
  401. path = fmt.Sprintf(
  402. "%s?type=%s&limit=%d&offset=",
  403. endpoint, txType, limit,
  404. )
  405. err = doGoodReqPaginated(path, &historyTxsAPI{}, appendIter)
  406. assert.NoError(t, err)
  407. txTypeTxs := historyTxAPIs{}
  408. for i := 0; i < len(tc.allTxs); i++ {
  409. if tc.allTxs[i].Type == txType {
  410. txTypeTxs = append(txTypeTxs, tc.allTxs[i])
  411. }
  412. }
  413. assertHistoryTxAPIs(t, txTypeTxs, fetchedTxs)
  414. }
  415. // Multiple filters
  416. fetchedTxs = historyTxAPIs{}
  417. limit = 1
  418. path = fmt.Sprintf(
  419. "%s?batchNum=%d&tokeId=%d&limit=%d&offset=",
  420. endpoint, *batchNum, tokenID, limit,
  421. )
  422. err = doGoodReqPaginated(path, &historyTxsAPI{}, appendIter)
  423. assert.NoError(t, err)
  424. mixedTxs := historyTxAPIs{}
  425. for i := 0; i < len(tc.allTxs); i++ {
  426. if tc.allTxs[i].BatchNum != nil {
  427. if *tc.allTxs[i].BatchNum == *batchNum && tc.allTxs[i].Token.TokenID == tokenID {
  428. mixedTxs = append(mixedTxs, tc.allTxs[i])
  429. }
  430. }
  431. }
  432. assertHistoryTxAPIs(t, mixedTxs, fetchedTxs)
  433. // All, in reverse order
  434. fetchedTxs = historyTxAPIs{}
  435. limit = 5
  436. path = fmt.Sprintf("%s?", endpoint)
  437. appendIterRev := func(intr interface{}) {
  438. tmpAll := historyTxAPIs{}
  439. for i := 0; i < len(intr.(*historyTxsAPI).Txs); i++ {
  440. tmp, err := copystructure.Copy(intr.(*historyTxsAPI).Txs[i])
  441. if err != nil {
  442. panic(err)
  443. }
  444. tmpAll = append(tmpAll, tmp.(historyTxAPI))
  445. }
  446. fetchedTxs = append(tmpAll, fetchedTxs...)
  447. }
  448. err = doGoodReqPaginatedReverse(path, &historyTxsAPI{}, appendIterRev, limit)
  449. assert.NoError(t, err)
  450. assertHistoryTxAPIs(t, tc.allTxs, fetchedTxs)
  451. // 400
  452. path = fmt.Sprintf(
  453. "%s?accountIndex=%s&hermezEthereumAddress=%s",
  454. endpoint, idx, tc.usrAddr,
  455. )
  456. err = doBadReq("GET", path, nil, 400)
  457. assert.NoError(t, err)
  458. path = fmt.Sprintf("%s?tokenId=X", endpoint)
  459. err = doBadReq("GET", path, nil, 400)
  460. assert.NoError(t, err)
  461. // 404
  462. path = fmt.Sprintf("%s?batchNum=999999", endpoint)
  463. err = doBadReq("GET", path, nil, 404)
  464. assert.NoError(t, err)
  465. path = fmt.Sprintf("%s?limit=1000&offset=1000", endpoint)
  466. err = doBadReq("GET", path, nil, 404)
  467. assert.NoError(t, err)
  468. }
  469. func assertHistoryTxAPIs(t *testing.T, expected, actual historyTxAPIs) {
  470. require.Equal(t, len(expected), len(actual))
  471. for i := 0; i < len(actual); i++ { //nolint len(actual) won't change within the loop
  472. assert.Equal(t, expected[i].Timestamp.Unix(), actual[i].Timestamp.Unix())
  473. expected[i].Timestamp = actual[i].Timestamp
  474. if expected[i].Token.USDUpdate == nil {
  475. assert.Equal(t, expected[i].Token.USDUpdate, actual[i].Token.USDUpdate)
  476. } else {
  477. assert.Equal(t, expected[i].Token.USDUpdate.Unix(), actual[i].Token.USDUpdate.Unix())
  478. expected[i].Token.USDUpdate = actual[i].Token.USDUpdate
  479. }
  480. test.AssertUSD(t, expected[i].HistoricUSD, actual[i].HistoricUSD)
  481. if expected[i].L2Info != nil {
  482. test.AssertUSD(t, expected[i].L2Info.HistoricFeeUSD, actual[i].L2Info.HistoricFeeUSD)
  483. } else {
  484. test.AssertUSD(t, expected[i].L1Info.HistoricLoadAmountUSD, actual[i].L1Info.HistoricLoadAmountUSD)
  485. }
  486. assert.Equal(t, expected[i], actual[i])
  487. }
  488. }
  489. func doGoodReqPaginated(
  490. path string,
  491. iterStruct paginationer,
  492. appendIter func(res interface{}),
  493. ) error {
  494. next := 0
  495. for {
  496. // Call API to get this iteration items
  497. if err := doGoodReq("GET", path+strconv.Itoa(next), nil, iterStruct); err != nil {
  498. return err
  499. }
  500. appendIter(iterStruct)
  501. // Keep iterating?
  502. pag := iterStruct.GetPagination()
  503. if pag.LastReturnedItem == pag.TotalItems-1 { // No
  504. break
  505. } else { // Yes
  506. next = int(pag.LastReturnedItem + 1)
  507. }
  508. }
  509. return nil
  510. }
  511. func doGoodReqPaginatedReverse(
  512. path string,
  513. iterStruct paginationer,
  514. appendIter func(res interface{}),
  515. limit int,
  516. ) error {
  517. next := 0
  518. first := true
  519. for {
  520. // Call API to get this iteration items
  521. if first {
  522. first = false
  523. pagQuery := fmt.Sprintf("last=true&limit=%d", limit)
  524. if err := doGoodReq("GET", path+pagQuery, nil, iterStruct); err != nil {
  525. return err
  526. }
  527. } else {
  528. pagQuery := fmt.Sprintf("offset=%d&limit=%d", next, limit)
  529. if err := doGoodReq("GET", path+pagQuery, nil, iterStruct); err != nil {
  530. return err
  531. }
  532. }
  533. appendIter(iterStruct)
  534. // Keep iterating?
  535. pag := iterStruct.GetPagination()
  536. if iterStruct.Len() == pag.TotalItems || pag.LastReturnedItem-iterStruct.Len() == -1 { // No
  537. break
  538. } else { // Yes
  539. prevOffset := next
  540. next = pag.LastReturnedItem - iterStruct.Len() - limit + 1
  541. if next < 0 {
  542. next = 0
  543. limit = prevOffset
  544. }
  545. }
  546. }
  547. return nil
  548. }
  549. func doGoodReq(method, path string, reqBody io.Reader, returnStruct interface{}) error {
  550. ctx := context.Background()
  551. client := &http.Client{}
  552. httpReq, _ := http.NewRequest(method, path, reqBody)
  553. route, pathParams, err := tc.router.FindRoute(httpReq.Method, httpReq.URL)
  554. if err != nil {
  555. return err
  556. }
  557. // Validate request against swagger spec
  558. requestValidationInput := &swagger.RequestValidationInput{
  559. Request: httpReq,
  560. PathParams: pathParams,
  561. Route: route,
  562. }
  563. if err := swagger.ValidateRequest(ctx, requestValidationInput); err != nil {
  564. return err
  565. }
  566. // Do API call
  567. resp, err := client.Do(httpReq)
  568. if err != nil {
  569. return err
  570. }
  571. if resp.Body == nil {
  572. return errors.New("Nil body")
  573. }
  574. //nolint
  575. defer resp.Body.Close()
  576. body, err := ioutil.ReadAll(resp.Body)
  577. if err != nil {
  578. return err
  579. }
  580. if resp.StatusCode != 200 {
  581. return fmt.Errorf("%d response: %s", resp.StatusCode, string(body))
  582. }
  583. // Unmarshal body into return struct
  584. if err := json.Unmarshal(body, returnStruct); err != nil {
  585. return err
  586. }
  587. // Validate response against swagger spec
  588. responseValidationInput := &swagger.ResponseValidationInput{
  589. RequestValidationInput: requestValidationInput,
  590. Status: resp.StatusCode,
  591. Header: resp.Header,
  592. }
  593. responseValidationInput = responseValidationInput.SetBodyBytes(body)
  594. return swagger.ValidateResponse(ctx, responseValidationInput)
  595. }
  596. func doBadReq(method, path string, reqBody io.Reader, expectedResponseCode int) error {
  597. ctx := context.Background()
  598. client := &http.Client{}
  599. httpReq, _ := http.NewRequest(method, path, reqBody)
  600. route, pathParams, err := tc.router.FindRoute(httpReq.Method, httpReq.URL)
  601. if err != nil {
  602. return err
  603. }
  604. // Validate request against swagger spec
  605. requestValidationInput := &swagger.RequestValidationInput{
  606. Request: httpReq,
  607. PathParams: pathParams,
  608. Route: route,
  609. }
  610. if err := swagger.ValidateRequest(ctx, requestValidationInput); err != nil {
  611. if expectedResponseCode != 400 {
  612. return err
  613. }
  614. log.Warn("The request does not match the API spec")
  615. }
  616. // Do API call
  617. resp, err := client.Do(httpReq)
  618. if err != nil {
  619. return err
  620. }
  621. if resp.Body == nil {
  622. return errors.New("Nil body")
  623. }
  624. //nolint
  625. defer resp.Body.Close()
  626. body, err := ioutil.ReadAll(resp.Body)
  627. if err != nil {
  628. return err
  629. }
  630. if resp.StatusCode != expectedResponseCode {
  631. return fmt.Errorf("Unexpected response code: %d", resp.StatusCode)
  632. }
  633. // Validate response against swagger spec
  634. responseValidationInput := &swagger.ResponseValidationInput{
  635. RequestValidationInput: requestValidationInput,
  636. Status: resp.StatusCode,
  637. Header: resp.Header,
  638. }
  639. responseValidationInput = responseValidationInput.SetBodyBytes(body)
  640. return swagger.ValidateResponse(ctx, responseValidationInput)
  641. }