You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

559 lines
13 KiB

  1. // Copyright 2013 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "bytes"
  7. "crypto/rand"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "net"
  12. "reflect"
  13. "runtime"
  14. "strings"
  15. "sync"
  16. "testing"
  17. )
  18. type testChecker struct {
  19. calls []string
  20. }
  21. func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
  22. if dialAddr == "bad" {
  23. return fmt.Errorf("dialAddr is bad")
  24. }
  25. if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
  26. return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
  27. }
  28. t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
  29. return nil
  30. }
  31. // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
  32. // therefore is buffered (net.Pipe deadlocks if both sides start with
  33. // a write.)
  34. func netPipe() (net.Conn, net.Conn, error) {
  35. listener, err := net.Listen("tcp", "127.0.0.1:0")
  36. if err != nil {
  37. listener, err = net.Listen("tcp", "[::1]:0")
  38. if err != nil {
  39. return nil, nil, err
  40. }
  41. }
  42. defer listener.Close()
  43. c1, err := net.Dial("tcp", listener.Addr().String())
  44. if err != nil {
  45. return nil, nil, err
  46. }
  47. c2, err := listener.Accept()
  48. if err != nil {
  49. c1.Close()
  50. return nil, nil, err
  51. }
  52. return c1, c2, nil
  53. }
  54. // noiseTransport inserts ignore messages to check that the read loop
  55. // and the key exchange filters out these messages.
  56. type noiseTransport struct {
  57. keyingTransport
  58. }
  59. func (t *noiseTransport) writePacket(p []byte) error {
  60. ignore := []byte{msgIgnore}
  61. if err := t.keyingTransport.writePacket(ignore); err != nil {
  62. return err
  63. }
  64. debug := []byte{msgDebug, 1, 2, 3}
  65. if err := t.keyingTransport.writePacket(debug); err != nil {
  66. return err
  67. }
  68. return t.keyingTransport.writePacket(p)
  69. }
  70. func addNoiseTransport(t keyingTransport) keyingTransport {
  71. return &noiseTransport{t}
  72. }
  73. // handshakePair creates two handshakeTransports connected with each
  74. // other. If the noise argument is true, both transports will try to
  75. // confuse the other side by sending ignore and debug messages.
  76. func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) {
  77. a, b, err := netPipe()
  78. if err != nil {
  79. return nil, nil, err
  80. }
  81. var trC, trS keyingTransport
  82. trC = newTransport(a, rand.Reader, true)
  83. trS = newTransport(b, rand.Reader, false)
  84. if noise {
  85. trC = addNoiseTransport(trC)
  86. trS = addNoiseTransport(trS)
  87. }
  88. clientConf.SetDefaults()
  89. v := []byte("version")
  90. client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
  91. serverConf := &ServerConfig{}
  92. serverConf.AddHostKey(testSigners["ecdsa"])
  93. serverConf.AddHostKey(testSigners["rsa"])
  94. serverConf.SetDefaults()
  95. server = newServerTransport(trS, v, v, serverConf)
  96. if err := server.waitSession(); err != nil {
  97. return nil, nil, fmt.Errorf("server.waitSession: %v", err)
  98. }
  99. if err := client.waitSession(); err != nil {
  100. return nil, nil, fmt.Errorf("client.waitSession: %v", err)
  101. }
  102. return client, server, nil
  103. }
  104. func TestHandshakeBasic(t *testing.T) {
  105. if runtime.GOOS == "plan9" {
  106. t.Skip("see golang.org/issue/7237")
  107. }
  108. checker := &syncChecker{
  109. waitCall: make(chan int, 10),
  110. called: make(chan int, 10),
  111. }
  112. checker.waitCall <- 1
  113. trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
  114. if err != nil {
  115. t.Fatalf("handshakePair: %v", err)
  116. }
  117. defer trC.Close()
  118. defer trS.Close()
  119. // Let first kex complete normally.
  120. <-checker.called
  121. clientDone := make(chan int, 0)
  122. gotHalf := make(chan int, 0)
  123. const N = 20
  124. go func() {
  125. defer close(clientDone)
  126. // Client writes a bunch of stuff, and does a key
  127. // change in the middle. This should not confuse the
  128. // handshake in progress. We do this twice, so we test
  129. // that the packet buffer is reset correctly.
  130. for i := 0; i < N; i++ {
  131. p := []byte{msgRequestSuccess, byte(i)}
  132. if err := trC.writePacket(p); err != nil {
  133. t.Fatalf("sendPacket: %v", err)
  134. }
  135. if (i % 10) == 5 {
  136. <-gotHalf
  137. // halfway through, we request a key change.
  138. trC.requestKeyExchange()
  139. // Wait until we can be sure the key
  140. // change has really started before we
  141. // write more.
  142. <-checker.called
  143. }
  144. if (i % 10) == 7 {
  145. // write some packets until the kex
  146. // completes, to test buffering of
  147. // packets.
  148. checker.waitCall <- 1
  149. }
  150. }
  151. }()
  152. // Server checks that client messages come in cleanly
  153. i := 0
  154. err = nil
  155. for ; i < N; i++ {
  156. var p []byte
  157. p, err = trS.readPacket()
  158. if err != nil {
  159. break
  160. }
  161. if (i % 10) == 5 {
  162. gotHalf <- 1
  163. }
  164. want := []byte{msgRequestSuccess, byte(i)}
  165. if bytes.Compare(p, want) != 0 {
  166. t.Errorf("message %d: got %v, want %v", i, p, want)
  167. }
  168. }
  169. <-clientDone
  170. if err != nil && err != io.EOF {
  171. t.Fatalf("server error: %v", err)
  172. }
  173. if i != N {
  174. t.Errorf("received %d messages, want 10.", i)
  175. }
  176. close(checker.called)
  177. if _, ok := <-checker.called; ok {
  178. // If all went well, we registered exactly 2 key changes: one
  179. // that establishes the session, and one that we requested
  180. // additionally.
  181. t.Fatalf("got another host key checks after 2 handshakes")
  182. }
  183. }
  184. func TestForceFirstKex(t *testing.T) {
  185. // like handshakePair, but must access the keyingTransport.
  186. checker := &testChecker{}
  187. clientConf := &ClientConfig{HostKeyCallback: checker.Check}
  188. a, b, err := netPipe()
  189. if err != nil {
  190. t.Fatalf("netPipe: %v", err)
  191. }
  192. var trC, trS keyingTransport
  193. trC = newTransport(a, rand.Reader, true)
  194. // This is the disallowed packet:
  195. trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
  196. // Rest of the setup.
  197. trS = newTransport(b, rand.Reader, false)
  198. clientConf.SetDefaults()
  199. v := []byte("version")
  200. client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
  201. serverConf := &ServerConfig{}
  202. serverConf.AddHostKey(testSigners["ecdsa"])
  203. serverConf.AddHostKey(testSigners["rsa"])
  204. serverConf.SetDefaults()
  205. server := newServerTransport(trS, v, v, serverConf)
  206. defer client.Close()
  207. defer server.Close()
  208. // We setup the initial key exchange, but the remote side
  209. // tries to send serviceRequestMsg in cleartext, which is
  210. // disallowed.
  211. if err := server.waitSession(); err == nil {
  212. t.Errorf("server first kex init should reject unexpected packet")
  213. }
  214. }
  215. func TestHandshakeAutoRekeyWrite(t *testing.T) {
  216. checker := &syncChecker{
  217. called: make(chan int, 10),
  218. waitCall: nil,
  219. }
  220. clientConf := &ClientConfig{HostKeyCallback: checker.Check}
  221. clientConf.RekeyThreshold = 500
  222. trC, trS, err := handshakePair(clientConf, "addr", false)
  223. if err != nil {
  224. t.Fatalf("handshakePair: %v", err)
  225. }
  226. defer trC.Close()
  227. defer trS.Close()
  228. input := make([]byte, 251)
  229. input[0] = msgRequestSuccess
  230. done := make(chan int, 1)
  231. const numPacket = 5
  232. go func() {
  233. defer close(done)
  234. j := 0
  235. for ; j < numPacket; j++ {
  236. if p, err := trS.readPacket(); err != nil {
  237. break
  238. } else if !bytes.Equal(input, p) {
  239. t.Errorf("got packet type %d, want %d", p[0], input[0])
  240. }
  241. }
  242. if j != numPacket {
  243. t.Errorf("got %d, want 5 messages", j)
  244. }
  245. }()
  246. <-checker.called
  247. for i := 0; i < numPacket; i++ {
  248. p := make([]byte, len(input))
  249. copy(p, input)
  250. if err := trC.writePacket(p); err != nil {
  251. t.Errorf("writePacket: %v", err)
  252. }
  253. if i == 2 {
  254. // Make sure the kex is in progress.
  255. <-checker.called
  256. }
  257. }
  258. <-done
  259. }
  260. type syncChecker struct {
  261. waitCall chan int
  262. called chan int
  263. }
  264. func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
  265. c.called <- 1
  266. if c.waitCall != nil {
  267. <-c.waitCall
  268. }
  269. return nil
  270. }
  271. func TestHandshakeAutoRekeyRead(t *testing.T) {
  272. sync := &syncChecker{
  273. called: make(chan int, 2),
  274. waitCall: nil,
  275. }
  276. clientConf := &ClientConfig{
  277. HostKeyCallback: sync.Check,
  278. }
  279. clientConf.RekeyThreshold = 500
  280. trC, trS, err := handshakePair(clientConf, "addr", false)
  281. if err != nil {
  282. t.Fatalf("handshakePair: %v", err)
  283. }
  284. defer trC.Close()
  285. defer trS.Close()
  286. packet := make([]byte, 501)
  287. packet[0] = msgRequestSuccess
  288. if err := trS.writePacket(packet); err != nil {
  289. t.Fatalf("writePacket: %v", err)
  290. }
  291. // While we read out the packet, a key change will be
  292. // initiated.
  293. done := make(chan int, 1)
  294. go func() {
  295. defer close(done)
  296. if _, err := trC.readPacket(); err != nil {
  297. t.Fatalf("readPacket(client): %v", err)
  298. }
  299. }()
  300. <-done
  301. <-sync.called
  302. }
  303. // errorKeyingTransport generates errors after a given number of
  304. // read/write operations.
  305. type errorKeyingTransport struct {
  306. packetConn
  307. readLeft, writeLeft int
  308. }
  309. func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
  310. return nil
  311. }
  312. func (n *errorKeyingTransport) getSessionID() []byte {
  313. return nil
  314. }
  315. func (n *errorKeyingTransport) writePacket(packet []byte) error {
  316. if n.writeLeft == 0 {
  317. n.Close()
  318. return errors.New("barf")
  319. }
  320. n.writeLeft--
  321. return n.packetConn.writePacket(packet)
  322. }
  323. func (n *errorKeyingTransport) readPacket() ([]byte, error) {
  324. if n.readLeft == 0 {
  325. n.Close()
  326. return nil, errors.New("barf")
  327. }
  328. n.readLeft--
  329. return n.packetConn.readPacket()
  330. }
  331. func TestHandshakeErrorHandlingRead(t *testing.T) {
  332. for i := 0; i < 20; i++ {
  333. testHandshakeErrorHandlingN(t, i, -1, false)
  334. }
  335. }
  336. func TestHandshakeErrorHandlingWrite(t *testing.T) {
  337. for i := 0; i < 20; i++ {
  338. testHandshakeErrorHandlingN(t, -1, i, false)
  339. }
  340. }
  341. func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
  342. for i := 0; i < 20; i++ {
  343. testHandshakeErrorHandlingN(t, i, -1, true)
  344. }
  345. }
  346. func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
  347. for i := 0; i < 20; i++ {
  348. testHandshakeErrorHandlingN(t, -1, i, true)
  349. }
  350. }
  351. // testHandshakeErrorHandlingN runs handshakes, injecting errors. If
  352. // handshakeTransport deadlocks, the go runtime will detect it and
  353. // panic.
  354. func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
  355. msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
  356. a, b := memPipe()
  357. defer a.Close()
  358. defer b.Close()
  359. key := testSigners["ecdsa"]
  360. serverConf := Config{RekeyThreshold: minRekeyThreshold}
  361. serverConf.SetDefaults()
  362. serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
  363. serverConn.hostKeys = []Signer{key}
  364. go serverConn.readLoop()
  365. go serverConn.kexLoop()
  366. clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
  367. clientConf.SetDefaults()
  368. clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
  369. clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
  370. clientConn.hostKeyCallback = InsecureIgnoreHostKey()
  371. go clientConn.readLoop()
  372. go clientConn.kexLoop()
  373. var wg sync.WaitGroup
  374. for _, hs := range []packetConn{serverConn, clientConn} {
  375. if !coupled {
  376. wg.Add(2)
  377. go func(c packetConn) {
  378. for i := 0; ; i++ {
  379. str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
  380. err := c.writePacket(Marshal(&serviceRequestMsg{str}))
  381. if err != nil {
  382. break
  383. }
  384. }
  385. wg.Done()
  386. c.Close()
  387. }(hs)
  388. go func(c packetConn) {
  389. for {
  390. _, err := c.readPacket()
  391. if err != nil {
  392. break
  393. }
  394. }
  395. wg.Done()
  396. }(hs)
  397. } else {
  398. wg.Add(1)
  399. go func(c packetConn) {
  400. for {
  401. _, err := c.readPacket()
  402. if err != nil {
  403. break
  404. }
  405. if err := c.writePacket(msg); err != nil {
  406. break
  407. }
  408. }
  409. wg.Done()
  410. }(hs)
  411. }
  412. }
  413. wg.Wait()
  414. }
  415. func TestDisconnect(t *testing.T) {
  416. if runtime.GOOS == "plan9" {
  417. t.Skip("see golang.org/issue/7237")
  418. }
  419. checker := &testChecker{}
  420. trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
  421. if err != nil {
  422. t.Fatalf("handshakePair: %v", err)
  423. }
  424. defer trC.Close()
  425. defer trS.Close()
  426. trC.writePacket([]byte{msgRequestSuccess, 0, 0})
  427. errMsg := &disconnectMsg{
  428. Reason: 42,
  429. Message: "such is life",
  430. }
  431. trC.writePacket(Marshal(errMsg))
  432. trC.writePacket([]byte{msgRequestSuccess, 0, 0})
  433. packet, err := trS.readPacket()
  434. if err != nil {
  435. t.Fatalf("readPacket 1: %v", err)
  436. }
  437. if packet[0] != msgRequestSuccess {
  438. t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
  439. }
  440. _, err = trS.readPacket()
  441. if err == nil {
  442. t.Errorf("readPacket 2 succeeded")
  443. } else if !reflect.DeepEqual(err, errMsg) {
  444. t.Errorf("got error %#v, want %#v", err, errMsg)
  445. }
  446. _, err = trS.readPacket()
  447. if err == nil {
  448. t.Errorf("readPacket 3 succeeded")
  449. }
  450. }
  451. func TestHandshakeRekeyDefault(t *testing.T) {
  452. clientConf := &ClientConfig{
  453. Config: Config{
  454. Ciphers: []string{"aes128-ctr"},
  455. },
  456. HostKeyCallback: InsecureIgnoreHostKey(),
  457. }
  458. trC, trS, err := handshakePair(clientConf, "addr", false)
  459. if err != nil {
  460. t.Fatalf("handshakePair: %v", err)
  461. }
  462. defer trC.Close()
  463. defer trS.Close()
  464. trC.writePacket([]byte{msgRequestSuccess, 0, 0})
  465. trC.Close()
  466. rgb := (1024 + trC.readBytesLeft) >> 30
  467. wgb := (1024 + trC.writeBytesLeft) >> 30
  468. if rgb != 64 {
  469. t.Errorf("got rekey after %dG read, want 64G", rgb)
  470. }
  471. if wgb != 64 {
  472. t.Errorf("got rekey after %dG write, want 64G", wgb)
  473. }
  474. }