You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
2.0 KiB

  1. // Copyright 2013 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "io"
  7. "sync"
  8. "testing"
  9. )
  10. // An in-memory packetConn. It is safe to call Close and writePacket
  11. // from different goroutines.
  12. type memTransport struct {
  13. eof bool
  14. pending [][]byte
  15. write *memTransport
  16. sync.Mutex
  17. *sync.Cond
  18. }
  19. func (t *memTransport) readPacket() ([]byte, error) {
  20. t.Lock()
  21. defer t.Unlock()
  22. for {
  23. if len(t.pending) > 0 {
  24. r := t.pending[0]
  25. t.pending = t.pending[1:]
  26. return r, nil
  27. }
  28. if t.eof {
  29. return nil, io.EOF
  30. }
  31. t.Cond.Wait()
  32. }
  33. }
  34. func (t *memTransport) closeSelf() error {
  35. t.Lock()
  36. defer t.Unlock()
  37. if t.eof {
  38. return io.EOF
  39. }
  40. t.eof = true
  41. t.Cond.Broadcast()
  42. return nil
  43. }
  44. func (t *memTransport) Close() error {
  45. err := t.write.closeSelf()
  46. t.closeSelf()
  47. return err
  48. }
  49. func (t *memTransport) writePacket(p []byte) error {
  50. t.write.Lock()
  51. defer t.write.Unlock()
  52. if t.write.eof {
  53. return io.EOF
  54. }
  55. c := make([]byte, len(p))
  56. copy(c, p)
  57. t.write.pending = append(t.write.pending, c)
  58. t.write.Cond.Signal()
  59. return nil
  60. }
  61. func memPipe() (a, b packetConn) {
  62. t1 := memTransport{}
  63. t2 := memTransport{}
  64. t1.write = &t2
  65. t2.write = &t1
  66. t1.Cond = sync.NewCond(&t1.Mutex)
  67. t2.Cond = sync.NewCond(&t2.Mutex)
  68. return &t1, &t2
  69. }
  70. func TestMemPipe(t *testing.T) {
  71. a, b := memPipe()
  72. if err := a.writePacket([]byte{42}); err != nil {
  73. t.Fatalf("writePacket: %v", err)
  74. }
  75. if err := a.Close(); err != nil {
  76. t.Fatal("Close: ", err)
  77. }
  78. p, err := b.readPacket()
  79. if err != nil {
  80. t.Fatal("readPacket: ", err)
  81. }
  82. if len(p) != 1 || p[0] != 42 {
  83. t.Fatalf("got %v, want {42}", p)
  84. }
  85. p, err = b.readPacket()
  86. if err != io.EOF {
  87. t.Fatalf("got %v, %v, want EOF", p, err)
  88. }
  89. }
  90. func TestDoubleClose(t *testing.T) {
  91. a, _ := memPipe()
  92. err := a.Close()
  93. if err != nil {
  94. t.Errorf("Close: %v", err)
  95. }
  96. err = a.Close()
  97. if err != io.EOF {
  98. t.Errorf("expect EOF on double close.")
  99. }
  100. }