You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

383 lines
8.1 KiB

  1. package gen
  2. import (
  3. "fmt"
  4. "io"
  5. )
  6. const (
  7. errcheck = "\nif err != nil { return }"
  8. lenAsUint32 = "uint32(len(%s))"
  9. literalFmt = "%s"
  10. intFmt = "%d"
  11. quotedFmt = `"%s"`
  12. mapHeader = "MapHeader"
  13. arrayHeader = "ArrayHeader"
  14. mapKey = "MapKeyPtr"
  15. stringTyp = "String"
  16. u32 = "uint32"
  17. )
  18. // Method is a bitfield representing something that the
  19. // generator knows how to print.
  20. type Method uint8
  21. // are the bits in 'f' set in 'm'?
  22. func (m Method) isset(f Method) bool { return (m&f == f) }
  23. // String implements fmt.Stringer
  24. func (m Method) String() string {
  25. switch m {
  26. case 0, invalidmeth:
  27. return "<invalid method>"
  28. case Decode:
  29. return "decode"
  30. case Encode:
  31. return "encode"
  32. case Marshal:
  33. return "marshal"
  34. case Unmarshal:
  35. return "unmarshal"
  36. case Size:
  37. return "size"
  38. case Test:
  39. return "test"
  40. default:
  41. // return e.g. "decode+encode+test"
  42. modes := [...]Method{Decode, Encode, Marshal, Unmarshal, Size, Test}
  43. any := false
  44. nm := ""
  45. for _, mm := range modes {
  46. if m.isset(mm) {
  47. if any {
  48. nm += "+" + mm.String()
  49. } else {
  50. nm += mm.String()
  51. any = true
  52. }
  53. }
  54. }
  55. return nm
  56. }
  57. }
  58. func strtoMeth(s string) Method {
  59. switch s {
  60. case "encode":
  61. return Encode
  62. case "decode":
  63. return Decode
  64. case "marshal":
  65. return Marshal
  66. case "unmarshal":
  67. return Unmarshal
  68. case "size":
  69. return Size
  70. case "test":
  71. return Test
  72. default:
  73. return 0
  74. }
  75. }
  76. const (
  77. Decode Method = 1 << iota // msgp.Decodable
  78. Encode // msgp.Encodable
  79. Marshal // msgp.Marshaler
  80. Unmarshal // msgp.Unmarshaler
  81. Size // msgp.Sizer
  82. Test // generate tests
  83. invalidmeth // this isn't a method
  84. encodetest = Encode | Decode | Test // tests for Encodable and Decodable
  85. marshaltest = Marshal | Unmarshal | Test // tests for Marshaler and Unmarshaler
  86. )
  87. type Printer struct {
  88. gens []generator
  89. }
  90. func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer {
  91. if m.isset(Test) && tests == nil {
  92. panic("cannot print tests with 'nil' tests argument!")
  93. }
  94. gens := make([]generator, 0, 7)
  95. if m.isset(Decode) {
  96. gens = append(gens, decode(out))
  97. }
  98. if m.isset(Encode) {
  99. gens = append(gens, encode(out))
  100. }
  101. if m.isset(Marshal) {
  102. gens = append(gens, marshal(out))
  103. }
  104. if m.isset(Unmarshal) {
  105. gens = append(gens, unmarshal(out))
  106. }
  107. if m.isset(Size) {
  108. gens = append(gens, sizes(out))
  109. }
  110. if m.isset(marshaltest) {
  111. gens = append(gens, mtest(tests))
  112. }
  113. if m.isset(encodetest) {
  114. gens = append(gens, etest(tests))
  115. }
  116. if len(gens) == 0 {
  117. panic("NewPrinter called with invalid method flags")
  118. }
  119. return &Printer{gens: gens}
  120. }
  121. // TransformPass is a pass that transforms individual
  122. // elements. (Note that if the returned is different from
  123. // the argument, it should not point to the same objects.)
  124. type TransformPass func(Elem) Elem
  125. // IgnoreTypename is a pass that just ignores
  126. // types of a given name.
  127. func IgnoreTypename(name string) TransformPass {
  128. return func(e Elem) Elem {
  129. if e.TypeName() == name {
  130. return nil
  131. }
  132. return e
  133. }
  134. }
  135. // ApplyDirective applies a directive to a named pass
  136. // and all of its dependents.
  137. func (p *Printer) ApplyDirective(pass Method, t TransformPass) {
  138. for _, g := range p.gens {
  139. if g.Method().isset(pass) {
  140. g.Add(t)
  141. }
  142. }
  143. }
  144. // Print prints an Elem.
  145. func (p *Printer) Print(e Elem) error {
  146. for _, g := range p.gens {
  147. // Elem.SetVarname() is called before the Print() step in parse.FileSet.PrintTo().
  148. // Elem.SetVarname() generates identifiers as it walks the Elem. This can cause
  149. // collisions between idents created during SetVarname and idents created during Print,
  150. // hence the separate prefixes.
  151. resetIdent("zb")
  152. err := g.Execute(e)
  153. resetIdent("za")
  154. if err != nil {
  155. return err
  156. }
  157. }
  158. return nil
  159. }
  160. // generator is the interface through
  161. // which code is generated.
  162. type generator interface {
  163. Method() Method
  164. Add(p TransformPass)
  165. Execute(Elem) error // execute writes the method for the provided object.
  166. }
  167. type passes []TransformPass
  168. func (p *passes) Add(t TransformPass) {
  169. *p = append(*p, t)
  170. }
  171. func (p *passes) applyall(e Elem) Elem {
  172. for _, t := range *p {
  173. e = t(e)
  174. if e == nil {
  175. return nil
  176. }
  177. }
  178. return e
  179. }
  180. type traversal interface {
  181. gMap(*Map)
  182. gSlice(*Slice)
  183. gArray(*Array)
  184. gPtr(*Ptr)
  185. gBase(*BaseElem)
  186. gStruct(*Struct)
  187. }
  188. // type-switch dispatch to the correct
  189. // method given the type of 'e'
  190. func next(t traversal, e Elem) {
  191. switch e := e.(type) {
  192. case *Map:
  193. t.gMap(e)
  194. case *Struct:
  195. t.gStruct(e)
  196. case *Slice:
  197. t.gSlice(e)
  198. case *Array:
  199. t.gArray(e)
  200. case *Ptr:
  201. t.gPtr(e)
  202. case *BaseElem:
  203. t.gBase(e)
  204. default:
  205. panic("bad element type")
  206. }
  207. }
  208. // possibly-immutable method receiver
  209. func imutMethodReceiver(p Elem) string {
  210. switch e := p.(type) {
  211. case *Struct:
  212. // TODO(HACK): actually do real math here.
  213. if len(e.Fields) <= 3 {
  214. for i := range e.Fields {
  215. if be, ok := e.Fields[i].FieldElem.(*BaseElem); !ok || (be.Value == IDENT || be.Value == Bytes) {
  216. goto nope
  217. }
  218. }
  219. return p.TypeName()
  220. }
  221. nope:
  222. return "*" + p.TypeName()
  223. // gets dereferenced automatically
  224. case *Array:
  225. return "*" + p.TypeName()
  226. // everything else can be
  227. // by-value.
  228. default:
  229. return p.TypeName()
  230. }
  231. }
  232. // if necessary, wraps a type
  233. // so that its method receiver
  234. // is of the write type.
  235. func methodReceiver(p Elem) string {
  236. switch p.(type) {
  237. // structs and arrays are
  238. // dereferenced automatically,
  239. // so no need to alter varname
  240. case *Struct, *Array:
  241. return "*" + p.TypeName()
  242. // set variable name to
  243. // *varname
  244. default:
  245. p.SetVarname("(*" + p.Varname() + ")")
  246. return "*" + p.TypeName()
  247. }
  248. }
  249. func unsetReceiver(p Elem) {
  250. switch p.(type) {
  251. case *Struct, *Array:
  252. default:
  253. p.SetVarname("z")
  254. }
  255. }
  256. // shared utility for generators
  257. type printer struct {
  258. w io.Writer
  259. err error
  260. }
  261. // writes "var {{name}} {{typ}};"
  262. func (p *printer) declare(name string, typ string) {
  263. p.printf("\nvar %s %s", name, typ)
  264. }
  265. // does:
  266. //
  267. // if m != nil && size > 0 {
  268. // m = make(type, size)
  269. // } else if len(m) > 0 {
  270. // for key, _ := range m { delete(m, key) }
  271. // }
  272. //
  273. func (p *printer) resizeMap(size string, m *Map) {
  274. vn := m.Varname()
  275. if !p.ok() {
  276. return
  277. }
  278. p.printf("\nif %s == nil && %s > 0 {", vn, size)
  279. p.printf("\n%s = make(%s, %s)", vn, m.TypeName(), size)
  280. p.printf("\n} else if len(%s) > 0 {", vn)
  281. p.clearMap(vn)
  282. p.closeblock()
  283. }
  284. // assign key to value based on varnames
  285. func (p *printer) mapAssign(m *Map) {
  286. if !p.ok() {
  287. return
  288. }
  289. p.printf("\n%s[%s] = %s", m.Varname(), m.Keyidx, m.Validx)
  290. }
  291. // clear map keys
  292. func (p *printer) clearMap(name string) {
  293. p.printf("\nfor key, _ := range %[1]s { delete(%[1]s, key) }", name)
  294. }
  295. func (p *printer) resizeSlice(size string, s *Slice) {
  296. p.printf("\nif cap(%[1]s) >= int(%[2]s) { %[1]s = (%[1]s)[:%[2]s] } else { %[1]s = make(%[3]s, %[2]s) }", s.Varname(), size, s.TypeName())
  297. }
  298. func (p *printer) arrayCheck(want string, got string) {
  299. p.printf("\nif %[1]s != %[2]s { err = msgp.ArrayError{Wanted: %[2]s, Got: %[1]s}; return }", got, want)
  300. }
  301. func (p *printer) closeblock() { p.print("\n}") }
  302. // does:
  303. //
  304. // for idx := range iter {
  305. // {{generate inner}}
  306. // }
  307. //
  308. func (p *printer) rangeBlock(idx string, iter string, t traversal, inner Elem) {
  309. p.printf("\n for %s := range %s {", idx, iter)
  310. next(t, inner)
  311. p.closeblock()
  312. }
  313. func (p *printer) nakedReturn() {
  314. if p.ok() {
  315. p.print("\nreturn\n}\n")
  316. }
  317. }
  318. func (p *printer) comment(s string) {
  319. p.print("\n// " + s)
  320. }
  321. func (p *printer) printf(format string, args ...interface{}) {
  322. if p.err == nil {
  323. _, p.err = fmt.Fprintf(p.w, format, args...)
  324. }
  325. }
  326. func (p *printer) print(format string) {
  327. if p.err == nil {
  328. _, p.err = io.WriteString(p.w, format)
  329. }
  330. }
  331. func (p *printer) initPtr(pt *Ptr) {
  332. if pt.Needsinit() {
  333. vname := pt.Varname()
  334. p.printf("\nif %s == nil { %s = new(%s); }", vname, vname, pt.Value.TypeName())
  335. }
  336. }
  337. func (p *printer) ok() bool { return p.err == nil }
  338. func tobaseConvert(b *BaseElem) string {
  339. return b.ToBase() + "(" + b.Varname() + ")"
  340. }