You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

665 lines
14 KiB

  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "bytes"
  7. "crypto/rand"
  8. "fmt"
  9. "io"
  10. "log"
  11. "net"
  12. "net/http"
  13. "net/http/httptest"
  14. "net/url"
  15. "reflect"
  16. "runtime"
  17. "strings"
  18. "sync"
  19. "testing"
  20. "time"
  21. )
  22. var serverAddr string
  23. var once sync.Once
  24. func echoServer(ws *Conn) {
  25. defer ws.Close()
  26. io.Copy(ws, ws)
  27. }
  28. type Count struct {
  29. S string
  30. N int
  31. }
  32. func countServer(ws *Conn) {
  33. defer ws.Close()
  34. for {
  35. var count Count
  36. err := JSON.Receive(ws, &count)
  37. if err != nil {
  38. return
  39. }
  40. count.N++
  41. count.S = strings.Repeat(count.S, count.N)
  42. err = JSON.Send(ws, count)
  43. if err != nil {
  44. return
  45. }
  46. }
  47. }
  48. type testCtrlAndDataHandler struct {
  49. hybiFrameHandler
  50. }
  51. func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) {
  52. h.hybiFrameHandler.conn.wio.Lock()
  53. defer h.hybiFrameHandler.conn.wio.Unlock()
  54. w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame)
  55. if err != nil {
  56. return 0, err
  57. }
  58. n, err := w.Write(b)
  59. w.Close()
  60. return n, err
  61. }
  62. func ctrlAndDataServer(ws *Conn) {
  63. defer ws.Close()
  64. h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
  65. ws.frameHandler = h
  66. go func() {
  67. for i := 0; ; i++ {
  68. var b []byte
  69. if i%2 != 0 { // with or without payload
  70. b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i))
  71. }
  72. if _, err := h.WritePing(b); err != nil {
  73. break
  74. }
  75. if _, err := h.WritePong(b); err != nil { // unsolicited pong
  76. break
  77. }
  78. time.Sleep(10 * time.Millisecond)
  79. }
  80. }()
  81. b := make([]byte, 128)
  82. for {
  83. n, err := ws.Read(b)
  84. if err != nil {
  85. break
  86. }
  87. if _, err := ws.Write(b[:n]); err != nil {
  88. break
  89. }
  90. }
  91. }
  92. func subProtocolHandshake(config *Config, req *http.Request) error {
  93. for _, proto := range config.Protocol {
  94. if proto == "chat" {
  95. config.Protocol = []string{proto}
  96. return nil
  97. }
  98. }
  99. return ErrBadWebSocketProtocol
  100. }
  101. func subProtoServer(ws *Conn) {
  102. for _, proto := range ws.Config().Protocol {
  103. io.WriteString(ws, proto)
  104. }
  105. }
  106. func startServer() {
  107. http.Handle("/echo", Handler(echoServer))
  108. http.Handle("/count", Handler(countServer))
  109. http.Handle("/ctrldata", Handler(ctrlAndDataServer))
  110. subproto := Server{
  111. Handshake: subProtocolHandshake,
  112. Handler: Handler(subProtoServer),
  113. }
  114. http.Handle("/subproto", subproto)
  115. server := httptest.NewServer(nil)
  116. serverAddr = server.Listener.Addr().String()
  117. log.Print("Test WebSocket server listening on ", serverAddr)
  118. }
  119. func newConfig(t *testing.T, path string) *Config {
  120. config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
  121. return config
  122. }
  123. func TestEcho(t *testing.T) {
  124. once.Do(startServer)
  125. // websocket.Dial()
  126. client, err := net.Dial("tcp", serverAddr)
  127. if err != nil {
  128. t.Fatal("dialing", err)
  129. }
  130. conn, err := NewClient(newConfig(t, "/echo"), client)
  131. if err != nil {
  132. t.Errorf("WebSocket handshake error: %v", err)
  133. return
  134. }
  135. msg := []byte("hello, world\n")
  136. if _, err := conn.Write(msg); err != nil {
  137. t.Errorf("Write: %v", err)
  138. }
  139. var actual_msg = make([]byte, 512)
  140. n, err := conn.Read(actual_msg)
  141. if err != nil {
  142. t.Errorf("Read: %v", err)
  143. }
  144. actual_msg = actual_msg[0:n]
  145. if !bytes.Equal(msg, actual_msg) {
  146. t.Errorf("Echo: expected %q got %q", msg, actual_msg)
  147. }
  148. conn.Close()
  149. }
  150. func TestAddr(t *testing.T) {
  151. once.Do(startServer)
  152. // websocket.Dial()
  153. client, err := net.Dial("tcp", serverAddr)
  154. if err != nil {
  155. t.Fatal("dialing", err)
  156. }
  157. conn, err := NewClient(newConfig(t, "/echo"), client)
  158. if err != nil {
  159. t.Errorf("WebSocket handshake error: %v", err)
  160. return
  161. }
  162. ra := conn.RemoteAddr().String()
  163. if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
  164. t.Errorf("Bad remote addr: %v", ra)
  165. }
  166. la := conn.LocalAddr().String()
  167. if !strings.HasPrefix(la, "http://") {
  168. t.Errorf("Bad local addr: %v", la)
  169. }
  170. conn.Close()
  171. }
  172. func TestCount(t *testing.T) {
  173. once.Do(startServer)
  174. // websocket.Dial()
  175. client, err := net.Dial("tcp", serverAddr)
  176. if err != nil {
  177. t.Fatal("dialing", err)
  178. }
  179. conn, err := NewClient(newConfig(t, "/count"), client)
  180. if err != nil {
  181. t.Errorf("WebSocket handshake error: %v", err)
  182. return
  183. }
  184. var count Count
  185. count.S = "hello"
  186. if err := JSON.Send(conn, count); err != nil {
  187. t.Errorf("Write: %v", err)
  188. }
  189. if err := JSON.Receive(conn, &count); err != nil {
  190. t.Errorf("Read: %v", err)
  191. }
  192. if count.N != 1 {
  193. t.Errorf("count: expected %d got %d", 1, count.N)
  194. }
  195. if count.S != "hello" {
  196. t.Errorf("count: expected %q got %q", "hello", count.S)
  197. }
  198. if err := JSON.Send(conn, count); err != nil {
  199. t.Errorf("Write: %v", err)
  200. }
  201. if err := JSON.Receive(conn, &count); err != nil {
  202. t.Errorf("Read: %v", err)
  203. }
  204. if count.N != 2 {
  205. t.Errorf("count: expected %d got %d", 2, count.N)
  206. }
  207. if count.S != "hellohello" {
  208. t.Errorf("count: expected %q got %q", "hellohello", count.S)
  209. }
  210. conn.Close()
  211. }
  212. func TestWithQuery(t *testing.T) {
  213. once.Do(startServer)
  214. client, err := net.Dial("tcp", serverAddr)
  215. if err != nil {
  216. t.Fatal("dialing", err)
  217. }
  218. config := newConfig(t, "/echo")
  219. config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
  220. if err != nil {
  221. t.Fatal("location url", err)
  222. }
  223. ws, err := NewClient(config, client)
  224. if err != nil {
  225. t.Errorf("WebSocket handshake: %v", err)
  226. return
  227. }
  228. ws.Close()
  229. }
  230. func testWithProtocol(t *testing.T, subproto []string) (string, error) {
  231. once.Do(startServer)
  232. client, err := net.Dial("tcp", serverAddr)
  233. if err != nil {
  234. t.Fatal("dialing", err)
  235. }
  236. config := newConfig(t, "/subproto")
  237. config.Protocol = subproto
  238. ws, err := NewClient(config, client)
  239. if err != nil {
  240. return "", err
  241. }
  242. msg := make([]byte, 16)
  243. n, err := ws.Read(msg)
  244. if err != nil {
  245. return "", err
  246. }
  247. ws.Close()
  248. return string(msg[:n]), nil
  249. }
  250. func TestWithProtocol(t *testing.T) {
  251. proto, err := testWithProtocol(t, []string{"chat"})
  252. if err != nil {
  253. t.Errorf("SubProto: unexpected error: %v", err)
  254. }
  255. if proto != "chat" {
  256. t.Errorf("SubProto: expected %q, got %q", "chat", proto)
  257. }
  258. }
  259. func TestWithTwoProtocol(t *testing.T) {
  260. proto, err := testWithProtocol(t, []string{"test", "chat"})
  261. if err != nil {
  262. t.Errorf("SubProto: unexpected error: %v", err)
  263. }
  264. if proto != "chat" {
  265. t.Errorf("SubProto: expected %q, got %q", "chat", proto)
  266. }
  267. }
  268. func TestWithBadProtocol(t *testing.T) {
  269. _, err := testWithProtocol(t, []string{"test"})
  270. if err != ErrBadStatus {
  271. t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
  272. }
  273. }
  274. func TestHTTP(t *testing.T) {
  275. once.Do(startServer)
  276. // If the client did not send a handshake that matches the protocol
  277. // specification, the server MUST return an HTTP response with an
  278. // appropriate error code (such as 400 Bad Request)
  279. resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
  280. if err != nil {
  281. t.Errorf("Get: error %#v", err)
  282. return
  283. }
  284. if resp == nil {
  285. t.Error("Get: resp is null")
  286. return
  287. }
  288. if resp.StatusCode != http.StatusBadRequest {
  289. t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
  290. }
  291. }
  292. func TestTrailingSpaces(t *testing.T) {
  293. // http://code.google.com/p/go/issues/detail?id=955
  294. // The last runs of this create keys with trailing spaces that should not be
  295. // generated by the client.
  296. once.Do(startServer)
  297. config := newConfig(t, "/echo")
  298. for i := 0; i < 30; i++ {
  299. // body
  300. ws, err := DialConfig(config)
  301. if err != nil {
  302. t.Errorf("Dial #%d failed: %v", i, err)
  303. break
  304. }
  305. ws.Close()
  306. }
  307. }
  308. func TestDialConfigBadVersion(t *testing.T) {
  309. once.Do(startServer)
  310. config := newConfig(t, "/echo")
  311. config.Version = 1234
  312. _, err := DialConfig(config)
  313. if dialerr, ok := err.(*DialError); ok {
  314. if dialerr.Err != ErrBadProtocolVersion {
  315. t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
  316. }
  317. }
  318. }
  319. func TestDialConfigWithDialer(t *testing.T) {
  320. once.Do(startServer)
  321. config := newConfig(t, "/echo")
  322. config.Dialer = &net.Dialer{
  323. Deadline: time.Now().Add(-time.Minute),
  324. }
  325. _, err := DialConfig(config)
  326. dialerr, ok := err.(*DialError)
  327. if !ok {
  328. t.Fatalf("DialError expected, got %#v", err)
  329. }
  330. neterr, ok := dialerr.Err.(*net.OpError)
  331. if !ok {
  332. t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
  333. }
  334. if !neterr.Timeout() {
  335. t.Fatalf("expected timeout error, got %#v", neterr)
  336. }
  337. }
  338. func TestSmallBuffer(t *testing.T) {
  339. // http://code.google.com/p/go/issues/detail?id=1145
  340. // Read should be able to handle reading a fragment of a frame.
  341. once.Do(startServer)
  342. // websocket.Dial()
  343. client, err := net.Dial("tcp", serverAddr)
  344. if err != nil {
  345. t.Fatal("dialing", err)
  346. }
  347. conn, err := NewClient(newConfig(t, "/echo"), client)
  348. if err != nil {
  349. t.Errorf("WebSocket handshake error: %v", err)
  350. return
  351. }
  352. msg := []byte("hello, world\n")
  353. if _, err := conn.Write(msg); err != nil {
  354. t.Errorf("Write: %v", err)
  355. }
  356. var small_msg = make([]byte, 8)
  357. n, err := conn.Read(small_msg)
  358. if err != nil {
  359. t.Errorf("Read: %v", err)
  360. }
  361. if !bytes.Equal(msg[:len(small_msg)], small_msg) {
  362. t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
  363. }
  364. var second_msg = make([]byte, len(msg))
  365. n, err = conn.Read(second_msg)
  366. if err != nil {
  367. t.Errorf("Read: %v", err)
  368. }
  369. second_msg = second_msg[0:n]
  370. if !bytes.Equal(msg[len(small_msg):], second_msg) {
  371. t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
  372. }
  373. conn.Close()
  374. }
  375. var parseAuthorityTests = []struct {
  376. in *url.URL
  377. out string
  378. }{
  379. {
  380. &url.URL{
  381. Scheme: "ws",
  382. Host: "www.google.com",
  383. },
  384. "www.google.com:80",
  385. },
  386. {
  387. &url.URL{
  388. Scheme: "wss",
  389. Host: "www.google.com",
  390. },
  391. "www.google.com:443",
  392. },
  393. {
  394. &url.URL{
  395. Scheme: "ws",
  396. Host: "www.google.com:80",
  397. },
  398. "www.google.com:80",
  399. },
  400. {
  401. &url.URL{
  402. Scheme: "wss",
  403. Host: "www.google.com:443",
  404. },
  405. "www.google.com:443",
  406. },
  407. // some invalid ones for parseAuthority. parseAuthority doesn't
  408. // concern itself with the scheme unless it actually knows about it
  409. {
  410. &url.URL{
  411. Scheme: "http",
  412. Host: "www.google.com",
  413. },
  414. "www.google.com",
  415. },
  416. {
  417. &url.URL{
  418. Scheme: "http",
  419. Host: "www.google.com:80",
  420. },
  421. "www.google.com:80",
  422. },
  423. {
  424. &url.URL{
  425. Scheme: "asdf",
  426. Host: "127.0.0.1",
  427. },
  428. "127.0.0.1",
  429. },
  430. {
  431. &url.URL{
  432. Scheme: "asdf",
  433. Host: "www.google.com",
  434. },
  435. "www.google.com",
  436. },
  437. }
  438. func TestParseAuthority(t *testing.T) {
  439. for _, tt := range parseAuthorityTests {
  440. out := parseAuthority(tt.in)
  441. if out != tt.out {
  442. t.Errorf("got %v; want %v", out, tt.out)
  443. }
  444. }
  445. }
  446. type closerConn struct {
  447. net.Conn
  448. closed int // count of the number of times Close was called
  449. }
  450. func (c *closerConn) Close() error {
  451. c.closed++
  452. return c.Conn.Close()
  453. }
  454. func TestClose(t *testing.T) {
  455. if runtime.GOOS == "plan9" {
  456. t.Skip("see golang.org/issue/11454")
  457. }
  458. once.Do(startServer)
  459. conn, err := net.Dial("tcp", serverAddr)
  460. if err != nil {
  461. t.Fatal("dialing", err)
  462. }
  463. cc := closerConn{Conn: conn}
  464. client, err := NewClient(newConfig(t, "/echo"), &cc)
  465. if err != nil {
  466. t.Fatalf("WebSocket handshake: %v", err)
  467. }
  468. // set the deadline to ten minutes ago, which will have expired by the time
  469. // client.Close sends the close status frame.
  470. conn.SetDeadline(time.Now().Add(-10 * time.Minute))
  471. if err := client.Close(); err == nil {
  472. t.Errorf("ws.Close(): expected error, got %v", err)
  473. }
  474. if cc.closed < 1 {
  475. t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
  476. }
  477. }
  478. var originTests = []struct {
  479. req *http.Request
  480. origin *url.URL
  481. }{
  482. {
  483. req: &http.Request{
  484. Header: http.Header{
  485. "Origin": []string{"http://www.example.com"},
  486. },
  487. },
  488. origin: &url.URL{
  489. Scheme: "http",
  490. Host: "www.example.com",
  491. },
  492. },
  493. {
  494. req: &http.Request{},
  495. },
  496. }
  497. func TestOrigin(t *testing.T) {
  498. conf := newConfig(t, "/echo")
  499. conf.Version = ProtocolVersionHybi13
  500. for i, tt := range originTests {
  501. origin, err := Origin(conf, tt.req)
  502. if err != nil {
  503. t.Error(err)
  504. continue
  505. }
  506. if !reflect.DeepEqual(origin, tt.origin) {
  507. t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin)
  508. continue
  509. }
  510. }
  511. }
  512. func TestCtrlAndData(t *testing.T) {
  513. once.Do(startServer)
  514. c, err := net.Dial("tcp", serverAddr)
  515. if err != nil {
  516. t.Fatal(err)
  517. }
  518. ws, err := NewClient(newConfig(t, "/ctrldata"), c)
  519. if err != nil {
  520. t.Fatal(err)
  521. }
  522. defer ws.Close()
  523. h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
  524. ws.frameHandler = h
  525. b := make([]byte, 128)
  526. for i := 0; i < 2; i++ {
  527. data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i))
  528. if _, err := ws.Write(data); err != nil {
  529. t.Fatalf("#%d: %v", i, err)
  530. }
  531. var ctrl []byte
  532. if i%2 != 0 { // with or without payload
  533. ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i))
  534. }
  535. if _, err := h.WritePing(ctrl); err != nil {
  536. t.Fatalf("#%d: %v", i, err)
  537. }
  538. n, err := ws.Read(b)
  539. if err != nil {
  540. t.Fatalf("#%d: %v", i, err)
  541. }
  542. if !bytes.Equal(b[:n], data) {
  543. t.Fatalf("#%d: got %v; want %v", i, b[:n], data)
  544. }
  545. }
  546. }
  547. func TestCodec_ReceiveLimited(t *testing.T) {
  548. const limit = 2048
  549. var payloads [][]byte
  550. for _, size := range []int{
  551. 1024,
  552. 2048,
  553. 4096, // receive of this message would be interrupted due to limit
  554. 2048, // this one is to make sure next receive recovers discarding leftovers
  555. } {
  556. b := make([]byte, size)
  557. rand.Read(b)
  558. payloads = append(payloads, b)
  559. }
  560. handlerDone := make(chan struct{})
  561. limitedHandler := func(ws *Conn) {
  562. defer close(handlerDone)
  563. ws.MaxPayloadBytes = limit
  564. defer ws.Close()
  565. for i, p := range payloads {
  566. t.Logf("payload #%d (size %d, exceeds limit: %v)", i, len(p), len(p) > limit)
  567. var recv []byte
  568. err := Message.Receive(ws, &recv)
  569. switch err {
  570. case nil:
  571. case ErrFrameTooLarge:
  572. if len(p) <= limit {
  573. t.Fatalf("unexpected frame size limit: expected %d bytes of payload having limit at %d", len(p), limit)
  574. }
  575. continue
  576. default:
  577. t.Fatalf("unexpected error: %v (want either nil or ErrFrameTooLarge)", err)
  578. }
  579. if len(recv) > limit {
  580. t.Fatalf("received %d bytes of payload having limit at %d", len(recv), limit)
  581. }
  582. if !bytes.Equal(p, recv) {
  583. t.Fatalf("received payload differs:\ngot:\t%v\nwant:\t%v", recv, p)
  584. }
  585. }
  586. }
  587. server := httptest.NewServer(Handler(limitedHandler))
  588. defer server.CloseClientConnections()
  589. defer server.Close()
  590. addr := server.Listener.Addr().String()
  591. ws, err := Dial("ws://"+addr+"/", "", "http://localhost/")
  592. if err != nil {
  593. t.Fatal(err)
  594. }
  595. defer ws.Close()
  596. for i, p := range payloads {
  597. if err := Message.Send(ws, p); err != nil {
  598. t.Fatalf("payload #%d (size %d): %v", i, len(p), err)
  599. }
  600. }
  601. <-handlerDone
  602. }