You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

308 lines
6.8 KiB

  1. package main
  2. import (
  3. "fmt"
  4. "go/ast"
  5. "go/parser"
  6. "go/token"
  7. "io/ioutil"
  8. "os"
  9. "path/filepath"
  10. "reflect"
  11. "sort"
  12. "testing"
  13. "text/template"
  14. "github.com/tinylib/msgp/gen"
  15. )
  16. // When stuff's going wrong, you'll be glad this is here!
  17. const debugTemp = false
  18. // Ensure that consistent identifiers are generated on a per-method basis by msgp.
  19. //
  20. // Also ensure that no duplicate identifiers appear in a method.
  21. //
  22. // structs are currently processed alphabetically by msgp. this test relies on
  23. // that property.
  24. //
  25. func TestIssue185Idents(t *testing.T) {
  26. var identCases = []struct {
  27. tpl *template.Template
  28. expectedChanged []string
  29. }{
  30. {tpl: issue185IdentsTpl, expectedChanged: []string{"Test1"}},
  31. {tpl: issue185ComplexIdentsTpl, expectedChanged: []string{"Test2"}},
  32. }
  33. methods := []string{"DecodeMsg", "EncodeMsg", "Msgsize", "MarshalMsg", "UnmarshalMsg"}
  34. for idx, identCase := range identCases {
  35. // generate the code, extract the generated variable names, mapped to function name
  36. var tplData issue185TplData
  37. varsBefore, err := loadVars(identCase.tpl, tplData)
  38. if err != nil {
  39. t.Fatalf("%d: could not extract before vars: %v", idx, err)
  40. }
  41. // regenerate the code with extra field(s), extract the generated variable
  42. // names, mapped to function name
  43. tplData.Extra = true
  44. varsAfter, err := loadVars(identCase.tpl, tplData)
  45. if err != nil {
  46. t.Fatalf("%d: could not extract after vars: %v", idx, err)
  47. }
  48. // ensure that all declared variable names inside each of the methods we
  49. // expect to change have actually changed
  50. for _, stct := range identCase.expectedChanged {
  51. for _, method := range methods {
  52. fn := fmt.Sprintf("%s.%s", stct, method)
  53. bv, av := varsBefore.Value(fn), varsAfter.Value(fn)
  54. if len(bv) > 0 && len(av) > 0 && reflect.DeepEqual(bv, av) {
  55. t.Fatalf("%d vars identical! expected vars to change for %s", idx, fn)
  56. }
  57. delete(varsBefore, fn)
  58. delete(varsAfter, fn)
  59. }
  60. }
  61. // all of the remaining keys should not have changed
  62. for bmethod, bvars := range varsBefore {
  63. avars := varsAfter.Value(bmethod)
  64. if !reflect.DeepEqual(bvars, avars) {
  65. t.Fatalf("%d: vars changed! expected vars identical for %s", idx, bmethod)
  66. }
  67. delete(varsBefore, bmethod)
  68. delete(varsAfter, bmethod)
  69. }
  70. if len(varsBefore) > 0 || len(varsAfter) > 0 {
  71. t.Fatalf("%d: unexpected methods remaining", idx)
  72. }
  73. }
  74. }
  75. type issue185TplData struct {
  76. Extra bool
  77. }
  78. func TestIssue185Overlap(t *testing.T) {
  79. var overlapCases = []struct {
  80. tpl *template.Template
  81. data issue185TplData
  82. }{
  83. {tpl: issue185IdentsTpl, data: issue185TplData{Extra: false}},
  84. {tpl: issue185IdentsTpl, data: issue185TplData{Extra: true}},
  85. {tpl: issue185ComplexIdentsTpl, data: issue185TplData{Extra: false}},
  86. {tpl: issue185ComplexIdentsTpl, data: issue185TplData{Extra: true}},
  87. }
  88. for idx, o := range overlapCases {
  89. // regenerate the code with extra field(s), extract the generated variable
  90. // names, mapped to function name
  91. mvars, err := loadVars(o.tpl, o.data)
  92. if err != nil {
  93. t.Fatalf("%d: could not extract after vars: %v", idx, err)
  94. }
  95. identCnt := 0
  96. for fn, vars := range mvars {
  97. sort.Strings(vars)
  98. // Loose sanity check to make sure the tests expectations aren't broken.
  99. // If the prefix ever changes, this needs to change.
  100. for _, v := range vars {
  101. if v[0] == 'z' {
  102. identCnt++
  103. }
  104. }
  105. for i := 0; i < len(vars)-1; i++ {
  106. if vars[i] == vars[i+1] {
  107. t.Fatalf("%d: duplicate var %s in function %s", idx, vars[i], fn)
  108. }
  109. }
  110. }
  111. // one last sanity check: if there aren't any vars that start with 'z',
  112. // this test's expectations are unsatisfiable.
  113. if identCnt == 0 {
  114. t.Fatalf("%d: no generated identifiers found", idx)
  115. }
  116. }
  117. }
  118. func loadVars(tpl *template.Template, tplData interface{}) (vars extractedVars, err error) {
  119. tempDir, err := ioutil.TempDir("", "msgp-")
  120. if err != nil {
  121. err = fmt.Errorf("could not create temp dir: %v", err)
  122. return
  123. }
  124. if !debugTemp {
  125. defer os.RemoveAll(tempDir)
  126. } else {
  127. fmt.Println(tempDir)
  128. }
  129. tfile := filepath.Join(tempDir, "msg.go")
  130. genFile := newFilename(tfile, "")
  131. if err = goGenerateTpl(tempDir, tfile, tpl, tplData); err != nil {
  132. err = fmt.Errorf("could not generate code: %v", err)
  133. return
  134. }
  135. vars, err = extractVars(genFile)
  136. if err != nil {
  137. err = fmt.Errorf("could not extract after vars: %v", err)
  138. return
  139. }
  140. return
  141. }
  142. type varVisitor struct {
  143. vars []string
  144. fset *token.FileSet
  145. }
  146. func (v *varVisitor) Visit(node ast.Node) (w ast.Visitor) {
  147. gen, ok := node.(*ast.GenDecl)
  148. if !ok {
  149. return v
  150. }
  151. for _, spec := range gen.Specs {
  152. if vspec, ok := spec.(*ast.ValueSpec); ok {
  153. for _, n := range vspec.Names {
  154. v.vars = append(v.vars, n.Name)
  155. }
  156. }
  157. }
  158. return v
  159. }
  160. type extractedVars map[string][]string
  161. func (e extractedVars) Value(key string) []string {
  162. if v, ok := e[key]; ok {
  163. return v
  164. }
  165. panic(fmt.Errorf("unknown key %s", key))
  166. }
  167. func extractVars(file string) (extractedVars, error) {
  168. fset := token.NewFileSet()
  169. f, err := parser.ParseFile(fset, file, nil, 0)
  170. if err != nil {
  171. return nil, err
  172. }
  173. vars := make(map[string][]string)
  174. for _, d := range f.Decls {
  175. switch d := d.(type) {
  176. case *ast.FuncDecl:
  177. sn := ""
  178. switch rt := d.Recv.List[0].Type.(type) {
  179. case *ast.Ident:
  180. sn = rt.Name
  181. case *ast.StarExpr:
  182. sn = rt.X.(*ast.Ident).Name
  183. default:
  184. panic("unknown receiver type")
  185. }
  186. key := fmt.Sprintf("%s.%s", sn, d.Name.Name)
  187. vis := &varVisitor{fset: fset}
  188. ast.Walk(vis, d.Body)
  189. vars[key] = vis.vars
  190. }
  191. }
  192. return vars, nil
  193. }
  194. func goGenerateTpl(cwd, tfile string, tpl *template.Template, tplData interface{}) error {
  195. outf, err := os.OpenFile(tfile, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0600)
  196. if err != nil {
  197. return err
  198. }
  199. defer outf.Close()
  200. if err := tpl.Execute(outf, tplData); err != nil {
  201. return err
  202. }
  203. mode := gen.Encode | gen.Decode | gen.Size | gen.Marshal | gen.Unmarshal
  204. return Run(tfile, mode, false)
  205. }
  206. var issue185IdentsTpl = template.Must(template.New("").Parse(`
  207. package issue185
  208. //go:generate msgp
  209. type Test1 struct {
  210. Foo string
  211. Bar string
  212. {{ if .Extra }}Baz []string{{ end }}
  213. Qux string
  214. }
  215. type Test2 struct {
  216. Foo string
  217. Bar string
  218. Baz string
  219. }
  220. `))
  221. var issue185ComplexIdentsTpl = template.Must(template.New("").Parse(`
  222. package issue185
  223. //go:generate msgp
  224. type Test1 struct {
  225. Foo string
  226. Bar string
  227. Baz string
  228. }
  229. type Test2 struct {
  230. Foo string
  231. Bar string
  232. Baz []string
  233. Qux map[string]string
  234. Yep map[string]map[string]string
  235. Quack struct {
  236. Quack struct {
  237. Quack struct {
  238. {{ if .Extra }}Extra []string{{ end }}
  239. Quack string
  240. }
  241. }
  242. }
  243. Nup struct {
  244. Foo string
  245. Bar string
  246. Baz []string
  247. Qux map[string]string
  248. Yep map[string]map[string]string
  249. }
  250. Ding struct {
  251. Dong struct {
  252. Dung struct {
  253. Thing string
  254. }
  255. }
  256. }
  257. }
  258. type Test3 struct {
  259. Foo string
  260. Bar string
  261. Baz string
  262. }
  263. `))