You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

129 lines
3.2 KiB

  1. package printer
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "strings"
  8. "github.com/tinylib/msgp/gen"
  9. "github.com/tinylib/msgp/parse"
  10. "github.com/ttacon/chalk"
  11. "golang.org/x/tools/imports"
  12. )
  13. func infof(s string, v ...interface{}) {
  14. fmt.Printf(chalk.Magenta.Color(s), v...)
  15. }
  16. // PrintFile prints the methods for the provided list
  17. // of elements to the given file name and canonical
  18. // package path.
  19. func PrintFile(file string, f *parse.FileSet, mode gen.Method) error {
  20. out, tests, err := generate(f, mode)
  21. if err != nil {
  22. return err
  23. }
  24. // we'll run goimports on the main file
  25. // in another goroutine, and run it here
  26. // for the test file. empirically, this
  27. // takes about the same amount of time as
  28. // doing them in serial when GOMAXPROCS=1,
  29. // and faster otherwise.
  30. res := goformat(file, out.Bytes())
  31. if tests != nil {
  32. testfile := strings.TrimSuffix(file, ".go") + "_test.go"
  33. err = format(testfile, tests.Bytes())
  34. if err != nil {
  35. return err
  36. }
  37. infof(">>> Wrote and formatted \"%s\"\n", testfile)
  38. }
  39. err = <-res
  40. if err != nil {
  41. return err
  42. }
  43. return nil
  44. }
  45. func format(file string, data []byte) error {
  46. out, err := imports.Process(file, data, nil)
  47. if err != nil {
  48. return err
  49. }
  50. return ioutil.WriteFile(file, out, 0600)
  51. }
  52. func goformat(file string, data []byte) <-chan error {
  53. out := make(chan error, 1)
  54. go func(file string, data []byte, end chan error) {
  55. end <- format(file, data)
  56. infof(">>> Wrote and formatted \"%s\"\n", file)
  57. }(file, data, out)
  58. return out
  59. }
  60. func dedupImports(imp []string) []string {
  61. m := make(map[string]struct{})
  62. for i := range imp {
  63. m[imp[i]] = struct{}{}
  64. }
  65. r := []string{}
  66. for k := range m {
  67. r = append(r, k)
  68. }
  69. return r
  70. }
  71. func generate(f *parse.FileSet, mode gen.Method) (*bytes.Buffer, *bytes.Buffer, error) {
  72. outbuf := bytes.NewBuffer(make([]byte, 0, 4096))
  73. writePkgHeader(outbuf, f.Package)
  74. myImports := []string{"github.com/tinylib/msgp/msgp"}
  75. for _, imp := range f.Imports {
  76. if imp.Name != nil {
  77. // have an alias, include it.
  78. myImports = append(myImports, imp.Name.Name+` `+imp.Path.Value)
  79. } else {
  80. myImports = append(myImports, imp.Path.Value)
  81. }
  82. }
  83. dedup := dedupImports(myImports)
  84. writeImportHeader(outbuf, dedup...)
  85. var testbuf *bytes.Buffer
  86. var testwr io.Writer
  87. if mode&gen.Test == gen.Test {
  88. testbuf = bytes.NewBuffer(make([]byte, 0, 4096))
  89. writePkgHeader(testbuf, f.Package)
  90. if mode&(gen.Encode|gen.Decode) != 0 {
  91. writeImportHeader(testbuf, "bytes", "github.com/tinylib/msgp/msgp", "testing")
  92. } else {
  93. writeImportHeader(testbuf, "github.com/tinylib/msgp/msgp", "testing")
  94. }
  95. testwr = testbuf
  96. }
  97. return outbuf, testbuf, f.PrintTo(gen.NewPrinter(mode, outbuf, testwr))
  98. }
  99. func writePkgHeader(b *bytes.Buffer, name string) {
  100. b.WriteString("package ")
  101. b.WriteString(name)
  102. b.WriteByte('\n')
  103. b.WriteString("// NOTE: THIS FILE WAS PRODUCED BY THE\n// MSGP CODE GENERATION TOOL (github.com/tinylib/msgp)\n// DO NOT EDIT\n\n")
  104. }
  105. func writeImportHeader(b *bytes.Buffer, imports ...string) {
  106. b.WriteString("import (\n")
  107. for _, im := range imports {
  108. if im[len(im)-1] == '"' {
  109. // support aliased imports
  110. fmt.Fprintf(b, "\t%s\n", im)
  111. } else {
  112. fmt.Fprintf(b, "\t%q\n", im)
  113. }
  114. }
  115. b.WriteString(")\n\n")
  116. }