|
|
package gen
import ( "fmt" "io" )
const ( errcheck = "\nif err != nil { return }" lenAsUint32 = "uint32(len(%s))" literalFmt = "%s" intFmt = "%d" quotedFmt = `"%s"` mapHeader = "MapHeader" arrayHeader = "ArrayHeader" mapKey = "MapKeyPtr" stringTyp = "String" u32 = "uint32" )
// Method is a bitfield representing something that the
// generator knows how to print.
type Method uint8
// are the bits in 'f' set in 'm'?
func (m Method) isset(f Method) bool { return (m&f == f) }
// String implements fmt.Stringer
func (m Method) String() string { switch m { case 0, invalidmeth: return "<invalid method>" case Decode: return "decode" case Encode: return "encode" case Marshal: return "marshal" case Unmarshal: return "unmarshal" case Size: return "size" case Test: return "test" default: // return e.g. "decode+encode+test"
modes := [...]Method{Decode, Encode, Marshal, Unmarshal, Size, Test} any := false nm := "" for _, mm := range modes { if m.isset(mm) { if any { nm += "+" + mm.String() } else { nm += mm.String() any = true } } } return nm
} }
func strtoMeth(s string) Method { switch s { case "encode": return Encode case "decode": return Decode case "marshal": return Marshal case "unmarshal": return Unmarshal case "size": return Size case "test": return Test default: return 0 } }
const ( Decode Method = 1 << iota // msgp.Decodable
Encode // msgp.Encodable
Marshal // msgp.Marshaler
Unmarshal // msgp.Unmarshaler
Size // msgp.Sizer
Test // generate tests
invalidmeth // this isn't a method
encodetest = Encode | Decode | Test // tests for Encodable and Decodable
marshaltest = Marshal | Unmarshal | Test // tests for Marshaler and Unmarshaler
)
type Printer struct { gens []generator }
func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer { if m.isset(Test) && tests == nil { panic("cannot print tests with 'nil' tests argument!") } gens := make([]generator, 0, 7) if m.isset(Decode) { gens = append(gens, decode(out)) } if m.isset(Encode) { gens = append(gens, encode(out)) } if m.isset(Marshal) { gens = append(gens, marshal(out)) } if m.isset(Unmarshal) { gens = append(gens, unmarshal(out)) } if m.isset(Size) { gens = append(gens, sizes(out)) } if m.isset(marshaltest) { gens = append(gens, mtest(tests)) } if m.isset(encodetest) { gens = append(gens, etest(tests)) } if len(gens) == 0 { panic("NewPrinter called with invalid method flags") } return &Printer{gens: gens} }
// TransformPass is a pass that transforms individual
// elements. (Note that if the returned is different from
// the argument, it should not point to the same objects.)
type TransformPass func(Elem) Elem
// IgnoreTypename is a pass that just ignores
// types of a given name.
func IgnoreTypename(name string) TransformPass { return func(e Elem) Elem { if e.TypeName() == name { return nil } return e } }
// ApplyDirective applies a directive to a named pass
// and all of its dependents.
func (p *Printer) ApplyDirective(pass Method, t TransformPass) { for _, g := range p.gens { if g.Method().isset(pass) { g.Add(t) } } }
// Print prints an Elem.
func (p *Printer) Print(e Elem) error { for _, g := range p.gens { // Elem.SetVarname() is called before the Print() step in parse.FileSet.PrintTo().
// Elem.SetVarname() generates identifiers as it walks the Elem. This can cause
// collisions between idents created during SetVarname and idents created during Print,
// hence the separate prefixes.
resetIdent("zb") err := g.Execute(e) resetIdent("za")
if err != nil { return err } } return nil }
// generator is the interface through
// which code is generated.
type generator interface { Method() Method Add(p TransformPass) Execute(Elem) error // execute writes the method for the provided object.
}
type passes []TransformPass
func (p *passes) Add(t TransformPass) { *p = append(*p, t) }
func (p *passes) applyall(e Elem) Elem { for _, t := range *p { e = t(e) if e == nil { return nil } } return e }
type traversal interface { gMap(*Map) gSlice(*Slice) gArray(*Array) gPtr(*Ptr) gBase(*BaseElem) gStruct(*Struct) }
// type-switch dispatch to the correct
// method given the type of 'e'
func next(t traversal, e Elem) { switch e := e.(type) { case *Map: t.gMap(e) case *Struct: t.gStruct(e) case *Slice: t.gSlice(e) case *Array: t.gArray(e) case *Ptr: t.gPtr(e) case *BaseElem: t.gBase(e) default: panic("bad element type") } }
// possibly-immutable method receiver
func imutMethodReceiver(p Elem) string { switch e := p.(type) { case *Struct: // TODO(HACK): actually do real math here.
if len(e.Fields) <= 3 { for i := range e.Fields { if be, ok := e.Fields[i].FieldElem.(*BaseElem); !ok || (be.Value == IDENT || be.Value == Bytes) { goto nope } } return p.TypeName() } nope: return "*" + p.TypeName()
// gets dereferenced automatically
case *Array: return "*" + p.TypeName()
// everything else can be
// by-value.
default: return p.TypeName() } }
// if necessary, wraps a type
// so that its method receiver
// is of the write type.
func methodReceiver(p Elem) string { switch p.(type) {
// structs and arrays are
// dereferenced automatically,
// so no need to alter varname
case *Struct, *Array: return "*" + p.TypeName() // set variable name to
// *varname
default: p.SetVarname("(*" + p.Varname() + ")") return "*" + p.TypeName() } }
func unsetReceiver(p Elem) { switch p.(type) { case *Struct, *Array: default: p.SetVarname("z") } }
// shared utility for generators
type printer struct { w io.Writer err error }
// writes "var {{name}} {{typ}};"
func (p *printer) declare(name string, typ string) { p.printf("\nvar %s %s", name, typ) }
// does:
//
// if m != nil && size > 0 {
// m = make(type, size)
// } else if len(m) > 0 {
// for key, _ := range m { delete(m, key) }
// }
//
func (p *printer) resizeMap(size string, m *Map) { vn := m.Varname() if !p.ok() { return } p.printf("\nif %s == nil && %s > 0 {", vn, size) p.printf("\n%s = make(%s, %s)", vn, m.TypeName(), size) p.printf("\n} else if len(%s) > 0 {", vn) p.clearMap(vn) p.closeblock() }
// assign key to value based on varnames
func (p *printer) mapAssign(m *Map) { if !p.ok() { return } p.printf("\n%s[%s] = %s", m.Varname(), m.Keyidx, m.Validx) }
// clear map keys
func (p *printer) clearMap(name string) { p.printf("\nfor key, _ := range %[1]s { delete(%[1]s, key) }", name) }
func (p *printer) resizeSlice(size string, s *Slice) { p.printf("\nif cap(%[1]s) >= int(%[2]s) { %[1]s = (%[1]s)[:%[2]s] } else { %[1]s = make(%[3]s, %[2]s) }", s.Varname(), size, s.TypeName()) }
func (p *printer) arrayCheck(want string, got string) { p.printf("\nif %[1]s != %[2]s { err = msgp.ArrayError{Wanted: %[2]s, Got: %[1]s}; return }", got, want) }
func (p *printer) closeblock() { p.print("\n}") }
// does:
//
// for idx := range iter {
// {{generate inner}}
// }
//
func (p *printer) rangeBlock(idx string, iter string, t traversal, inner Elem) { p.printf("\n for %s := range %s {", idx, iter) next(t, inner) p.closeblock() }
func (p *printer) nakedReturn() { if p.ok() { p.print("\nreturn\n}\n") } }
func (p *printer) comment(s string) { p.print("\n// " + s) }
func (p *printer) printf(format string, args ...interface{}) { if p.err == nil { _, p.err = fmt.Fprintf(p.w, format, args...) } }
func (p *printer) print(format string) { if p.err == nil { _, p.err = io.WriteString(p.w, format) } }
func (p *printer) initPtr(pt *Ptr) { if pt.Needsinit() { vname := pt.Varname() p.printf("\nif %s == nil { %s = new(%s); }", vname, vname, pt.Value.TypeName()) } }
func (p *printer) ok() bool { return p.err == nil }
func tobaseConvert(b *BaseElem) string { return b.ToBase() + "(" + b.Varname() + ")" }
|