You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

509 lines
10 KiB

  1. package msgpack
  2. import (
  3. "bufio"
  4. "bytes"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "reflect"
  9. "time"
  10. "github.com/vmihailenco/msgpack/codes"
  11. )
  12. const bytesAllocLimit = 1024 * 1024 // 1mb
  13. type bufReader interface {
  14. io.Reader
  15. io.ByteScanner
  16. }
  17. func newBufReader(r io.Reader) bufReader {
  18. if br, ok := r.(bufReader); ok {
  19. return br
  20. }
  21. return bufio.NewReader(r)
  22. }
  23. func makeBuffer() []byte {
  24. return make([]byte, 0, 64)
  25. }
  26. // Unmarshal decodes the MessagePack-encoded data and stores the result
  27. // in the value pointed to by v.
  28. func Unmarshal(data []byte, v ...interface{}) error {
  29. return NewDecoder(bytes.NewReader(data)).Decode(v...)
  30. }
  31. type Decoder struct {
  32. r io.Reader
  33. s io.ByteScanner
  34. buf []byte
  35. extLen int
  36. rec []byte // accumulates read data if not nil
  37. decodeMapFunc func(*Decoder) (interface{}, error)
  38. }
  39. // NewDecoder returns a new decoder that reads from r.
  40. //
  41. // The decoder introduces its own buffering and may read data from r
  42. // beyond the MessagePack values requested. Buffering can be disabled
  43. // by passing a reader that implements io.ByteScanner interface.
  44. func NewDecoder(r io.Reader) *Decoder {
  45. d := &Decoder{
  46. decodeMapFunc: decodeMap,
  47. buf: makeBuffer(),
  48. }
  49. d.resetReader(r)
  50. return d
  51. }
  52. func (d *Decoder) SetDecodeMapFunc(fn func(*Decoder) (interface{}, error)) {
  53. d.decodeMapFunc = fn
  54. }
  55. func (d *Decoder) Reset(r io.Reader) error {
  56. d.resetReader(r)
  57. return nil
  58. }
  59. func (d *Decoder) resetReader(r io.Reader) {
  60. reader := newBufReader(r)
  61. d.r = reader
  62. d.s = reader
  63. }
  64. func (d *Decoder) Decode(v ...interface{}) error {
  65. for _, vv := range v {
  66. if err := d.decode(vv); err != nil {
  67. return err
  68. }
  69. }
  70. return nil
  71. }
  72. func (d *Decoder) decode(dst interface{}) error {
  73. var err error
  74. switch v := dst.(type) {
  75. case *string:
  76. if v != nil {
  77. *v, err = d.DecodeString()
  78. return err
  79. }
  80. case *[]byte:
  81. if v != nil {
  82. return d.decodeBytesPtr(v)
  83. }
  84. case *int:
  85. if v != nil {
  86. *v, err = d.DecodeInt()
  87. return err
  88. }
  89. case *int8:
  90. if v != nil {
  91. *v, err = d.DecodeInt8()
  92. return err
  93. }
  94. case *int16:
  95. if v != nil {
  96. *v, err = d.DecodeInt16()
  97. return err
  98. }
  99. case *int32:
  100. if v != nil {
  101. *v, err = d.DecodeInt32()
  102. return err
  103. }
  104. case *int64:
  105. if v != nil {
  106. *v, err = d.DecodeInt64()
  107. return err
  108. }
  109. case *uint:
  110. if v != nil {
  111. *v, err = d.DecodeUint()
  112. return err
  113. }
  114. case *uint8:
  115. if v != nil {
  116. *v, err = d.DecodeUint8()
  117. return err
  118. }
  119. case *uint16:
  120. if v != nil {
  121. *v, err = d.DecodeUint16()
  122. return err
  123. }
  124. case *uint32:
  125. if v != nil {
  126. *v, err = d.DecodeUint32()
  127. return err
  128. }
  129. case *uint64:
  130. if v != nil {
  131. *v, err = d.DecodeUint64()
  132. return err
  133. }
  134. case *bool:
  135. if v != nil {
  136. *v, err = d.DecodeBool()
  137. return err
  138. }
  139. case *float32:
  140. if v != nil {
  141. *v, err = d.DecodeFloat32()
  142. return err
  143. }
  144. case *float64:
  145. if v != nil {
  146. *v, err = d.DecodeFloat64()
  147. return err
  148. }
  149. case *[]string:
  150. return d.decodeStringSlicePtr(v)
  151. case *map[string]string:
  152. return d.decodeMapStringStringPtr(v)
  153. case *map[string]interface{}:
  154. return d.decodeMapStringInterfacePtr(v)
  155. case *time.Duration:
  156. if v != nil {
  157. vv, err := d.DecodeInt64()
  158. *v = time.Duration(vv)
  159. return err
  160. }
  161. case *time.Time:
  162. if v != nil {
  163. *v, err = d.DecodeTime()
  164. return err
  165. }
  166. }
  167. v := reflect.ValueOf(dst)
  168. if !v.IsValid() {
  169. return errors.New("msgpack: Decode(nil)")
  170. }
  171. if v.Kind() != reflect.Ptr {
  172. return fmt.Errorf("msgpack: Decode(nonsettable %T)", dst)
  173. }
  174. v = v.Elem()
  175. if !v.IsValid() {
  176. return fmt.Errorf("msgpack: Decode(nonsettable %T)", dst)
  177. }
  178. return d.DecodeValue(v)
  179. }
  180. func (d *Decoder) DecodeValue(v reflect.Value) error {
  181. decode := getDecoder(v.Type())
  182. return decode(d, v)
  183. }
  184. func (d *Decoder) DecodeNil() error {
  185. c, err := d.readCode()
  186. if err != nil {
  187. return err
  188. }
  189. if c != codes.Nil {
  190. return fmt.Errorf("msgpack: invalid code=%x decoding nil", c)
  191. }
  192. return nil
  193. }
  194. func (d *Decoder) decodeNilValue(v reflect.Value) error {
  195. err := d.DecodeNil()
  196. if v.IsNil() {
  197. return err
  198. }
  199. if v.Kind() == reflect.Ptr {
  200. v = v.Elem()
  201. }
  202. v.Set(reflect.Zero(v.Type()))
  203. return err
  204. }
  205. func (d *Decoder) DecodeBool() (bool, error) {
  206. c, err := d.readCode()
  207. if err != nil {
  208. return false, err
  209. }
  210. return d.bool(c)
  211. }
  212. func (d *Decoder) bool(c codes.Code) (bool, error) {
  213. if c == codes.False {
  214. return false, nil
  215. }
  216. if c == codes.True {
  217. return true, nil
  218. }
  219. return false, fmt.Errorf("msgpack: invalid code=%x decoding bool", c)
  220. }
  221. // DecodeInterface decodes value into interface. Possible value types are:
  222. // - nil,
  223. // - bool,
  224. // - int8, int16, int32, int64,
  225. // - uint8, uint16, uint32, uint64,
  226. // - float32 and float64,
  227. // - string,
  228. // - []byte,
  229. // - slices of any of the above,
  230. // - maps of any of the above.
  231. func (d *Decoder) DecodeInterface() (interface{}, error) {
  232. c, err := d.readCode()
  233. if err != nil {
  234. return nil, err
  235. }
  236. if codes.IsFixedNum(c) {
  237. return int8(c), nil
  238. }
  239. if codes.IsFixedMap(c) {
  240. d.s.UnreadByte()
  241. return d.DecodeMap()
  242. }
  243. if codes.IsFixedArray(c) {
  244. return d.decodeSlice(c)
  245. }
  246. if codes.IsFixedString(c) {
  247. return d.string(c)
  248. }
  249. switch c {
  250. case codes.Nil:
  251. return nil, nil
  252. case codes.False, codes.True:
  253. return d.bool(c)
  254. case codes.Float:
  255. return d.float32(c)
  256. case codes.Double:
  257. return d.float64(c)
  258. case codes.Uint8:
  259. return d.uint8()
  260. case codes.Uint16:
  261. return d.uint16()
  262. case codes.Uint32:
  263. return d.uint32()
  264. case codes.Uint64:
  265. return d.uint64()
  266. case codes.Int8:
  267. return d.int8()
  268. case codes.Int16:
  269. return d.int16()
  270. case codes.Int32:
  271. return d.int32()
  272. case codes.Int64:
  273. return d.int64()
  274. case codes.Bin8, codes.Bin16, codes.Bin32:
  275. return d.bytes(c, nil)
  276. case codes.Str8, codes.Str16, codes.Str32:
  277. return d.string(c)
  278. case codes.Array16, codes.Array32:
  279. return d.decodeSlice(c)
  280. case codes.Map16, codes.Map32:
  281. d.s.UnreadByte()
  282. return d.DecodeMap()
  283. case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
  284. codes.Ext8, codes.Ext16, codes.Ext32:
  285. return d.extInterface(c)
  286. }
  287. return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
  288. }
  289. // DecodeInterfaceLoose is like DecodeInterface except that:
  290. // - int8, int16, and int32 are converted to int64,
  291. // - uint8, uint16, and uint32 are converted to uint64,
  292. // - float32 is converted to float64.
  293. func (d *Decoder) DecodeInterfaceLoose() (interface{}, error) {
  294. c, err := d.readCode()
  295. if err != nil {
  296. return nil, err
  297. }
  298. if codes.IsFixedNum(c) {
  299. return int64(c), nil
  300. }
  301. if codes.IsFixedMap(c) {
  302. d.s.UnreadByte()
  303. return d.DecodeMap()
  304. }
  305. if codes.IsFixedArray(c) {
  306. return d.decodeSlice(c)
  307. }
  308. if codes.IsFixedString(c) {
  309. return d.string(c)
  310. }
  311. switch c {
  312. case codes.Nil:
  313. return nil, nil
  314. case codes.False, codes.True:
  315. return d.bool(c)
  316. case codes.Float, codes.Double:
  317. return d.float64(c)
  318. case codes.Uint8, codes.Uint16, codes.Uint32, codes.Uint64:
  319. return d.uint(c)
  320. case codes.Int8, codes.Int16, codes.Int32, codes.Int64:
  321. return d.int(c)
  322. case codes.Bin8, codes.Bin16, codes.Bin32:
  323. return d.bytes(c, nil)
  324. case codes.Str8, codes.Str16, codes.Str32:
  325. return d.string(c)
  326. case codes.Array16, codes.Array32:
  327. return d.decodeSlice(c)
  328. case codes.Map16, codes.Map32:
  329. d.s.UnreadByte()
  330. return d.DecodeMap()
  331. case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
  332. codes.Ext8, codes.Ext16, codes.Ext32:
  333. return d.extInterface(c)
  334. }
  335. return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
  336. }
  337. // Skip skips next value.
  338. func (d *Decoder) Skip() error {
  339. c, err := d.readCode()
  340. if err != nil {
  341. return err
  342. }
  343. if codes.IsFixedNum(c) {
  344. return nil
  345. } else if codes.IsFixedMap(c) {
  346. return d.skipMap(c)
  347. } else if codes.IsFixedArray(c) {
  348. return d.skipSlice(c)
  349. } else if codes.IsFixedString(c) {
  350. return d.skipBytes(c)
  351. }
  352. switch c {
  353. case codes.Nil, codes.False, codes.True:
  354. return nil
  355. case codes.Uint8, codes.Int8:
  356. return d.skipN(1)
  357. case codes.Uint16, codes.Int16:
  358. return d.skipN(2)
  359. case codes.Uint32, codes.Int32, codes.Float:
  360. return d.skipN(4)
  361. case codes.Uint64, codes.Int64, codes.Double:
  362. return d.skipN(8)
  363. case codes.Bin8, codes.Bin16, codes.Bin32:
  364. return d.skipBytes(c)
  365. case codes.Str8, codes.Str16, codes.Str32:
  366. return d.skipBytes(c)
  367. case codes.Array16, codes.Array32:
  368. return d.skipSlice(c)
  369. case codes.Map16, codes.Map32:
  370. return d.skipMap(c)
  371. case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
  372. codes.Ext8, codes.Ext16, codes.Ext32:
  373. return d.skipExt(c)
  374. }
  375. return fmt.Errorf("msgpack: unknown code %x", c)
  376. }
  377. // PeekCode returns the next MessagePack code without advancing the reader.
  378. // Subpackage msgpack/codes contains list of available codes.
  379. func (d *Decoder) PeekCode() (codes.Code, error) {
  380. c, err := d.s.ReadByte()
  381. if err != nil {
  382. return 0, err
  383. }
  384. return codes.Code(c), d.s.UnreadByte()
  385. }
  386. func (d *Decoder) hasNilCode() bool {
  387. code, err := d.PeekCode()
  388. return err == nil && code == codes.Nil
  389. }
  390. func (d *Decoder) readCode() (codes.Code, error) {
  391. d.extLen = 0
  392. c, err := d.s.ReadByte()
  393. if err != nil {
  394. return 0, err
  395. }
  396. if d.rec != nil {
  397. d.rec = append(d.rec, c)
  398. }
  399. return codes.Code(c), nil
  400. }
  401. func (d *Decoder) readFull(b []byte) error {
  402. _, err := io.ReadFull(d.r, b)
  403. if err != nil {
  404. return err
  405. }
  406. if d.rec != nil {
  407. d.rec = append(d.rec, b...)
  408. }
  409. return nil
  410. }
  411. func (d *Decoder) readN(n int) ([]byte, error) {
  412. buf, err := readN(d.r, d.buf, n)
  413. if err != nil {
  414. return nil, err
  415. }
  416. d.buf = buf
  417. if d.rec != nil {
  418. d.rec = append(d.rec, buf...)
  419. }
  420. return buf, nil
  421. }
  422. func readN(r io.Reader, b []byte, n int) ([]byte, error) {
  423. if b == nil {
  424. if n == 0 {
  425. return make([]byte, 0), nil
  426. }
  427. if n <= bytesAllocLimit {
  428. b = make([]byte, n)
  429. } else {
  430. b = make([]byte, bytesAllocLimit)
  431. }
  432. }
  433. if n <= cap(b) {
  434. b = b[:n]
  435. _, err := io.ReadFull(r, b)
  436. return b, err
  437. }
  438. b = b[:cap(b)]
  439. var pos int
  440. for {
  441. alloc := n - len(b)
  442. if alloc > bytesAllocLimit {
  443. alloc = bytesAllocLimit
  444. }
  445. b = append(b, make([]byte, alloc)...)
  446. _, err := io.ReadFull(r, b[pos:])
  447. if err != nil {
  448. return nil, err
  449. }
  450. if len(b) == n {
  451. break
  452. }
  453. pos = len(b)
  454. }
  455. return b, nil
  456. }
  457. func min(a, b int) int {
  458. if a <= b {
  459. return a
  460. }
  461. return b
  462. }