|
|
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import ( "io" "io/ioutil" "sync" "testing" )
func muxPair() (*mux, *mux) { a, b := memPipe()
s := newMux(a) c := newMux(b)
return s, c }
// Returns both ends of a channel, and the mux for the the 2nd
// channel.
func channelPair(t *testing.T) (*channel, *channel, *mux) { c, s := muxPair()
res := make(chan *channel, 1) go func() { newCh, ok := <-s.incomingChannels if !ok { t.Fatalf("No incoming channel") } if newCh.ChannelType() != "chan" { t.Fatalf("got type %q want chan", newCh.ChannelType()) } ch, _, err := newCh.Accept() if err != nil { t.Fatalf("Accept %v", err) } res <- ch.(*channel) }()
ch, err := c.openChannel("chan", nil) if err != nil { t.Fatalf("OpenChannel: %v", err) }
return <-res, ch, c }
// Test that stderr and stdout can be addressed from different
// goroutines. This is intended for use with the race detector.
func TestMuxChannelExtendedThreadSafety(t *testing.T) { writer, reader, mux := channelPair(t) defer writer.Close() defer reader.Close() defer mux.Close()
var wr, rd sync.WaitGroup magic := "hello world"
wr.Add(2) go func() { io.WriteString(writer, magic) wr.Done() }() go func() { io.WriteString(writer.Stderr(), magic) wr.Done() }()
rd.Add(2) go func() { c, err := ioutil.ReadAll(reader) if string(c) != magic { t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err) } rd.Done() }() go func() { c, err := ioutil.ReadAll(reader.Stderr()) if string(c) != magic { t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err) } rd.Done() }()
wr.Wait() writer.CloseWrite() rd.Wait() }
func TestMuxReadWrite(t *testing.T) { s, c, mux := channelPair(t) defer s.Close() defer c.Close() defer mux.Close()
magic := "hello world" magicExt := "hello stderr" go func() { _, err := s.Write([]byte(magic)) if err != nil { t.Fatalf("Write: %v", err) } _, err = s.Extended(1).Write([]byte(magicExt)) if err != nil { t.Fatalf("Write: %v", err) } err = s.Close() if err != nil { t.Fatalf("Close: %v", err) } }()
var buf [1024]byte n, err := c.Read(buf[:]) if err != nil { t.Fatalf("server Read: %v", err) } got := string(buf[:n]) if got != magic { t.Fatalf("server: got %q want %q", got, magic) }
n, err = c.Extended(1).Read(buf[:]) if err != nil { t.Fatalf("server Read: %v", err) }
got = string(buf[:n]) if got != magicExt { t.Fatalf("server: got %q want %q", got, magic) } }
func TestMuxChannelOverflow(t *testing.T) { reader, writer, mux := channelPair(t) defer reader.Close() defer writer.Close() defer mux.Close()
wDone := make(chan int, 1) go func() { if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { t.Errorf("could not fill window: %v", err) } writer.Write(make([]byte, 1)) wDone <- 1 }() writer.remoteWin.waitWriterBlocked()
// Send 1 byte.
packet := make([]byte, 1+4+4+1) packet[0] = msgChannelData marshalUint32(packet[1:], writer.remoteId) marshalUint32(packet[5:], uint32(1)) packet[9] = 42
if err := writer.mux.conn.writePacket(packet); err != nil { t.Errorf("could not send packet") } if _, err := reader.SendRequest("hello", true, nil); err == nil { t.Errorf("SendRequest succeeded.") } <-wDone }
func TestMuxChannelCloseWriteUnblock(t *testing.T) { reader, writer, mux := channelPair(t) defer reader.Close() defer writer.Close() defer mux.Close()
wDone := make(chan int, 1) go func() { if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { t.Errorf("could not fill window: %v", err) } if _, err := writer.Write(make([]byte, 1)); err != io.EOF { t.Errorf("got %v, want EOF for unblock write", err) } wDone <- 1 }()
writer.remoteWin.waitWriterBlocked() reader.Close() <-wDone }
func TestMuxConnectionCloseWriteUnblock(t *testing.T) { reader, writer, mux := channelPair(t) defer reader.Close() defer writer.Close() defer mux.Close()
wDone := make(chan int, 1) go func() { if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { t.Errorf("could not fill window: %v", err) } if _, err := writer.Write(make([]byte, 1)); err != io.EOF { t.Errorf("got %v, want EOF for unblock write", err) } wDone <- 1 }()
writer.remoteWin.waitWriterBlocked() mux.Close() <-wDone }
func TestMuxReject(t *testing.T) { client, server := muxPair() defer server.Close() defer client.Close()
go func() { ch, ok := <-server.incomingChannels if !ok { t.Fatalf("Accept") } if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) } ch.Reject(RejectionReason(42), "message") }()
ch, err := client.openChannel("ch", []byte("extra")) if ch != nil { t.Fatal("openChannel not rejected") }
ocf, ok := err.(*OpenChannelError) if !ok { t.Errorf("got %#v want *OpenChannelError", err) } else if ocf.Reason != 42 || ocf.Message != "message" { t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message") }
want := "ssh: rejected: unknown reason 42 (message)" if err.Error() != want { t.Errorf("got %q, want %q", err.Error(), want) } }
func TestMuxChannelRequest(t *testing.T) { client, server, mux := channelPair(t) defer server.Close() defer client.Close() defer mux.Close()
var received int var wg sync.WaitGroup wg.Add(1) go func() { for r := range server.incomingRequests { received++ r.Reply(r.Type == "yes", nil) } wg.Done() }() _, err := client.SendRequest("yes", false, nil) if err != nil { t.Fatalf("SendRequest: %v", err) } ok, err := client.SendRequest("yes", true, nil) if err != nil { t.Fatalf("SendRequest: %v", err) }
if !ok { t.Errorf("SendRequest(yes): %v", ok)
}
ok, err = client.SendRequest("no", true, nil) if err != nil { t.Fatalf("SendRequest: %v", err) } if ok { t.Errorf("SendRequest(no): %v", ok)
}
client.Close() wg.Wait()
if received != 3 { t.Errorf("got %d requests, want %d", received, 3) } }
func TestMuxGlobalRequest(t *testing.T) { clientMux, serverMux := muxPair() defer serverMux.Close() defer clientMux.Close()
var seen bool go func() { for r := range serverMux.incomingRequests { seen = seen || r.Type == "peek" if r.WantReply { err := r.Reply(r.Type == "yes", append([]byte(r.Type), r.Payload...)) if err != nil { t.Errorf("AckRequest: %v", err) } } } }()
_, _, err := clientMux.SendRequest("peek", false, nil) if err != nil { t.Errorf("SendRequest: %v", err) }
ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) if !ok || string(data) != "yesa" || err != nil { t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", ok, data, err) } if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil { t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", ok, data, err) }
if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil { t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", ok, data, err) }
if !seen { t.Errorf("never saw 'peek' request") } }
func TestMuxGlobalRequestUnblock(t *testing.T) { clientMux, serverMux := muxPair() defer serverMux.Close() defer clientMux.Close()
result := make(chan error, 1) go func() { _, _, err := clientMux.SendRequest("hello", true, nil) result <- err }()
<-serverMux.incomingRequests serverMux.conn.Close() err := <-result
if err != io.EOF { t.Errorf("want EOF, got %v", io.EOF) } }
func TestMuxChannelRequestUnblock(t *testing.T) { a, b, connB := channelPair(t) defer a.Close() defer b.Close() defer connB.Close()
result := make(chan error, 1) go func() { _, err := a.SendRequest("hello", true, nil) result <- err }()
<-b.incomingRequests connB.conn.Close() err := <-result
if err != io.EOF { t.Errorf("want EOF, got %v", err) } }
func TestMuxCloseChannel(t *testing.T) { r, w, mux := channelPair(t) defer mux.Close() defer r.Close() defer w.Close()
result := make(chan error, 1) go func() { var b [1024]byte _, err := r.Read(b[:]) result <- err }() if err := w.Close(); err != nil { t.Errorf("w.Close: %v", err) }
if _, err := w.Write([]byte("hello")); err != io.EOF { t.Errorf("got err %v, want io.EOF after Close", err) }
if err := <-result; err != io.EOF { t.Errorf("got %v (%T), want io.EOF", err, err) } }
func TestMuxCloseWriteChannel(t *testing.T) { r, w, mux := channelPair(t) defer mux.Close()
result := make(chan error, 1) go func() { var b [1024]byte _, err := r.Read(b[:]) result <- err }() if err := w.CloseWrite(); err != nil { t.Errorf("w.CloseWrite: %v", err) }
if _, err := w.Write([]byte("hello")); err != io.EOF { t.Errorf("got err %v, want io.EOF after CloseWrite", err) }
if err := <-result; err != io.EOF { t.Errorf("got %v (%T), want io.EOF", err, err) } }
func TestMuxInvalidRecord(t *testing.T) { a, b := muxPair() defer a.Close() defer b.Close()
packet := make([]byte, 1+4+4+1) packet[0] = msgChannelData marshalUint32(packet[1:], 29348723 /* invalid channel id */) marshalUint32(packet[5:], 1) packet[9] = 42
a.conn.writePacket(packet) go a.SendRequest("hello", false, nil) // 'a' wrote an invalid packet, so 'b' has exited.
req, ok := <-b.incomingRequests if ok { t.Errorf("got request %#v after receiving invalid packet", req) } }
func TestZeroWindowAdjust(t *testing.T) { a, b, mux := channelPair(t) defer a.Close() defer b.Close() defer mux.Close()
go func() { io.WriteString(a, "hello") // bogus adjust.
a.sendMessage(windowAdjustMsg{}) io.WriteString(a, "world") a.Close() }()
want := "helloworld" c, _ := ioutil.ReadAll(b) if string(c) != want { t.Errorf("got %q want %q", c, want) } }
func TestMuxMaxPacketSize(t *testing.T) { a, b, mux := channelPair(t) defer a.Close() defer b.Close() defer mux.Close()
large := make([]byte, a.maxRemotePayload+1) packet := make([]byte, 1+4+4+1+len(large)) packet[0] = msgChannelData marshalUint32(packet[1:], a.remoteId) marshalUint32(packet[5:], uint32(len(large))) packet[9] = 42
if err := a.mux.conn.writePacket(packet); err != nil { t.Errorf("could not send packet") }
go a.SendRequest("hello", false, nil)
_, ok := <-b.incomingRequests if ok { t.Errorf("connection still alive after receiving large packet.") } }
// Don't ship code with debug=true.
func TestDebug(t *testing.T) { if debugMux { t.Error("mux debug switched on") } if debugHandshake { t.Error("handshake debug switched on") } if debugTransport { t.Error("transport debug switched on") } }
|