You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

109 lines
2.5 KiB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
  1. package main
  2. import (
  3. "fmt"
  4. "sort"
  5. )
  6. type Neighbour struct {
  7. Dist float64
  8. Label string
  9. }
  10. func euclideanDist(img1, img2 [][]float64) float64 {
  11. var dist float64
  12. for i := 0; i < len(img1); i++ {
  13. for j := 0; j < len(img1[i]); j++ {
  14. dist += (img1[i][j] - img2[i][j]) * (img1[i][j] - img2[i][j])
  15. }
  16. }
  17. return dist
  18. }
  19. func isNeighbour(neighbours []Neighbour, dist float64, label string) []Neighbour {
  20. var temp []Neighbour
  21. for i := 0; i < len(neighbours); i++ {
  22. temp = append(temp, neighbours[i])
  23. }
  24. ntemp := Neighbour{dist, label}
  25. temp = append(temp, ntemp)
  26. //now, sort the temp array
  27. sort.Slice(temp, func(i, j int) bool {
  28. return temp[i].Dist < temp[j].Dist
  29. })
  30. for i := 0; i < len(neighbours); i++ {
  31. neighbours[i] = temp[i]
  32. }
  33. return neighbours
  34. }
  35. func getMapKey(dataset map[string]ImgDataset) string {
  36. for k, _ := range dataset {
  37. return k
  38. }
  39. return ""
  40. }
  41. type LabelCount struct {
  42. Label string
  43. Count int
  44. }
  45. func averageLabel(neighbours []Neighbour) string {
  46. labels := make(map[string]int)
  47. for _, n := range neighbours {
  48. labels[n.Label]++
  49. }
  50. //create array from map
  51. var a []LabelCount
  52. for k, v := range labels {
  53. a = append(a, LabelCount{k, v})
  54. }
  55. sort.Slice(a, func(i, j int) bool {
  56. return a[i].Count > a[j].Count
  57. })
  58. fmt.Println(a)
  59. //send the most appeared neighbour in k
  60. return a[0].Label
  61. }
  62. func distNeighboursFromDataset(dataset Dataset, neighbours []Neighbour, input [][]float64) []Neighbour {
  63. //check the complete dataset, checking if each entry is a k nearest neighbour
  64. for l, v := range dataset {
  65. for i := 0; i < len(v); i++ {
  66. dNew := euclideanDist(v[i], input)
  67. neighbours = isNeighbour(neighbours, dNew, l)
  68. }
  69. }
  70. return neighbours
  71. }
  72. func knn(dataset Dataset, input [][]float64) string {
  73. k := 6
  74. var neighbours []Neighbour
  75. var neighboursED []Neighbour
  76. //get a key from map dataset, the key is a label
  77. label := getMapKey(dataset)
  78. //fill the first k neighbours
  79. for i := 0; i < k; i++ {
  80. neighbours = append(neighbours, Neighbour{euclideanDist(dataset[label][0], input), label})
  81. neighboursED = append(neighbours, Neighbour{euclideanDist(dataset[label][0], input), label})
  82. }
  83. neighbours = distNeighboursFromDataset(dataset, neighbours, input)
  84. neighboursED = distNeighboursFromDataset(datasetED, neighbours, input)
  85. neighbours = append(neighbours, neighboursED...)
  86. for i := 0; i < len(neighbours); i++ {
  87. fmt.Print(neighbours[i].Label + " - ")
  88. fmt.Println(neighbours[i].Dist)
  89. }
  90. //from the k nearest neighbours, get the more frequent neighbour
  91. r := averageLabel(neighbours)
  92. return r
  93. }