You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

64 lines
2.0 KiB

  1. from PIL import Image, ImageOps
  2. import numpy, os
  3. from sklearn.feature_extraction import image
  4. from sklearn.model_selection import KFold, cross_val_score
  5. import numpy as np
  6. import pandas as pd
  7. from time import time
  8. import pickle
  9. from sklearn.pipeline import Pipeline
  10. from sklearn.model_selection import GridSearchCV
  11. from sklearn.model_selection import RandomizedSearchCV
  12. path="dataset/"
  13. Xlist=[]
  14. Ylist=[]
  15. size = 100, 100
  16. #load images from dataset
  17. for directory in os.listdir(path):
  18. for file in os.listdir(path+directory):
  19. print(path+directory+"/"+file)
  20. img=Image.open(path+directory+"/"+file)
  21. #resize
  22. thumb = ImageOps.fit(img, size, Image.ANTIALIAS)
  23. image_data = np.array(thumb).flatten()[:100]
  24. Xlist.append(image_data)
  25. Ylist.append(directory)
  26. from sklearn.ensemble import RandomForestClassifier
  27. pipe = Pipeline([
  28. ('clf', RandomForestClassifier()),
  29. ])
  30. param_grid = dict(clf__n_estimators=[100])
  31. grid_search = GridSearchCV(pipe, param_grid=param_grid, n_jobs=-1, verbose=1, cv=3)
  32. # Utility function to report best scores
  33. def report(results, n_top=10):
  34. for i in range(1, n_top + 1):
  35. candidates = np.flatnonzero(results['rank_test_score'] == i)
  36. for candidate in candidates:
  37. print("Model with rank: {0}".format(i))
  38. print("Mean validation score: {0:.3f} (std: {1:.3f})".format(
  39. results['mean_test_score'][candidate],
  40. results['std_test_score'][candidate]))
  41. print("Parameters: {0}".format(results['params'][candidate]))
  42. print("")
  43. start = time()
  44. grid_search = GridSearchCV(pipe, param_grid, n_jobs=-1, verbose=1, cv=3)
  45. grid_search.fit(Xlist, Ylist)
  46. print("GridSearchCV took %.2f seconds for %d candidate parameter settings."
  47. % (time() - start, len(grid_search.cv_results_['params'])))
  48. print("finished GridSearch")
  49. report(grid_search.cv_results_)
  50. pickle.dump(grid_search, open('model.pkl', 'wb'))
  51. print("pipeline model saved to model.pkl")