You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

100 lines
2.0 KiB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
  1. package main
  2. import (
  3. "fmt"
  4. "sort"
  5. )
  6. func euclideanDist(img1, img2 [][]float64) float64 {
  7. var dist float64
  8. for i := 0; i < len(img1); i++ {
  9. for j := 0; j < len(img1[i]); j++ {
  10. dist += (img1[i][j] - img2[i][j]) * (img1[i][j] - img2[i][j])
  11. }
  12. }
  13. return dist
  14. }
  15. type Neighbour struct {
  16. Dist float64
  17. Label string
  18. }
  19. func isNeighbour(neighbours []Neighbour, dist float64, label string) []Neighbour {
  20. var temp []Neighbour
  21. for i := 0; i < len(neighbours); i++ {
  22. temp = append(temp, neighbours[i])
  23. }
  24. ntemp := Neighbour{dist, label}
  25. temp = append(temp, ntemp)
  26. //now, sort the temp array
  27. sort.Slice(temp, func(i, j int) bool {
  28. return temp[i].Dist < temp[j].Dist
  29. })
  30. for i := 0; i < len(neighbours); i++ {
  31. neighbours[i] = temp[i]
  32. }
  33. return neighbours
  34. }
  35. func getMapKey(dataset map[string]ImgDataset) string {
  36. for k, _ := range dataset {
  37. return k
  38. }
  39. return ""
  40. }
  41. type LabelCount struct {
  42. Label string
  43. Count int
  44. }
  45. func averageLabel(neighbours []Neighbour) string {
  46. labels := make(map[string]int)
  47. for _, n := range neighbours {
  48. labels[n.Label]++
  49. }
  50. //create array from map
  51. var a []LabelCount
  52. for k, v := range labels {
  53. a = append(a, LabelCount{k, v})
  54. }
  55. sort.Slice(a, func(i, j int) bool {
  56. return a[i].Count > a[j].Count
  57. })
  58. fmt.Println(a)
  59. //send the most appeared neighbour in k
  60. return a[0].Label
  61. }
  62. func knn(dataset Dataset, input [][]float64) string {
  63. k := 10
  64. var neighbours []Neighbour
  65. label := getMapKey(dataset)
  66. for i := 0; i < k; i++ {
  67. /*neighbours[i].Dist = euclideanDist(dataset["leopard"][0], input)
  68. neighbours[i].Label = "leopard"*/
  69. neighbours = append(neighbours, Neighbour{euclideanDist(dataset[label][0], input), label})
  70. }
  71. for l, v := range dataset {
  72. for i := 0; i < len(v); i++ {
  73. dNew := euclideanDist(v[i], input)
  74. /*if dNew < d {
  75. d = dNew
  76. label = l
  77. }*/
  78. neighbours = isNeighbour(neighbours, dNew, l)
  79. }
  80. }
  81. for i := 0; i < len(neighbours); i++ {
  82. fmt.Print(neighbours[i].Label + " - ")
  83. fmt.Println(neighbours[i].Dist)
  84. }
  85. r := averageLabel(neighbours)
  86. return r
  87. }