You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

62 lines
2.0 KiB

  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. from random import randint
  4. import pickle
  5. from sknn.mlp import Classifier, Layer, Convolution
  6. def datasetToTrainAndTestData(dataset, numtest):
  7. np.random.shuffle(dataset)
  8. print "length total data:" + str(len(dataset))
  9. traindata = np.copy(dataset)
  10. testdata = []
  11. for i in range(numtest):
  12. #get random integer between 0 and the total amount of images in the dataset
  13. n = randint(0, len(traindata))
  14. testdata.append(dataset[n])
  15. #delete the n image (dataset[n]) of the traindata
  16. traindata = np.delete(traindata, n, axis=0)
  17. testdataNP = np.array(testdata)
  18. return traindata, testdataNP
  19. #read the dataset made with the 'imagesToDataset' repository
  20. dataset = np.load('dataset.npy')
  21. traindata, testdata = datasetToTrainAndTestData(dataset, 10)
  22. print "length traindata: " + str(len(traindata))
  23. print "length testdata: " + str(len(testdata))
  24. #traindataAttributes contains all the pixels of each image
  25. traindataAttributes = traindata[:,0]
  26. traindataAttributes = np.array([[row] for row in traindataAttributes])
  27. #traindataLabels contains each label of each image
  28. traindataLabels = traindata[:,1]
  29. traindataLabels = traindataLabels.astype('int')
  30. #testdataAttributes contains the pixels of the test images
  31. testdataAttributes = testdata[:,0]
  32. testdataAttributes = np.array([[row] for row in testdataAttributes])
  33. #testdataLabels contains each label of each image
  34. testdataLabels = testdata[:,1]
  35. testdataLabels = testdataLabels.astype('int')
  36. #default: units=100, learning_rate=0.001, n_iter=25
  37. nn = Classifier(
  38. layers=[
  39. Layer("Sigmoid", units=10),
  40. Layer("Softmax")],
  41. learning_rate=0.001,
  42. n_iter=20,
  43. verbose=True)
  44. nn.fit(traindataAttributes, traindataLabels)
  45. print('\nTRAIN SCORE', nn.score(traindataAttributes, traindataLabels))
  46. print('TEST SCORE', nn.score(testdataAttributes, testdataLabels))
  47. #save the neural network configuration
  48. pickle.dump(nn, open('nn.pkl', 'wb'))