You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

946 lines
24 KiB

  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {},
  7. "outputs": [
  8. {
  9. "name": "stderr",
  10. "output_type": "stream",
  11. "text": [
  12. "/home/asus/anaconda3/lib/python3.6/site-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
  13. " \"This module will be removed in 0.20.\", DeprecationWarning)\n"
  14. ]
  15. }
  16. ],
  17. "source": [
  18. "from PIL import Image, ImageOps\n",
  19. "import numpy, os\n",
  20. "from sklearn.ensemble import AdaBoostClassifier\n",
  21. "from sklearn.cross_validation import cross_val_score\n",
  22. "import numpy as np\n",
  23. "import pandas as pd"
  24. ]
  25. },
  26. {
  27. "cell_type": "code",
  28. "execution_count": 2,
  29. "metadata": {
  30. "collapsed": true
  31. },
  32. "outputs": [],
  33. "source": [
  34. "path=\"dataset/\"\n",
  35. "Xlist=[]\n",
  36. "Ylist=[]\n",
  37. "size = 100, 100"
  38. ]
  39. },
  40. {
  41. "cell_type": "code",
  42. "execution_count": 3,
  43. "metadata": {},
  44. "outputs": [
  45. {
  46. "name": "stdout",
  47. "output_type": "stream",
  48. "text": [
  49. "dataset/object/27.png\n",
  50. "dataset/object/82.png\n",
  51. "dataset/object/83.png\n",
  52. "dataset/object/100.png\n",
  53. "dataset/object/0.png\n",
  54. "dataset/object/13.png\n",
  55. "dataset/object/45.png\n",
  56. "dataset/object/64.png\n",
  57. "dataset/object/19.png\n",
  58. "dataset/object/101.png\n",
  59. "dataset/object/40.png\n",
  60. "dataset/object/97.png\n",
  61. "dataset/object/41.png\n",
  62. "dataset/object/7.png\n",
  63. "dataset/object/66.png\n",
  64. "dataset/object/55.png\n",
  65. "dataset/object/56.png\n",
  66. "dataset/object/65.png\n",
  67. "dataset/object/18.png\n",
  68. "dataset/object/24.png\n",
  69. "dataset/object/105.png\n",
  70. "dataset/object/116.png\n",
  71. "dataset/object/117.png\n",
  72. "dataset/object/104.png\n",
  73. "dataset/object/63.png\n",
  74. "dataset/object/38.png\n",
  75. "dataset/object/58.png\n",
  76. "dataset/object/103.png\n",
  77. "dataset/object/112.png\n",
  78. "dataset/object/33.png\n",
  79. "dataset/object/76.png\n",
  80. "dataset/object/59.png\n",
  81. "dataset/object/96.png\n",
  82. "dataset/object/91.png\n",
  83. "dataset/object/57.png\n",
  84. "dataset/object/2.png\n",
  85. "dataset/object/75.png\n",
  86. "dataset/object/107.png\n",
  87. "dataset/object/50.png\n",
  88. "dataset/object/16.png\n",
  89. "dataset/object/32.png\n",
  90. "dataset/object/15.png\n",
  91. "dataset/object/5.png\n",
  92. "dataset/object/72.png\n",
  93. "dataset/object/52.png\n",
  94. "dataset/object/4.png\n",
  95. "dataset/object/28.png\n",
  96. "dataset/object/43.png\n",
  97. "dataset/object/87.png\n",
  98. "dataset/object/98.png\n",
  99. "dataset/object/71.png\n",
  100. "dataset/object/102.png\n",
  101. "dataset/object/62.png\n",
  102. "dataset/object/9.png\n",
  103. "dataset/object/6.png\n",
  104. "dataset/object/85.png\n",
  105. "dataset/object/70.png\n",
  106. "dataset/object/42.png\n",
  107. "dataset/object/34.png\n",
  108. "dataset/object/81.png\n",
  109. "dataset/object/94.png\n",
  110. "dataset/object/26.png\n",
  111. "dataset/object/90.png\n",
  112. "dataset/object/44.png\n",
  113. "dataset/object/60.png\n",
  114. "dataset/object/17.png\n",
  115. "dataset/object/10.png\n",
  116. "dataset/object/53.png\n",
  117. "dataset/object/25.png\n",
  118. "dataset/object/21.png\n",
  119. "dataset/object/22.png\n",
  120. "dataset/object/30.png\n",
  121. "dataset/object/78.png\n",
  122. "dataset/object/118.png\n",
  123. "dataset/object/110.png\n",
  124. "dataset/object/79.png\n",
  125. "dataset/object/77.png\n",
  126. "dataset/object/12.png\n",
  127. "dataset/object/115.png\n",
  128. "dataset/object/67.png\n",
  129. "dataset/object/84.png\n",
  130. "dataset/object/11.png\n",
  131. "dataset/object/86.png\n",
  132. "dataset/object/89.png\n",
  133. "dataset/object/113.png\n",
  134. "dataset/noobject/image_0056.jpg\n",
  135. "dataset/noobject/image_0181.jpg\n",
  136. "dataset/noobject/image_0127.jpg\n",
  137. "dataset/noobject/image_0142.jpg\n",
  138. "dataset/noobject/image_0025.jpg\n",
  139. "dataset/noobject/image_0065.jpg\n",
  140. "dataset/noobject/image_0174.jpg\n",
  141. "dataset/noobject/image_0091.jpg\n",
  142. "dataset/noobject/image_0124.jpg\n",
  143. "dataset/noobject/image_0086.jpg\n",
  144. "dataset/noobject/image_0079.jpg\n",
  145. "dataset/noobject/image_0058.jpg\n",
  146. "dataset/noobject/image_0060.jpg\n",
  147. "dataset/noobject/image_0119.jpg\n",
  148. "dataset/noobject/image_0023.jpg\n",
  149. "dataset/noobject/image_0075.jpg\n",
  150. "dataset/noobject/image_0020.jpg\n",
  151. "dataset/noobject/image_0013.jpg\n",
  152. "dataset/noobject/image_0126.jpg\n",
  153. "dataset/noobject/image_0012.jpg\n",
  154. "dataset/noobject/image_0055.jpg\n",
  155. "dataset/noobject/image_0176.jpg\n",
  156. "dataset/noobject/image_0144.jpg\n",
  157. "dataset/noobject/image_0048.jpg\n",
  158. "dataset/noobject/image_0121.jpg\n",
  159. "dataset/noobject/image_0070.jpg\n",
  160. "dataset/noobject/image_0082.jpg\n",
  161. "dataset/noobject/image_0095.jpg\n",
  162. "dataset/noobject/image_0022.jpg\n",
  163. "dataset/noobject/image_0120.jpg\n",
  164. "dataset/noobject/image_0139.jpg\n",
  165. "dataset/noobject/image_0073.jpg\n",
  166. "dataset/noobject/image_0090.jpg\n",
  167. "dataset/noobject/image_0145.jpg\n",
  168. "dataset/noobject/image_0173.jpg\n",
  169. "dataset/noobject/image_0078.jpg\n",
  170. "dataset/noobject/image_0085.jpg\n",
  171. "dataset/noobject/image_0083.jpg\n",
  172. "dataset/noobject/image_0179.jpg\n",
  173. "dataset/noobject/image_0050.jpg\n",
  174. "dataset/noobject/image_0076.jpg\n",
  175. "dataset/noobject/image_0014.jpg\n",
  176. "dataset/noobject/image_0054.jpg\n",
  177. "dataset/noobject/image_0066.jpg\n",
  178. "dataset/noobject/image_0001.jpg\n",
  179. "dataset/noobject/image_0047.jpg\n",
  180. "dataset/noobject/image_0077.jpg\n",
  181. "dataset/noobject/image_0122.jpg\n",
  182. "dataset/noobject/image_0068.jpg\n",
  183. "dataset/noobject/image_0049.jpg\n",
  184. "dataset/noobject/image_0092.jpg\n",
  185. "dataset/noobject/image_0138.jpg\n",
  186. "dataset/noobject/image_0072.jpg\n",
  187. "dataset/noobject/image_0146.jpg\n",
  188. "dataset/noobject/image_0061.jpg\n",
  189. "dataset/noobject/image_0011.jpg\n",
  190. "dataset/noobject/image_0002.jpg\n",
  191. "dataset/noobject/image_0143.jpg\n",
  192. "dataset/noobject/image_0088.jpg\n",
  193. "dataset/noobject/image_0062.jpg\n",
  194. "dataset/noobject/image_0089.jpg\n",
  195. "dataset/noobject/image_0018.jpg\n",
  196. "dataset/noobject/image_0024.jpg\n",
  197. "dataset/noobject/image_0064.jpg\n",
  198. "dataset/noobject/image_0074.jpg\n",
  199. "dataset/noobject/image_0052.jpg\n",
  200. "dataset/noobject/image_0096.jpg\n",
  201. "dataset/noobject/image_0178.jpg\n",
  202. "dataset/noobject/image_0067.jpg\n",
  203. "dataset/noobject/image_0140.jpg\n",
  204. "dataset/noobject/image_0084.jpg\n",
  205. "dataset/noobject/image_0010.jpg\n",
  206. "dataset/noobject/image_0081.jpg\n",
  207. "dataset/noobject/image_0059.jpg\n",
  208. "dataset/noobject/image_0016.jpg\n",
  209. "dataset/noobject/image_0175.jpg\n",
  210. "dataset/noobject/image_0094.jpg\n",
  211. "dataset/noobject/image_0071.jpg\n",
  212. "dataset/noobject/image_0080.jpg\n",
  213. "dataset/noobject/image_0125.jpg\n",
  214. "dataset/noobject/image_0008.jpg\n",
  215. "dataset/noobject/image_0019.jpg\n",
  216. "dataset/noobject/image_0017.jpg\n",
  217. "dataset/noobject/image_0180.jpg\n"
  218. ]
  219. }
  220. ],
  221. "source": [
  222. "for directory in os.listdir(path):\n",
  223. " for file in os.listdir(path+directory):\n",
  224. " print(path+directory+\"/\"+file)\n",
  225. " img=Image.open(path+directory+\"/\"+file)\n",
  226. " #resize\n",
  227. " thumb = ImageOps.fit(img, size, Image.ANTIALIAS)\n",
  228. " image_data = np.array(thumb).flatten()[:100]\n",
  229. " #image_data=numpy.array(img).flatten()[:50] #in my case the images dont have the same dimensions, so [:50] only takes the first 50 values\n",
  230. " Xlist.append(image_data)\n",
  231. " Ylist.append(directory)"
  232. ]
  233. },
  234. {
  235. "cell_type": "code",
  236. "execution_count": 4,
  237. "metadata": {
  238. "collapsed": true
  239. },
  240. "outputs": [],
  241. "source": [
  242. "from sklearn.model_selection import train_test_split\n",
  243. "X_train, X_test, y_train, y_test = train_test_split(Xlist, Ylist, test_size=0.2)"
  244. ]
  245. },
  246. {
  247. "cell_type": "markdown",
  248. "metadata": {},
  249. "source": [
  250. "### AdaBoostClassifier"
  251. ]
  252. },
  253. {
  254. "cell_type": "code",
  255. "execution_count": 5,
  256. "metadata": {
  257. "collapsed": true
  258. },
  259. "outputs": [],
  260. "source": [
  261. "clf=AdaBoostClassifier(n_estimators=100)"
  262. ]
  263. },
  264. {
  265. "cell_type": "code",
  266. "execution_count": 6,
  267. "metadata": {
  268. "collapsed": true
  269. },
  270. "outputs": [],
  271. "source": [
  272. "scores = cross_val_score(clf, X_train, y_train, cv=3)"
  273. ]
  274. },
  275. {
  276. "cell_type": "code",
  277. "execution_count": 7,
  278. "metadata": {},
  279. "outputs": [
  280. {
  281. "name": "stdout",
  282. "output_type": "stream",
  283. "text": [
  284. "0.77037037037\n"
  285. ]
  286. }
  287. ],
  288. "source": [
  289. "print(scores.mean())"
  290. ]
  291. },
  292. {
  293. "cell_type": "markdown",
  294. "metadata": {},
  295. "source": [
  296. "### GaussianNB"
  297. ]
  298. },
  299. {
  300. "cell_type": "code",
  301. "execution_count": 8,
  302. "metadata": {
  303. "collapsed": true
  304. },
  305. "outputs": [],
  306. "source": [
  307. "from sklearn.naive_bayes import GaussianNB"
  308. ]
  309. },
  310. {
  311. "cell_type": "code",
  312. "execution_count": 9,
  313. "metadata": {
  314. "collapsed": true
  315. },
  316. "outputs": [],
  317. "source": [
  318. "clf = GaussianNB()"
  319. ]
  320. },
  321. {
  322. "cell_type": "code",
  323. "execution_count": 10,
  324. "metadata": {
  325. "collapsed": true
  326. },
  327. "outputs": [],
  328. "source": [
  329. "scores = cross_val_score(clf, Xlist, Ylist)"
  330. ]
  331. },
  332. {
  333. "cell_type": "code",
  334. "execution_count": 11,
  335. "metadata": {},
  336. "outputs": [
  337. {
  338. "name": "stdout",
  339. "output_type": "stream",
  340. "text": [
  341. "0.721908939014\n"
  342. ]
  343. }
  344. ],
  345. "source": [
  346. "print(scores.mean())"
  347. ]
  348. },
  349. {
  350. "cell_type": "markdown",
  351. "metadata": {},
  352. "source": [
  353. "### KNeighborsClassifier"
  354. ]
  355. },
  356. {
  357. "cell_type": "code",
  358. "execution_count": 12,
  359. "metadata": {
  360. "collapsed": true
  361. },
  362. "outputs": [],
  363. "source": [
  364. "from sklearn.neighbors import KNeighborsClassifier"
  365. ]
  366. },
  367. {
  368. "cell_type": "code",
  369. "execution_count": 13,
  370. "metadata": {
  371. "collapsed": true
  372. },
  373. "outputs": [],
  374. "source": [
  375. "clf = KNeighborsClassifier(n_neighbors=10)"
  376. ]
  377. },
  378. {
  379. "cell_type": "code",
  380. "execution_count": 14,
  381. "metadata": {
  382. "collapsed": true
  383. },
  384. "outputs": [],
  385. "source": [
  386. "scores = cross_val_score(clf, Xlist, Ylist)"
  387. ]
  388. },
  389. {
  390. "cell_type": "code",
  391. "execution_count": 15,
  392. "metadata": {},
  393. "outputs": [
  394. {
  395. "name": "stdout",
  396. "output_type": "stream",
  397. "text": [
  398. "0.751357560568\n"
  399. ]
  400. }
  401. ],
  402. "source": [
  403. "print(scores.mean())"
  404. ]
  405. },
  406. {
  407. "cell_type": "markdown",
  408. "metadata": {},
  409. "source": [
  410. "### LinearSVC"
  411. ]
  412. },
  413. {
  414. "cell_type": "code",
  415. "execution_count": 16,
  416. "metadata": {
  417. "collapsed": true
  418. },
  419. "outputs": [],
  420. "source": [
  421. "from sklearn.svm import LinearSVC"
  422. ]
  423. },
  424. {
  425. "cell_type": "code",
  426. "execution_count": 17,
  427. "metadata": {
  428. "collapsed": true
  429. },
  430. "outputs": [],
  431. "source": [
  432. "clf = LinearSVC()"
  433. ]
  434. },
  435. {
  436. "cell_type": "code",
  437. "execution_count": 18,
  438. "metadata": {
  439. "collapsed": true
  440. },
  441. "outputs": [],
  442. "source": [
  443. "scores = cross_val_score(clf, Xlist, Ylist)"
  444. ]
  445. },
  446. {
  447. "cell_type": "code",
  448. "execution_count": 19,
  449. "metadata": {},
  450. "outputs": [
  451. {
  452. "name": "stdout",
  453. "output_type": "stream",
  454. "text": [
  455. "0.638575605681\n"
  456. ]
  457. }
  458. ],
  459. "source": [
  460. "print(scores.mean())"
  461. ]
  462. },
  463. {
  464. "cell_type": "markdown",
  465. "metadata": {},
  466. "source": [
  467. "### SVC"
  468. ]
  469. },
  470. {
  471. "cell_type": "code",
  472. "execution_count": 20,
  473. "metadata": {
  474. "collapsed": true
  475. },
  476. "outputs": [],
  477. "source": [
  478. "from sklearn.svm import SVC"
  479. ]
  480. },
  481. {
  482. "cell_type": "code",
  483. "execution_count": 21,
  484. "metadata": {
  485. "collapsed": true
  486. },
  487. "outputs": [],
  488. "source": [
  489. "clf = SVC()"
  490. ]
  491. },
  492. {
  493. "cell_type": "code",
  494. "execution_count": 22,
  495. "metadata": {
  496. "collapsed": true
  497. },
  498. "outputs": [],
  499. "source": [
  500. "scores = cross_val_score(clf, Xlist, Ylist)"
  501. ]
  502. },
  503. {
  504. "cell_type": "code",
  505. "execution_count": 23,
  506. "metadata": {},
  507. "outputs": [
  508. {
  509. "name": "stdout",
  510. "output_type": "stream",
  511. "text": [
  512. "0.668650793651\n"
  513. ]
  514. }
  515. ],
  516. "source": [
  517. "print(scores.mean())"
  518. ]
  519. },
  520. {
  521. "cell_type": "markdown",
  522. "metadata": {},
  523. "source": [
  524. "### GaussianProcessClassifier"
  525. ]
  526. },
  527. {
  528. "cell_type": "code",
  529. "execution_count": 24,
  530. "metadata": {
  531. "collapsed": true
  532. },
  533. "outputs": [],
  534. "source": [
  535. "from sklearn.gaussian_process import GaussianProcessClassifier"
  536. ]
  537. },
  538. {
  539. "cell_type": "code",
  540. "execution_count": 25,
  541. "metadata": {
  542. "collapsed": true
  543. },
  544. "outputs": [],
  545. "source": [
  546. "clf = GaussianProcessClassifier()"
  547. ]
  548. },
  549. {
  550. "cell_type": "code",
  551. "execution_count": 26,
  552. "metadata": {
  553. "collapsed": true
  554. },
  555. "outputs": [],
  556. "source": [
  557. "scores = cross_val_score(clf, Xlist, Ylist)"
  558. ]
  559. },
  560. {
  561. "cell_type": "code",
  562. "execution_count": 27,
  563. "metadata": {},
  564. "outputs": [
  565. {
  566. "name": "stdout",
  567. "output_type": "stream",
  568. "text": [
  569. "0.491228070175\n"
  570. ]
  571. }
  572. ],
  573. "source": [
  574. "print(scores.mean())"
  575. ]
  576. },
  577. {
  578. "cell_type": "markdown",
  579. "metadata": {},
  580. "source": [
  581. "### RandomForestClassifier"
  582. ]
  583. },
  584. {
  585. "cell_type": "code",
  586. "execution_count": 28,
  587. "metadata": {
  588. "collapsed": true
  589. },
  590. "outputs": [],
  591. "source": [
  592. "from sklearn.ensemble import RandomForestClassifier"
  593. ]
  594. },
  595. {
  596. "cell_type": "code",
  597. "execution_count": 29,
  598. "metadata": {
  599. "collapsed": true
  600. },
  601. "outputs": [],
  602. "source": [
  603. "clf = RandomForestClassifier()"
  604. ]
  605. },
  606. {
  607. "cell_type": "code",
  608. "execution_count": 30,
  609. "metadata": {
  610. "collapsed": true
  611. },
  612. "outputs": [],
  613. "source": [
  614. "scores = cross_val_score(clf, Xlist, Ylist)"
  615. ]
  616. },
  617. {
  618. "cell_type": "code",
  619. "execution_count": 31,
  620. "metadata": {},
  621. "outputs": [
  622. {
  623. "name": "stdout",
  624. "output_type": "stream",
  625. "text": [
  626. "0.710317460317\n"
  627. ]
  628. }
  629. ],
  630. "source": [
  631. "print(scores.mean())"
  632. ]
  633. },
  634. {
  635. "cell_type": "markdown",
  636. "metadata": {},
  637. "source": [
  638. "# Hyperparameters Tuning using sklearn pipeline and gridsearch"
  639. ]
  640. },
  641. {
  642. "cell_type": "code",
  643. "execution_count": 32,
  644. "metadata": {
  645. "collapsed": true
  646. },
  647. "outputs": [],
  648. "source": [
  649. "from sklearn.pipeline import Pipeline\n",
  650. "from sklearn.model_selection import GridSearchCV\n",
  651. "from sklearn.model_selection import RandomizedSearchCV"
  652. ]
  653. },
  654. {
  655. "cell_type": "code",
  656. "execution_count": 33,
  657. "metadata": {
  658. "collapsed": true
  659. },
  660. "outputs": [],
  661. "source": [
  662. "pipe = [Pipeline([\n",
  663. " ('clf', RandomForestClassifier()),\n",
  664. " ]),\n",
  665. " Pipeline([\n",
  666. " ('clf', KNeighborsClassifier()),\n",
  667. " ]),\n",
  668. " Pipeline([\n",
  669. " ('clf', GaussianProcessClassifier()),\n",
  670. " ]),\n",
  671. " Pipeline([\n",
  672. " ('clf', AdaBoostClassifier()),\n",
  673. " ]),\n",
  674. " Pipeline([\n",
  675. " ('clf', SVC()),\n",
  676. " ]),\n",
  677. "]"
  678. ]
  679. },
  680. {
  681. "cell_type": "code",
  682. "execution_count": 34,
  683. "metadata": {
  684. "collapsed": true
  685. },
  686. "outputs": [],
  687. "source": [
  688. "param_grid = [dict(clf__n_estimators=[3, 10, 100]),\n",
  689. " dict(clf__n_neighbors=[3,10]),\n",
  690. " dict(clf__n_restarts_optimizer=[0,1]),\n",
  691. " dict(clf__n_estimators=[3, 10, 100]),\n",
  692. " dict(clf__C=[3, 10, 100]),\n",
  693. " ]"
  694. ]
  695. },
  696. {
  697. "cell_type": "code",
  698. "execution_count": 35,
  699. "metadata": {
  700. "collapsed": true
  701. },
  702. "outputs": [],
  703. "source": [
  704. "grid_search = GridSearchCV(pipe, param_grid=param_grid, n_jobs=-1, verbose=1, cv=3)"
  705. ]
  706. },
  707. {
  708. "cell_type": "code",
  709. "execution_count": 36,
  710. "metadata": {
  711. "collapsed": true
  712. },
  713. "outputs": [],
  714. "source": [
  715. "#grid_search.fit(Xlist, Ylist)"
  716. ]
  717. },
  718. {
  719. "cell_type": "code",
  720. "execution_count": 37,
  721. "metadata": {
  722. "collapsed": true
  723. },
  724. "outputs": [],
  725. "source": [
  726. "# Utility function to report best scores\n",
  727. "def report(results, n_top=10):\n",
  728. " for i in range(1, n_top + 1):\n",
  729. " candidates = np.flatnonzero(results['rank_test_score'] == i)\n",
  730. " for candidate in candidates:\n",
  731. " print(\"Model with rank: {0}\".format(i))\n",
  732. " print(\"Mean validation score: {0:.3f} (std: {1:.3f})\".format(\n",
  733. " results['mean_test_score'][candidate],\n",
  734. " results['std_test_score'][candidate]))\n",
  735. " print(\"Parameters: {0}\".format(results['params'][candidate]))\n",
  736. " print(\"\")"
  737. ]
  738. },
  739. {
  740. "cell_type": "code",
  741. "execution_count": 38,
  742. "metadata": {
  743. "scrolled": false
  744. },
  745. "outputs": [
  746. {
  747. "name": "stdout",
  748. "output_type": "stream",
  749. "text": [
  750. "-----\n",
  751. "classifier:\n",
  752. "RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n",
  753. " max_depth=None, max_features='auto', max_leaf_nodes=None,\n",
  754. " min_impurity_decrease=0.0, min_impurity_split=None,\n",
  755. " min_samples_leaf=1, min_samples_split=2,\n",
  756. " min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,\n",
  757. " oob_score=False, random_state=None, verbose=0,\n",
  758. " warm_start=False)\n",
  759. "Fitting 3 folds for each of 3 candidates, totalling 9 fits\n"
  760. ]
  761. },
  762. {
  763. "name": "stderr",
  764. "output_type": "stream",
  765. "text": [
  766. "[Parallel(n_jobs=-1)]: Done 9 out of 9 | elapsed: 1.5s finished\n"
  767. ]
  768. },
  769. {
  770. "name": "stdout",
  771. "output_type": "stream",
  772. "text": [
  773. "GridSearchCV took 2.38 seconds for 3 candidate parameter settings.\n",
  774. "finished GridSearch\n",
  775. "Model with rank: 1\n",
  776. "Mean validation score: 0.815 (std: 0.073)\n",
  777. "Parameters: {'clf__n_estimators': 100}\n",
  778. "\n",
  779. "Model with rank: 2\n",
  780. "Mean validation score: 0.763 (std: 0.093)\n",
  781. "Parameters: {'clf__n_estimators': 10}\n",
  782. "\n",
  783. "Model with rank: 3\n",
  784. "Mean validation score: 0.756 (std: 0.110)\n",
  785. "Parameters: {'clf__n_estimators': 3}\n",
  786. "\n",
  787. "-----\n",
  788. "classifier:\n",
  789. "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
  790. " metric_params=None, n_jobs=1, n_neighbors=5, p=2,\n",
  791. " weights='uniform')\n",
  792. "Fitting 3 folds for each of 2 candidates, totalling 6 fits\n",
  793. "GridSearchCV took 0.23 seconds for 2 candidate parameter settings.\n",
  794. "finished GridSearch\n",
  795. "Model with rank: 1\n",
  796. "Mean validation score: 0.778 (std: 0.048)\n",
  797. "Parameters: {'clf__n_neighbors': 3}\n",
  798. "\n",
  799. "Model with rank: 2\n",
  800. "Mean validation score: 0.704 (std: 0.010)\n",
  801. "Parameters: {'clf__n_neighbors': 10}\n",
  802. "\n",
  803. "-----\n",
  804. "classifier:\n",
  805. "GaussianProcessClassifier(copy_X_train=True, kernel=None,\n",
  806. " max_iter_predict=100, multi_class='one_vs_rest', n_jobs=1,\n",
  807. " n_restarts_optimizer=0, optimizer='fmin_l_bfgs_b',\n",
  808. " random_state=None, warm_start=False)\n",
  809. "Fitting 3 folds for each of 2 candidates, totalling 6 fits\n"
  810. ]
  811. },
  812. {
  813. "name": "stderr",
  814. "output_type": "stream",
  815. "text": [
  816. "[Parallel(n_jobs=-1)]: Done 6 out of 6 | elapsed: 0.1s remaining: 0.0s\n",
  817. "[Parallel(n_jobs=-1)]: Done 6 out of 6 | elapsed: 0.1s finished\n"
  818. ]
  819. },
  820. {
  821. "name": "stdout",
  822. "output_type": "stream",
  823. "text": [
  824. "GridSearchCV took 0.36 seconds for 2 candidate parameter settings.\n",
  825. "finished GridSearch\n",
  826. "Model with rank: 1\n",
  827. "Mean validation score: 0.489 (std: 0.000)\n",
  828. "Parameters: {'clf__n_restarts_optimizer': 0}\n",
  829. "\n",
  830. "Model with rank: 1\n",
  831. "Mean validation score: 0.489 (std: 0.000)\n",
  832. "Parameters: {'clf__n_restarts_optimizer': 1}\n",
  833. "\n",
  834. "-----\n",
  835. "classifier:\n",
  836. "AdaBoostClassifier(algorithm='SAMME.R', base_estimator=None,\n",
  837. " learning_rate=1.0, n_estimators=50, random_state=None)\n",
  838. "Fitting 3 folds for each of 3 candidates, totalling 9 fits\n"
  839. ]
  840. },
  841. {
  842. "name": "stderr",
  843. "output_type": "stream",
  844. "text": [
  845. "[Parallel(n_jobs=-1)]: Done 6 out of 6 | elapsed: 0.2s remaining: 0.0s\n",
  846. "[Parallel(n_jobs=-1)]: Done 6 out of 6 | elapsed: 0.2s finished\n",
  847. "[Parallel(n_jobs=-1)]: Done 9 out of 9 | elapsed: 0.9s finished\n"
  848. ]
  849. },
  850. {
  851. "name": "stdout",
  852. "output_type": "stream",
  853. "text": [
  854. "GridSearchCV took 1.16 seconds for 3 candidate parameter settings.\n",
  855. "finished GridSearch\n",
  856. "Model with rank: 1\n",
  857. "Mean validation score: 0.807 (std: 0.093)\n",
  858. "Parameters: {'clf__n_estimators': 3}\n",
  859. "\n",
  860. "Model with rank: 2\n",
  861. "Mean validation score: 0.756 (std: 0.048)\n",
  862. "Parameters: {'clf__n_estimators': 100}\n",
  863. "\n",
  864. "Model with rank: 3\n",
  865. "Mean validation score: 0.733 (std: 0.054)\n",
  866. "Parameters: {'clf__n_estimators': 10}\n",
  867. "\n",
  868. "-----\n",
  869. "classifier:\n",
  870. "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n",
  871. " decision_function_shape='ovr', degree=3, gamma='auto', kernel='rbf',\n",
  872. " max_iter=-1, probability=False, random_state=None, shrinking=True,\n",
  873. " tol=0.001, verbose=False)\n",
  874. "Fitting 3 folds for each of 3 candidates, totalling 9 fits\n",
  875. "GridSearchCV took 0.35 seconds for 3 candidate parameter settings.\n",
  876. "finished GridSearch\n",
  877. "Model with rank: 1\n",
  878. "Mean validation score: 0.689 (std: 0.031)\n",
  879. "Parameters: {'clf__C': 3}\n",
  880. "\n",
  881. "Model with rank: 1\n",
  882. "Mean validation score: 0.689 (std: 0.031)\n",
  883. "Parameters: {'clf__C': 10}\n",
  884. "\n",
  885. "Model with rank: 1\n",
  886. "Mean validation score: 0.689 (std: 0.031)\n",
  887. "Parameters: {'clf__C': 100}\n",
  888. "\n"
  889. ]
  890. },
  891. {
  892. "name": "stderr",
  893. "output_type": "stream",
  894. "text": [
  895. "[Parallel(n_jobs=-1)]: Done 9 out of 9 | elapsed: 0.1s finished\n"
  896. ]
  897. }
  898. ],
  899. "source": [
  900. "from time import time\n",
  901. "\n",
  902. "for i in range(len(pipe)):\n",
  903. " start = time()\n",
  904. " print(\"-----\")\n",
  905. " print(\"classifier:\")\n",
  906. " print(pipe[i].named_steps['clf'])\n",
  907. " grid_search = GridSearchCV(pipe[i], param_grid[i], n_jobs=-1, verbose=1, cv=3)\n",
  908. " grid_search.fit(X_train, y_train)\n",
  909. " print(\"GridSearchCV took %.2f seconds for %d candidate parameter settings.\"\n",
  910. " % (time() - start, len(grid_search.cv_results_['params'])))\n",
  911. " print(\"finished GridSearch\")\n",
  912. " report(grid_search.cv_results_)"
  913. ]
  914. },
  915. {
  916. "cell_type": "code",
  917. "execution_count": null,
  918. "metadata": {
  919. "collapsed": true
  920. },
  921. "outputs": [],
  922. "source": []
  923. }
  924. ],
  925. "metadata": {
  926. "kernelspec": {
  927. "display_name": "Python 3",
  928. "language": "python",
  929. "name": "python3"
  930. },
  931. "language_info": {
  932. "codemirror_mode": {
  933. "name": "ipython",
  934. "version": 3
  935. },
  936. "file_extension": ".py",
  937. "mimetype": "text/x-python",
  938. "name": "python",
  939. "nbconvert_exporter": "python",
  940. "pygments_lexer": "ipython3",
  941. "version": "3.6.3"
  942. }
  943. },
  944. "nbformat": 4,
  945. "nbformat_minor": 2
  946. }