You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

269 lines
4.7 KiB

  1. package msgpack
  2. import (
  3. "fmt"
  4. "reflect"
  5. "github.com/vmihailenco/msgpack/codes"
  6. )
  7. const mapElemsAllocLimit = 1e4
  8. var mapStringStringPtrType = reflect.TypeOf((*map[string]string)(nil))
  9. var mapStringStringType = mapStringStringPtrType.Elem()
  10. var mapStringInterfacePtrType = reflect.TypeOf((*map[string]interface{})(nil))
  11. var mapStringInterfaceType = mapStringInterfacePtrType.Elem()
  12. func decodeMapValue(d *Decoder, v reflect.Value) error {
  13. n, err := d.DecodeMapLen()
  14. if err != nil {
  15. return err
  16. }
  17. typ := v.Type()
  18. if n == -1 {
  19. v.Set(reflect.Zero(typ))
  20. return nil
  21. }
  22. if v.IsNil() {
  23. v.Set(reflect.MakeMap(typ))
  24. }
  25. keyType := typ.Key()
  26. valueType := typ.Elem()
  27. for i := 0; i < n; i++ {
  28. mk := reflect.New(keyType).Elem()
  29. if err := d.DecodeValue(mk); err != nil {
  30. return err
  31. }
  32. mv := reflect.New(valueType).Elem()
  33. if err := d.DecodeValue(mv); err != nil {
  34. return err
  35. }
  36. v.SetMapIndex(mk, mv)
  37. }
  38. return nil
  39. }
  40. func decodeMap(d *Decoder) (interface{}, error) {
  41. n, err := d.DecodeMapLen()
  42. if err != nil {
  43. return nil, err
  44. }
  45. if n == -1 {
  46. return nil, nil
  47. }
  48. m := make(map[string]interface{}, min(n, mapElemsAllocLimit))
  49. for i := 0; i < n; i++ {
  50. mk, err := d.DecodeString()
  51. if err != nil {
  52. return nil, err
  53. }
  54. mv, err := d.DecodeInterface()
  55. if err != nil {
  56. return nil, err
  57. }
  58. m[mk] = mv
  59. }
  60. return m, nil
  61. }
  62. func (d *Decoder) DecodeMapLen() (int, error) {
  63. c, err := d.readCode()
  64. if err != nil {
  65. return 0, err
  66. }
  67. if codes.IsExt(c) {
  68. if err = d.skipExtHeader(c); err != nil {
  69. return 0, err
  70. }
  71. c, err = d.readCode()
  72. if err != nil {
  73. return 0, err
  74. }
  75. }
  76. return d.mapLen(c)
  77. }
  78. func (d *Decoder) mapLen(c codes.Code) (int, error) {
  79. if c == codes.Nil {
  80. return -1, nil
  81. }
  82. if c >= codes.FixedMapLow && c <= codes.FixedMapHigh {
  83. return int(c & codes.FixedMapMask), nil
  84. }
  85. if c == codes.Map16 {
  86. n, err := d.uint16()
  87. return int(n), err
  88. }
  89. if c == codes.Map32 {
  90. n, err := d.uint32()
  91. return int(n), err
  92. }
  93. return 0, fmt.Errorf("msgpack: invalid code=%x decoding map length", c)
  94. }
  95. func decodeMapStringStringValue(d *Decoder, v reflect.Value) error {
  96. mptr := v.Addr().Convert(mapStringStringPtrType).Interface().(*map[string]string)
  97. return d.decodeMapStringStringPtr(mptr)
  98. }
  99. func (d *Decoder) decodeMapStringStringPtr(ptr *map[string]string) error {
  100. n, err := d.DecodeMapLen()
  101. if err != nil {
  102. return err
  103. }
  104. if n == -1 {
  105. *ptr = nil
  106. return nil
  107. }
  108. m := *ptr
  109. if m == nil {
  110. *ptr = make(map[string]string, min(n, mapElemsAllocLimit))
  111. m = *ptr
  112. }
  113. for i := 0; i < n; i++ {
  114. mk, err := d.DecodeString()
  115. if err != nil {
  116. return err
  117. }
  118. mv, err := d.DecodeString()
  119. if err != nil {
  120. return err
  121. }
  122. m[mk] = mv
  123. }
  124. return nil
  125. }
  126. func decodeMapStringInterfaceValue(d *Decoder, v reflect.Value) error {
  127. ptr := v.Addr().Convert(mapStringInterfacePtrType).Interface().(*map[string]interface{})
  128. return d.decodeMapStringInterfacePtr(ptr)
  129. }
  130. func (d *Decoder) decodeMapStringInterfacePtr(ptr *map[string]interface{}) error {
  131. n, err := d.DecodeMapLen()
  132. if err != nil {
  133. return err
  134. }
  135. if n == -1 {
  136. *ptr = nil
  137. return nil
  138. }
  139. m := *ptr
  140. if m == nil {
  141. *ptr = make(map[string]interface{}, min(n, mapElemsAllocLimit))
  142. m = *ptr
  143. }
  144. for i := 0; i < n; i++ {
  145. mk, err := d.DecodeString()
  146. if err != nil {
  147. return err
  148. }
  149. mv, err := d.DecodeInterface()
  150. if err != nil {
  151. return err
  152. }
  153. m[mk] = mv
  154. }
  155. return nil
  156. }
  157. func (d *Decoder) DecodeMap() (interface{}, error) {
  158. return d.decodeMapFunc(d)
  159. }
  160. func (d *Decoder) skipMap(c codes.Code) error {
  161. n, err := d.mapLen(c)
  162. if err != nil {
  163. return err
  164. }
  165. for i := 0; i < n; i++ {
  166. if err := d.Skip(); err != nil {
  167. return err
  168. }
  169. if err := d.Skip(); err != nil {
  170. return err
  171. }
  172. }
  173. return nil
  174. }
  175. func decodeStructValue(d *Decoder, v reflect.Value) error {
  176. c, err := d.readCode()
  177. if err != nil {
  178. return err
  179. }
  180. var isArray bool
  181. n, err := d.mapLen(c)
  182. if err != nil {
  183. var err2 error
  184. n, err2 = d.arrayLen(c)
  185. if err2 != nil {
  186. return err
  187. }
  188. isArray = true
  189. }
  190. if n == -1 {
  191. if err = mustSet(v); err != nil {
  192. return err
  193. }
  194. v.Set(reflect.Zero(v.Type()))
  195. return nil
  196. }
  197. fields := structs.Fields(v.Type())
  198. if isArray {
  199. for i, f := range fields.List {
  200. if i >= n {
  201. break
  202. }
  203. if err := f.DecodeValue(d, v); err != nil {
  204. return err
  205. }
  206. }
  207. // Skip extra values.
  208. for i := len(fields.List); i < n; i++ {
  209. if err := d.Skip(); err != nil {
  210. return err
  211. }
  212. }
  213. return nil
  214. }
  215. for i := 0; i < n; i++ {
  216. name, err := d.DecodeString()
  217. if err != nil {
  218. return err
  219. }
  220. if f := fields.Table[name]; f != nil {
  221. if err := f.DecodeValue(d, v); err != nil {
  222. return err
  223. }
  224. } else {
  225. if err := d.Skip(); err != nil {
  226. return err
  227. }
  228. }
  229. }
  230. return nil
  231. }