You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

475 lines
8.4 KiB

  1. package readline
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/binary"
  6. "fmt"
  7. "io"
  8. "net"
  9. "os"
  10. "sync"
  11. "sync/atomic"
  12. )
  13. type MsgType int16
  14. const (
  15. T_DATA = MsgType(iota)
  16. T_WIDTH
  17. T_WIDTH_REPORT
  18. T_ISTTY_REPORT
  19. T_RAW
  20. T_ERAW // exit raw
  21. T_EOF
  22. )
  23. type RemoteSvr struct {
  24. eof int32
  25. closed int32
  26. width int32
  27. reciveChan chan struct{}
  28. writeChan chan *writeCtx
  29. conn net.Conn
  30. isTerminal bool
  31. funcWidthChan func()
  32. stopChan chan struct{}
  33. dataBufM sync.Mutex
  34. dataBuf bytes.Buffer
  35. }
  36. type writeReply struct {
  37. n int
  38. err error
  39. }
  40. type writeCtx struct {
  41. msg *Message
  42. reply chan *writeReply
  43. }
  44. func newWriteCtx(msg *Message) *writeCtx {
  45. return &writeCtx{
  46. msg: msg,
  47. reply: make(chan *writeReply),
  48. }
  49. }
  50. func NewRemoteSvr(conn net.Conn) (*RemoteSvr, error) {
  51. rs := &RemoteSvr{
  52. width: -1,
  53. conn: conn,
  54. writeChan: make(chan *writeCtx),
  55. reciveChan: make(chan struct{}),
  56. stopChan: make(chan struct{}),
  57. }
  58. buf := bufio.NewReader(rs.conn)
  59. if err := rs.init(buf); err != nil {
  60. return nil, err
  61. }
  62. go rs.readLoop(buf)
  63. go rs.writeLoop()
  64. return rs, nil
  65. }
  66. func (r *RemoteSvr) init(buf *bufio.Reader) error {
  67. m, err := ReadMessage(buf)
  68. if err != nil {
  69. return err
  70. }
  71. // receive isTerminal
  72. if m.Type != T_ISTTY_REPORT {
  73. return fmt.Errorf("unexpected init message")
  74. }
  75. r.GotIsTerminal(m.Data)
  76. // receive width
  77. m, err = ReadMessage(buf)
  78. if err != nil {
  79. return err
  80. }
  81. if m.Type != T_WIDTH_REPORT {
  82. return fmt.Errorf("unexpected init message")
  83. }
  84. r.GotReportWidth(m.Data)
  85. return nil
  86. }
  87. func (r *RemoteSvr) HandleConfig(cfg *Config) {
  88. cfg.Stderr = r
  89. cfg.Stdout = r
  90. cfg.Stdin = r
  91. cfg.FuncExitRaw = r.ExitRawMode
  92. cfg.FuncIsTerminal = r.IsTerminal
  93. cfg.FuncMakeRaw = r.EnterRawMode
  94. cfg.FuncExitRaw = r.ExitRawMode
  95. cfg.FuncGetWidth = r.GetWidth
  96. cfg.FuncOnWidthChanged = func(f func()) {
  97. r.funcWidthChan = f
  98. }
  99. }
  100. func (r *RemoteSvr) IsTerminal() bool {
  101. return r.isTerminal
  102. }
  103. func (r *RemoteSvr) checkEOF() error {
  104. if atomic.LoadInt32(&r.eof) == 1 {
  105. return io.EOF
  106. }
  107. return nil
  108. }
  109. func (r *RemoteSvr) Read(b []byte) (int, error) {
  110. r.dataBufM.Lock()
  111. n, err := r.dataBuf.Read(b)
  112. r.dataBufM.Unlock()
  113. if n == 0 {
  114. if err := r.checkEOF(); err != nil {
  115. return 0, err
  116. }
  117. }
  118. if n == 0 && err == io.EOF {
  119. <-r.reciveChan
  120. r.dataBufM.Lock()
  121. n, err = r.dataBuf.Read(b)
  122. r.dataBufM.Unlock()
  123. }
  124. if n == 0 {
  125. if err := r.checkEOF(); err != nil {
  126. return 0, err
  127. }
  128. }
  129. return n, err
  130. }
  131. func (r *RemoteSvr) writeMsg(m *Message) error {
  132. ctx := newWriteCtx(m)
  133. r.writeChan <- ctx
  134. reply := <-ctx.reply
  135. return reply.err
  136. }
  137. func (r *RemoteSvr) Write(b []byte) (int, error) {
  138. ctx := newWriteCtx(NewMessage(T_DATA, b))
  139. r.writeChan <- ctx
  140. reply := <-ctx.reply
  141. return reply.n, reply.err
  142. }
  143. func (r *RemoteSvr) EnterRawMode() error {
  144. return r.writeMsg(NewMessage(T_RAW, nil))
  145. }
  146. func (r *RemoteSvr) ExitRawMode() error {
  147. return r.writeMsg(NewMessage(T_ERAW, nil))
  148. }
  149. func (r *RemoteSvr) writeLoop() {
  150. defer r.Close()
  151. loop:
  152. for {
  153. select {
  154. case ctx, ok := <-r.writeChan:
  155. if !ok {
  156. break
  157. }
  158. n, err := ctx.msg.WriteTo(r.conn)
  159. ctx.reply <- &writeReply{n, err}
  160. case <-r.stopChan:
  161. break loop
  162. }
  163. }
  164. }
  165. func (r *RemoteSvr) Close() error {
  166. if atomic.CompareAndSwapInt32(&r.closed, 0, 1) {
  167. close(r.stopChan)
  168. r.conn.Close()
  169. }
  170. return nil
  171. }
  172. func (r *RemoteSvr) readLoop(buf *bufio.Reader) {
  173. defer r.Close()
  174. for {
  175. m, err := ReadMessage(buf)
  176. if err != nil {
  177. break
  178. }
  179. switch m.Type {
  180. case T_EOF:
  181. atomic.StoreInt32(&r.eof, 1)
  182. select {
  183. case r.reciveChan <- struct{}{}:
  184. default:
  185. }
  186. case T_DATA:
  187. r.dataBufM.Lock()
  188. r.dataBuf.Write(m.Data)
  189. r.dataBufM.Unlock()
  190. select {
  191. case r.reciveChan <- struct{}{}:
  192. default:
  193. }
  194. case T_WIDTH_REPORT:
  195. r.GotReportWidth(m.Data)
  196. case T_ISTTY_REPORT:
  197. r.GotIsTerminal(m.Data)
  198. }
  199. }
  200. }
  201. func (r *RemoteSvr) GotIsTerminal(data []byte) {
  202. if binary.BigEndian.Uint16(data) == 0 {
  203. r.isTerminal = false
  204. } else {
  205. r.isTerminal = true
  206. }
  207. }
  208. func (r *RemoteSvr) GotReportWidth(data []byte) {
  209. atomic.StoreInt32(&r.width, int32(binary.BigEndian.Uint16(data)))
  210. if r.funcWidthChan != nil {
  211. r.funcWidthChan()
  212. }
  213. }
  214. func (r *RemoteSvr) GetWidth() int {
  215. return int(atomic.LoadInt32(&r.width))
  216. }
  217. // -----------------------------------------------------------------------------
  218. type Message struct {
  219. Type MsgType
  220. Data []byte
  221. }
  222. func ReadMessage(r io.Reader) (*Message, error) {
  223. m := new(Message)
  224. var length int32
  225. if err := binary.Read(r, binary.BigEndian, &length); err != nil {
  226. return nil, err
  227. }
  228. if err := binary.Read(r, binary.BigEndian, &m.Type); err != nil {
  229. return nil, err
  230. }
  231. m.Data = make([]byte, int(length)-2)
  232. if _, err := io.ReadFull(r, m.Data); err != nil {
  233. return nil, err
  234. }
  235. return m, nil
  236. }
  237. func NewMessage(t MsgType, data []byte) *Message {
  238. return &Message{t, data}
  239. }
  240. func (m *Message) WriteTo(w io.Writer) (int, error) {
  241. buf := bytes.NewBuffer(make([]byte, 0, len(m.Data)+2+4))
  242. binary.Write(buf, binary.BigEndian, int32(len(m.Data)+2))
  243. binary.Write(buf, binary.BigEndian, m.Type)
  244. buf.Write(m.Data)
  245. n, err := buf.WriteTo(w)
  246. return int(n), err
  247. }
  248. // -----------------------------------------------------------------------------
  249. type RemoteCli struct {
  250. conn net.Conn
  251. raw RawMode
  252. receiveChan chan struct{}
  253. inited int32
  254. isTerminal *bool
  255. data bytes.Buffer
  256. dataM sync.Mutex
  257. }
  258. func NewRemoteCli(conn net.Conn) (*RemoteCli, error) {
  259. r := &RemoteCli{
  260. conn: conn,
  261. receiveChan: make(chan struct{}),
  262. }
  263. return r, nil
  264. }
  265. func (r *RemoteCli) MarkIsTerminal(is bool) {
  266. r.isTerminal = &is
  267. }
  268. func (r *RemoteCli) init() error {
  269. if !atomic.CompareAndSwapInt32(&r.inited, 0, 1) {
  270. return nil
  271. }
  272. if err := r.reportIsTerminal(); err != nil {
  273. return err
  274. }
  275. if err := r.reportWidth(); err != nil {
  276. return err
  277. }
  278. // register sig for width changed
  279. DefaultOnWidthChanged(func() {
  280. r.reportWidth()
  281. })
  282. return nil
  283. }
  284. func (r *RemoteCli) writeMsg(m *Message) error {
  285. r.dataM.Lock()
  286. _, err := m.WriteTo(r.conn)
  287. r.dataM.Unlock()
  288. return err
  289. }
  290. func (r *RemoteCli) Write(b []byte) (int, error) {
  291. m := NewMessage(T_DATA, b)
  292. r.dataM.Lock()
  293. _, err := m.WriteTo(r.conn)
  294. r.dataM.Unlock()
  295. return len(b), err
  296. }
  297. func (r *RemoteCli) reportWidth() error {
  298. screenWidth := GetScreenWidth()
  299. data := make([]byte, 2)
  300. binary.BigEndian.PutUint16(data, uint16(screenWidth))
  301. msg := NewMessage(T_WIDTH_REPORT, data)
  302. if err := r.writeMsg(msg); err != nil {
  303. return err
  304. }
  305. return nil
  306. }
  307. func (r *RemoteCli) reportIsTerminal() error {
  308. var isTerminal bool
  309. if r.isTerminal != nil {
  310. isTerminal = *r.isTerminal
  311. } else {
  312. isTerminal = DefaultIsTerminal()
  313. }
  314. data := make([]byte, 2)
  315. if isTerminal {
  316. binary.BigEndian.PutUint16(data, 1)
  317. } else {
  318. binary.BigEndian.PutUint16(data, 0)
  319. }
  320. msg := NewMessage(T_ISTTY_REPORT, data)
  321. if err := r.writeMsg(msg); err != nil {
  322. return err
  323. }
  324. return nil
  325. }
  326. func (r *RemoteCli) readLoop() {
  327. buf := bufio.NewReader(r.conn)
  328. for {
  329. msg, err := ReadMessage(buf)
  330. if err != nil {
  331. break
  332. }
  333. switch msg.Type {
  334. case T_ERAW:
  335. r.raw.Exit()
  336. case T_RAW:
  337. r.raw.Enter()
  338. case T_DATA:
  339. os.Stdout.Write(msg.Data)
  340. }
  341. }
  342. }
  343. func (r *RemoteCli) ServeBy(source io.Reader) error {
  344. if err := r.init(); err != nil {
  345. return err
  346. }
  347. go func() {
  348. defer r.Close()
  349. for {
  350. n, _ := io.Copy(r, source)
  351. if n == 0 {
  352. break
  353. }
  354. }
  355. }()
  356. defer r.raw.Exit()
  357. r.readLoop()
  358. return nil
  359. }
  360. func (r *RemoteCli) Close() {
  361. r.writeMsg(NewMessage(T_EOF, nil))
  362. }
  363. func (r *RemoteCli) Serve() error {
  364. return r.ServeBy(os.Stdin)
  365. }
  366. func ListenRemote(n, addr string, cfg *Config, h func(*Instance), onListen ...func(net.Listener) error) error {
  367. ln, err := net.Listen(n, addr)
  368. if err != nil {
  369. return err
  370. }
  371. if len(onListen) > 0 {
  372. if err := onListen[0](ln); err != nil {
  373. return err
  374. }
  375. }
  376. for {
  377. conn, err := ln.Accept()
  378. if err != nil {
  379. break
  380. }
  381. go func() {
  382. defer conn.Close()
  383. rl, err := HandleConn(*cfg, conn)
  384. if err != nil {
  385. return
  386. }
  387. h(rl)
  388. }()
  389. }
  390. return nil
  391. }
  392. func HandleConn(cfg Config, conn net.Conn) (*Instance, error) {
  393. r, err := NewRemoteSvr(conn)
  394. if err != nil {
  395. return nil, err
  396. }
  397. r.HandleConfig(&cfg)
  398. rl, err := NewEx(&cfg)
  399. if err != nil {
  400. return nil, err
  401. }
  402. return rl, nil
  403. }
  404. func DialRemote(n, addr string) error {
  405. conn, err := net.Dial(n, addr)
  406. if err != nil {
  407. return err
  408. }
  409. defer conn.Close()
  410. cli, err := NewRemoteCli(conn)
  411. if err != nil {
  412. return err
  413. }
  414. return cli.Serve()
  415. }