You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

657 lines
18 KiB

  1. package api
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "io/ioutil"
  9. "math"
  10. "math/big"
  11. "net/http"
  12. "os"
  13. "sort"
  14. "strconv"
  15. "testing"
  16. "time"
  17. ethCommon "github.com/ethereum/go-ethereum/common"
  18. swagger "github.com/getkin/kin-openapi/openapi3filter"
  19. "github.com/gin-gonic/gin"
  20. "github.com/hermeznetwork/hermez-node/common"
  21. dbUtils "github.com/hermeznetwork/hermez-node/db"
  22. "github.com/hermeznetwork/hermez-node/db/historydb"
  23. "github.com/hermeznetwork/hermez-node/db/l2db"
  24. "github.com/hermeznetwork/hermez-node/db/statedb"
  25. "github.com/hermeznetwork/hermez-node/log"
  26. "github.com/hermeznetwork/hermez-node/test"
  27. "github.com/iden3/go-iden3-crypto/babyjub"
  28. "github.com/mitchellh/copystructure"
  29. "github.com/stretchr/testify/assert"
  30. )
  31. const apiPort = ":4010"
  32. const apiURL = "http://localhost" + apiPort + "/"
  33. type testCommon struct {
  34. blocks []common.Block
  35. tokens []common.Token
  36. batches []common.Batch
  37. usrAddr string
  38. usrBjj string
  39. accs []common.Account
  40. usrTxs historyTxAPIs
  41. othrTxs historyTxAPIs
  42. allTxs historyTxAPIs
  43. router *swagger.Router
  44. }
  45. type historyTxAPIs []historyTxAPI
  46. func (h historyTxAPIs) Len() int { return len(h) }
  47. func (h historyTxAPIs) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
  48. func (h historyTxAPIs) Less(i, j int) bool {
  49. // i not forged yet
  50. if h[i].BatchNum == nil {
  51. if h[j].BatchNum != nil { // j is already forged
  52. return false
  53. }
  54. // Both aren't forged, is i in a smaller position?
  55. return h[i].Position < h[j].Position
  56. }
  57. // i is forged
  58. if h[j].BatchNum == nil {
  59. return true // j is not forged
  60. }
  61. // Both are forged
  62. if *h[i].BatchNum == *h[j].BatchNum {
  63. // At the same batch, is i in a smaller position?
  64. return h[i].Position < h[j].Position
  65. }
  66. // At different batches, is i in a smaller batch?
  67. return *h[i].BatchNum < *h[j].BatchNum
  68. }
  69. var tc testCommon
  70. func TestMain(m *testing.M) {
  71. // Init swagger
  72. router := swagger.NewRouter().WithSwaggerFromFile("./swagger.yml")
  73. // Init DBs
  74. // HistoryDB
  75. pass := os.Getenv("POSTGRES_PASS")
  76. db, err := dbUtils.InitSQLDB(5432, "localhost", "hermez", pass, "hermez")
  77. if err != nil {
  78. panic(err)
  79. }
  80. hdb := historydb.NewHistoryDB(db)
  81. err = hdb.Reorg(-1)
  82. if err != nil {
  83. panic(err)
  84. }
  85. // StateDB
  86. dir, err := ioutil.TempDir("", "tmpdb")
  87. if err != nil {
  88. panic(err)
  89. }
  90. sdb, err := statedb.NewStateDB(dir, false, 0)
  91. if err != nil {
  92. panic(err)
  93. }
  94. // L2DB
  95. l2DB := l2db.NewL2DB(db, 10, 100, 24*time.Hour)
  96. test.CleanL2DB(l2DB.DB())
  97. // Init API
  98. api := gin.Default()
  99. if err := SetAPIEndpoints(
  100. true,
  101. true,
  102. api,
  103. hdb,
  104. sdb,
  105. l2DB,
  106. ); err != nil {
  107. panic(err)
  108. }
  109. // Start server
  110. server := &http.Server{Addr: apiPort, Handler: api}
  111. go func() {
  112. if err := server.ListenAndServe(); err != nil &&
  113. err != http.ErrServerClosed {
  114. panic(err)
  115. }
  116. }()
  117. // Populate DBs
  118. // Clean DB
  119. err = h.Reorg(0)
  120. if err != nil {
  121. panic(err)
  122. }
  123. // Gen blocks and add them to DB
  124. const nBlocks = 5
  125. blocks := test.GenBlocks(1, nBlocks+1)
  126. err = h.AddBlocks(blocks)
  127. if err != nil {
  128. panic(err)
  129. }
  130. // Gen tokens and add them to DB
  131. const nTokens = 10
  132. tokens := test.GenTokens(nTokens, blocks)
  133. err = h.AddTokens(tokens)
  134. if err != nil {
  135. panic(err)
  136. }
  137. // Gen batches and add them to DB
  138. const nBatches = 10
  139. batches := test.GenBatches(nBatches, blocks)
  140. err = h.AddBatches(batches)
  141. if err != nil {
  142. panic(err)
  143. }
  144. // Gen accounts and add them to DB
  145. const totalAccounts = 40
  146. const userAccounts = 4
  147. usrAddr := ethCommon.BigToAddress(big.NewInt(4896847))
  148. privK := babyjub.NewRandPrivKey()
  149. usrBjj := privK.Public()
  150. accs := test.GenAccounts(totalAccounts, userAccounts, tokens, &usrAddr, usrBjj, batches)
  151. err = h.AddAccounts(accs)
  152. if err != nil {
  153. panic(err)
  154. }
  155. // Gen L1Txs and add them to DB
  156. const totalL1Txs = 40
  157. const userL1Txs = 4
  158. usrL1Txs, othrL1Txs := test.GenL1Txs(0, totalL1Txs, userL1Txs, &usrAddr, accs, tokens, blocks, batches)
  159. var l1Txs []common.L1Tx
  160. l1Txs = append(l1Txs, usrL1Txs...)
  161. l1Txs = append(l1Txs, othrL1Txs...)
  162. err = h.AddL1Txs(l1Txs)
  163. if err != nil {
  164. panic(err)
  165. }
  166. // Gen L2Txs and add them to DB
  167. const totalL2Txs = 20
  168. const userL2Txs = 4
  169. usrL2Txs, othrL2Txs := test.GenL2Txs(totalL1Txs, totalL2Txs, userL2Txs, &usrAddr, accs, tokens, blocks, batches)
  170. var l2Txs []common.L2Tx
  171. l2Txs = append(l2Txs, usrL2Txs...)
  172. l2Txs = append(l2Txs, othrL2Txs...)
  173. err = h.AddL2Txs(l2Txs)
  174. if err != nil {
  175. panic(err)
  176. }
  177. // Set test commons
  178. txsToAPITxs := func(l1Txs []common.L1Tx, l2Txs []common.L2Tx, blocks []common.Block, tokens []common.Token) historyTxAPIs {
  179. // Transform L1Txs and L2Txs to generic Txs
  180. genericTxs := []*common.Tx{}
  181. for _, l1tx := range l1Txs {
  182. genericTxs = append(genericTxs, l1tx.Tx())
  183. }
  184. for _, l2tx := range l2Txs {
  185. genericTxs = append(genericTxs, l2tx.Tx())
  186. }
  187. // Transform generic Txs to HistoryTx
  188. historyTxs := []*historydb.HistoryTx{}
  189. for _, genericTx := range genericTxs {
  190. // find timestamp
  191. var timestamp time.Time
  192. for i := 0; i < len(blocks); i++ {
  193. if blocks[i].EthBlockNum == genericTx.EthBlockNum {
  194. timestamp = blocks[i].Timestamp
  195. break
  196. }
  197. }
  198. // find token
  199. var token common.Token
  200. if genericTx.IsL1 {
  201. tokenID := genericTx.TokenID
  202. found := false
  203. for i := 0; i < len(tokens); i++ {
  204. if tokens[i].TokenID == tokenID {
  205. token = tokens[i]
  206. found = true
  207. break
  208. }
  209. }
  210. if !found {
  211. panic("Token not found")
  212. }
  213. } else {
  214. token = test.GetToken(genericTx.FromIdx, accs, tokens)
  215. }
  216. var usd, loadUSD, feeUSD *float64
  217. if token.USD != nil {
  218. noDecimalsUSD := *token.USD / math.Pow(10, float64(token.Decimals))
  219. usd = new(float64)
  220. *usd = noDecimalsUSD * genericTx.AmountFloat
  221. if genericTx.IsL1 {
  222. loadUSD = new(float64)
  223. *loadUSD = noDecimalsUSD * *genericTx.LoadAmountFloat
  224. } else {
  225. feeUSD = new(float64)
  226. *feeUSD = *usd * genericTx.Fee.Percentage()
  227. }
  228. }
  229. historyTxs = append(historyTxs, &historydb.HistoryTx{
  230. IsL1: genericTx.IsL1,
  231. TxID: genericTx.TxID,
  232. Type: genericTx.Type,
  233. Position: genericTx.Position,
  234. FromIdx: genericTx.FromIdx,
  235. ToIdx: genericTx.ToIdx,
  236. Amount: genericTx.Amount,
  237. AmountFloat: genericTx.AmountFloat,
  238. HistoricUSD: usd,
  239. BatchNum: genericTx.BatchNum,
  240. EthBlockNum: genericTx.EthBlockNum,
  241. ToForgeL1TxsNum: genericTx.ToForgeL1TxsNum,
  242. UserOrigin: genericTx.UserOrigin,
  243. FromEthAddr: genericTx.FromEthAddr,
  244. FromBJJ: genericTx.FromBJJ,
  245. LoadAmount: genericTx.LoadAmount,
  246. LoadAmountFloat: genericTx.LoadAmountFloat,
  247. HistoricLoadAmountUSD: loadUSD,
  248. Fee: genericTx.Fee,
  249. HistoricFeeUSD: feeUSD,
  250. Nonce: genericTx.Nonce,
  251. Timestamp: timestamp,
  252. TokenID: token.TokenID,
  253. TokenEthBlockNum: token.EthBlockNum,
  254. TokenEthAddr: token.EthAddr,
  255. TokenName: token.Name,
  256. TokenSymbol: token.Symbol,
  257. TokenDecimals: token.Decimals,
  258. TokenUSD: token.USD,
  259. TokenUSDUpdate: token.USDUpdate,
  260. })
  261. }
  262. return historyTxAPIs(historyTxsToAPI(historyTxs))
  263. }
  264. usrTxs := txsToAPITxs(usrL1Txs, usrL2Txs, blocks, tokens)
  265. sort.Sort(usrTxs)
  266. othrTxs := txsToAPITxs(othrL1Txs, othrL2Txs, blocks, tokens)
  267. sort.Sort(othrTxs)
  268. allTxs := append(usrTxs, othrTxs...)
  269. sort.Sort(allTxs)
  270. tc = testCommon{
  271. blocks: blocks,
  272. tokens: tokens,
  273. batches: batches,
  274. usrAddr: "hez:" + usrAddr.String(),
  275. usrBjj: bjjToString(usrBjj),
  276. accs: accs,
  277. usrTxs: usrTxs,
  278. othrTxs: othrTxs,
  279. allTxs: allTxs,
  280. router: router,
  281. }
  282. // Run tests
  283. result := m.Run()
  284. // Stop server
  285. if err := server.Shutdown(context.Background()); err != nil {
  286. panic(err)
  287. }
  288. if err := db.Close(); err != nil {
  289. panic(err)
  290. }
  291. os.Exit(result)
  292. }
  293. func TestGetHistoryTxs(t *testing.T) {
  294. endpoint := apiURL + "transactions-history"
  295. fetchedTxs := historyTxAPIs{}
  296. appendIter := func(intr interface{}) {
  297. for i := 0; i < len(intr.(*historyTxsAPI).Txs); i++ {
  298. tmp, err := copystructure.Copy(intr.(*historyTxsAPI).Txs[i])
  299. if err != nil {
  300. panic(err)
  301. }
  302. fetchedTxs = append(fetchedTxs, tmp.(historyTxAPI))
  303. }
  304. }
  305. // Get all (no filters)
  306. limit := 8
  307. path := fmt.Sprintf("%s?limit=%d&offset=", endpoint, limit)
  308. err := doGoodReqPaginated(path, &historyTxsAPI{}, appendIter)
  309. assert.NoError(t, err)
  310. assertHistoryTxAPIs(t, tc.allTxs, fetchedTxs)
  311. // Get by ethAddr
  312. fetchedTxs = historyTxAPIs{}
  313. limit = 7
  314. path = fmt.Sprintf(
  315. "%s?hermezEthereumAddress=%s&limit=%d&offset=",
  316. endpoint, tc.usrAddr, limit,
  317. )
  318. err = doGoodReqPaginated(path, &historyTxsAPI{}, appendIter)
  319. assert.NoError(t, err)
  320. assertHistoryTxAPIs(t, tc.usrTxs, fetchedTxs)
  321. // Get by bjj
  322. fetchedTxs = historyTxAPIs{}
  323. limit = 6
  324. path = fmt.Sprintf(
  325. "%s?BJJ=%s&limit=%d&offset=",
  326. endpoint, tc.usrBjj, limit,
  327. )
  328. err = doGoodReqPaginated(path, &historyTxsAPI{}, appendIter)
  329. assert.NoError(t, err)
  330. assertHistoryTxAPIs(t, tc.usrTxs, fetchedTxs)
  331. // Get by tokenID
  332. fetchedTxs = historyTxAPIs{}
  333. limit = 5
  334. tokenID := tc.allTxs[0].Token.TokenID
  335. path = fmt.Sprintf(
  336. "%s?tokenId=%d&limit=%d&offset=",
  337. endpoint, tokenID, limit,
  338. )
  339. err = doGoodReqPaginated(path, &historyTxsAPI{}, appendIter)
  340. assert.NoError(t, err)
  341. tokenIDTxs := historyTxAPIs{}
  342. for i := 0; i < len(tc.allTxs); i++ {
  343. if tc.allTxs[i].Token.TokenID == tokenID {
  344. tokenIDTxs = append(tokenIDTxs, tc.allTxs[i])
  345. }
  346. }
  347. assertHistoryTxAPIs(t, tokenIDTxs, fetchedTxs)
  348. // idx
  349. fetchedTxs = historyTxAPIs{}
  350. limit = 4
  351. idx := tc.allTxs[0].FromIdx
  352. path = fmt.Sprintf(
  353. "%s?accountIndex=%s&limit=%d&offset=",
  354. endpoint, idx, limit,
  355. )
  356. err = doGoodReqPaginated(path, &historyTxsAPI{}, appendIter)
  357. assert.NoError(t, err)
  358. idxTxs := historyTxAPIs{}
  359. for i := 0; i < len(tc.allTxs); i++ {
  360. if tc.allTxs[i].FromIdx == idx {
  361. idxTxs = append(idxTxs, tc.allTxs[i])
  362. }
  363. }
  364. assertHistoryTxAPIs(t, idxTxs, fetchedTxs)
  365. // batchNum
  366. fetchedTxs = historyTxAPIs{}
  367. limit = 3
  368. batchNum := tc.allTxs[0].BatchNum
  369. path = fmt.Sprintf(
  370. "%s?batchNum=%d&limit=%d&offset=",
  371. endpoint, *batchNum, limit,
  372. )
  373. err = doGoodReqPaginated(path, &historyTxsAPI{}, appendIter)
  374. assert.NoError(t, err)
  375. batchNumTxs := historyTxAPIs{}
  376. for i := 0; i < len(tc.allTxs); i++ {
  377. if tc.allTxs[i].BatchNum != nil &&
  378. *tc.allTxs[i].BatchNum == *batchNum {
  379. batchNumTxs = append(batchNumTxs, tc.allTxs[i])
  380. }
  381. }
  382. assertHistoryTxAPIs(t, batchNumTxs, fetchedTxs)
  383. // type
  384. txTypes := []common.TxType{
  385. common.TxTypeExit,
  386. common.TxTypeWithdrawn,
  387. common.TxTypeTransfer,
  388. common.TxTypeDeposit,
  389. common.TxTypeCreateAccountDeposit,
  390. common.TxTypeCreateAccountDepositTransfer,
  391. common.TxTypeDepositTransfer,
  392. common.TxTypeForceTransfer,
  393. common.TxTypeForceExit,
  394. common.TxTypeTransferToEthAddr,
  395. common.TxTypeTransferToBJJ,
  396. }
  397. for _, txType := range txTypes {
  398. fetchedTxs = historyTxAPIs{}
  399. limit = 2
  400. path = fmt.Sprintf(
  401. "%s?type=%s&limit=%d&offset=",
  402. endpoint, txType, limit,
  403. )
  404. err = doGoodReqPaginated(path, &historyTxsAPI{}, appendIter)
  405. assert.NoError(t, err)
  406. txTypeTxs := historyTxAPIs{}
  407. for i := 0; i < len(tc.allTxs); i++ {
  408. if tc.allTxs[i].Type == txType {
  409. txTypeTxs = append(txTypeTxs, tc.allTxs[i])
  410. }
  411. }
  412. assertHistoryTxAPIs(t, txTypeTxs, fetchedTxs)
  413. }
  414. // Multiple filters
  415. fetchedTxs = historyTxAPIs{}
  416. limit = 1
  417. path = fmt.Sprintf(
  418. "%s?batchNum=%d&tokeId=%d&limit=%d&offset=",
  419. endpoint, *batchNum, tokenID, limit,
  420. )
  421. err = doGoodReqPaginated(path, &historyTxsAPI{}, appendIter)
  422. assert.NoError(t, err)
  423. mixedTxs := historyTxAPIs{}
  424. for i := 0; i < len(tc.allTxs); i++ {
  425. if tc.allTxs[i].BatchNum != nil {
  426. if *tc.allTxs[i].BatchNum == *batchNum && tc.allTxs[i].Token.TokenID == tokenID {
  427. mixedTxs = append(mixedTxs, tc.allTxs[i])
  428. }
  429. }
  430. }
  431. assertHistoryTxAPIs(t, mixedTxs, fetchedTxs)
  432. // All, in reverse order
  433. fetchedTxs = historyTxAPIs{}
  434. limit = 5
  435. path = fmt.Sprintf("%s?", endpoint)
  436. appendIterRev := func(intr interface{}) {
  437. tmpAll := historyTxAPIs{}
  438. for i := 0; i < len(intr.(*historyTxsAPI).Txs); i++ {
  439. tmp, err := copystructure.Copy(intr.(*historyTxsAPI).Txs[i])
  440. if err != nil {
  441. panic(err)
  442. }
  443. tmpAll = append(tmpAll, tmp.(historyTxAPI))
  444. }
  445. fetchedTxs = append(tmpAll, fetchedTxs...)
  446. }
  447. err = doGoodReqPaginatedReverse(path, &historyTxsAPI{}, appendIterRev, limit)
  448. assert.NoError(t, err)
  449. assertHistoryTxAPIs(t, tc.allTxs, fetchedTxs)
  450. // 400
  451. path = fmt.Sprintf(
  452. "%s?accountIndex=%s&hermezEthereumAddress=%s",
  453. endpoint, idx, tc.usrAddr,
  454. )
  455. err = doBadReq("GET", path, nil, 400)
  456. assert.NoError(t, err)
  457. path = fmt.Sprintf("%s?tokenId=X", endpoint)
  458. err = doBadReq("GET", path, nil, 400)
  459. assert.NoError(t, err)
  460. // 404
  461. path = fmt.Sprintf("%s?batchNum=999999", endpoint)
  462. err = doBadReq("GET", path, nil, 404)
  463. assert.NoError(t, err)
  464. path = fmt.Sprintf("%s?limit=1000&offset=1000", endpoint)
  465. err = doBadReq("GET", path, nil, 404)
  466. assert.NoError(t, err)
  467. }
  468. func assertHistoryTxAPIs(t *testing.T, expected, actual historyTxAPIs) {
  469. assert.Equal(t, len(expected), len(actual))
  470. for i := 0; i < len(actual); i++ { //nolint len(actual) won't change within the loop
  471. assert.Equal(t, expected[i].Timestamp.Unix(), actual[i].Timestamp.Unix())
  472. expected[i].Timestamp = actual[i].Timestamp
  473. if expected[i].Token.USDUpdate == nil {
  474. assert.Equal(t, expected[i].Token.USDUpdate, actual[i].Token.USDUpdate)
  475. } else {
  476. assert.Equal(t, expected[i].Token.USDUpdate.Unix(), actual[i].Token.USDUpdate.Unix())
  477. expected[i].Token.USDUpdate = actual[i].Token.USDUpdate
  478. }
  479. test.AssertUSD(t, expected[i].HistoricUSD, actual[i].HistoricUSD)
  480. if expected[i].L2Info != nil {
  481. test.AssertUSD(t, expected[i].L2Info.HistoricFeeUSD, actual[i].L2Info.HistoricFeeUSD)
  482. } else {
  483. test.AssertUSD(t, expected[i].L1Info.HistoricLoadAmountUSD, actual[i].L1Info.HistoricLoadAmountUSD)
  484. }
  485. assert.Equal(t, expected[i], actual[i])
  486. }
  487. }
  488. func doGoodReqPaginated(
  489. path string,
  490. iterStruct paginationer,
  491. appendIter func(res interface{}),
  492. ) error {
  493. next := 0
  494. for {
  495. // Call API to get this iteration items
  496. if err := doGoodReq("GET", path+strconv.Itoa(next), nil, iterStruct); err != nil {
  497. return err
  498. }
  499. appendIter(iterStruct)
  500. // Keep iterating?
  501. pag := iterStruct.GetPagination()
  502. if pag.LastReturnedItem == pag.TotalItems-1 { // No
  503. break
  504. } else { // Yes
  505. next = int(pag.LastReturnedItem + 1)
  506. }
  507. }
  508. return nil
  509. }
  510. func doGoodReqPaginatedReverse(
  511. path string,
  512. iterStruct paginationer,
  513. appendIter func(res interface{}),
  514. limit int,
  515. ) error {
  516. next := 0
  517. first := true
  518. for {
  519. // Call API to get this iteration items
  520. if first {
  521. first = false
  522. pagQuery := fmt.Sprintf("last=true&limit=%d", limit)
  523. if err := doGoodReq("GET", path+pagQuery, nil, iterStruct); err != nil {
  524. return err
  525. }
  526. } else {
  527. pagQuery := fmt.Sprintf("offset=%d&limit=%d", next, limit)
  528. if err := doGoodReq("GET", path+pagQuery, nil, iterStruct); err != nil {
  529. return err
  530. }
  531. }
  532. appendIter(iterStruct)
  533. // Keep iterating?
  534. pag := iterStruct.GetPagination()
  535. if iterStruct.Len() == pag.TotalItems || pag.LastReturnedItem-iterStruct.Len() == -1 { // No
  536. break
  537. } else { // Yes
  538. prevOffset := next
  539. next = pag.LastReturnedItem - iterStruct.Len() - limit + 1
  540. if next < 0 {
  541. next = 0
  542. limit = prevOffset
  543. }
  544. }
  545. }
  546. return nil
  547. }
  548. func doGoodReq(method, path string, reqBody io.Reader, returnStruct interface{}) error {
  549. ctx := context.Background()
  550. client := &http.Client{}
  551. httpReq, _ := http.NewRequest(method, path, reqBody)
  552. route, pathParams, err := tc.router.FindRoute(httpReq.Method, httpReq.URL)
  553. if err != nil {
  554. return err
  555. }
  556. // Validate request against swagger spec
  557. requestValidationInput := &swagger.RequestValidationInput{
  558. Request: httpReq,
  559. PathParams: pathParams,
  560. Route: route,
  561. }
  562. if err := swagger.ValidateRequest(ctx, requestValidationInput); err != nil {
  563. return err
  564. }
  565. // Do API call
  566. resp, err := client.Do(httpReq)
  567. if err != nil {
  568. return err
  569. }
  570. if resp.Body == nil {
  571. return errors.New("Nil body")
  572. }
  573. //nolint
  574. defer resp.Body.Close()
  575. body, err := ioutil.ReadAll(resp.Body)
  576. if err != nil {
  577. return err
  578. }
  579. if resp.StatusCode != 200 {
  580. return fmt.Errorf("%d response: %s", resp.StatusCode, string(body))
  581. }
  582. // Unmarshal body into return struct
  583. if err := json.Unmarshal(body, returnStruct); err != nil {
  584. return err
  585. }
  586. // Validate response against swagger spec
  587. responseValidationInput := &swagger.ResponseValidationInput{
  588. RequestValidationInput: requestValidationInput,
  589. Status: resp.StatusCode,
  590. Header: resp.Header,
  591. }
  592. responseValidationInput = responseValidationInput.SetBodyBytes(body)
  593. return swagger.ValidateResponse(ctx, responseValidationInput)
  594. }
  595. func doBadReq(method, path string, reqBody io.Reader, expectedResponseCode int) error {
  596. ctx := context.Background()
  597. client := &http.Client{}
  598. httpReq, _ := http.NewRequest(method, path, reqBody)
  599. route, pathParams, err := tc.router.FindRoute(httpReq.Method, httpReq.URL)
  600. if err != nil {
  601. return err
  602. }
  603. // Validate request against swagger spec
  604. requestValidationInput := &swagger.RequestValidationInput{
  605. Request: httpReq,
  606. PathParams: pathParams,
  607. Route: route,
  608. }
  609. if err := swagger.ValidateRequest(ctx, requestValidationInput); err != nil {
  610. if expectedResponseCode != 400 {
  611. return err
  612. }
  613. log.Warn("The request does not match the API spec")
  614. }
  615. // Do API call
  616. resp, err := client.Do(httpReq)
  617. if err != nil {
  618. return err
  619. }
  620. if resp.Body == nil {
  621. return errors.New("Nil body")
  622. }
  623. //nolint
  624. defer resp.Body.Close()
  625. body, err := ioutil.ReadAll(resp.Body)
  626. if err != nil {
  627. return err
  628. }
  629. if resp.StatusCode != expectedResponseCode {
  630. return fmt.Errorf("Unexpected response code: %d", resp.StatusCode)
  631. }
  632. // Validate response against swagger spec
  633. responseValidationInput := &swagger.ResponseValidationInput{
  634. RequestValidationInput: requestValidationInput,
  635. Status: resp.StatusCode,
  636. Header: resp.Header,
  637. }
  638. responseValidationInput = responseValidationInput.SetBodyBytes(body)
  639. return swagger.ValidateResponse(ctx, responseValidationInput)
  640. }