1
2 package portal
3
4 import (
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "log"
10 "os"
11 "os/exec"
12 "path"
13 "syscall"
14 "time"
15 "unsafe"
16
17 "github.com/creack/pty"
18 "github.com/gliderlabs/ssh"
19 gossh "golang.org/x/crypto/ssh"
20 )
21
22 const (
23
24 ListenTimeout = 1 * time.Second
25
26
27 IdleTimeout = 1 * time.Minute
28 )
29
30
31 type Portal struct {
32 Name string
33 Address string
34 Command []string
35 Server *ssh.Server
36 }
37
38
39 func New(name string, address string, command []string) (*Portal, error) {
40 if address == "" {
41 return nil, errors.New("no address supplied")
42 } else if command == nil || command[0] == "" {
43 return nil, errors.New("no command supplied")
44 }
45
46 server := &ssh.Server{
47 Addr: address,
48 IdleTimeout: IdleTimeout,
49 Handler: func(sshSession ssh.Session) {
50 ptyReq, winCh, isPty := sshSession.Pty()
51 if !isPty {
52 io.WriteString(sshSession, "failed to start command: non-interactive terminals are not supported\n")
53 sshSession.Exit(1)
54 return
55 }
56
57 cmdCtx, cancelCmd := context.WithCancel(sshSession.Context())
58 defer cancelCmd()
59
60 var args []string
61 if len(command) > 1 {
62 args = command[1:]
63 }
64 cmd := exec.CommandContext(cmdCtx, command[0], args...)
65
66 cmd.Env = append(sshSession.Environ(), fmt.Sprintf("TERM=%s", ptyReq.Term))
67
68 stderr, err := cmd.StderrPipe()
69 if err != nil {
70 log.Printf("error: failed to create stderr pipe for portal %s: %s", name, err)
71 return
72 }
73 go func() {
74 io.Copy(sshSession.Stderr(), stderr)
75 }()
76
77 f, err := pty.Start(cmd)
78 if err != nil {
79 io.WriteString(sshSession, fmt.Sprintf("failed to start command: failed to initialize pseudo-terminal: %s\n", err))
80 sshSession.Exit(1)
81 return
82 }
83 go func() {
84 for win := range winCh {
85 setWinsize(f, win.Width, win.Height)
86 }
87 }()
88
89 go func() {
90 io.Copy(f, sshSession)
91 }()
92 io.Copy(sshSession, f)
93
94 f.Close()
95 cmd.Wait()
96 },
97 PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
98 return true
99 },
100 PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
101 return true
102 },
103 PasswordHandler: func(ctx ssh.Context, password string) bool {
104 return true
105 },
106 KeyboardInteractiveHandler: func(ctx ssh.Context, challenger gossh.KeyboardInteractiveChallenge) bool {
107 return true
108 },
109 }
110
111 homeDir, err := os.UserHomeDir()
112 if err != nil {
113 return nil, fmt.Errorf("failed to retrieve user home dir: %s", err)
114 }
115
116 keyPath := path.Join(homeDir, ".ssh", "id_ed25519")
117 _, err = os.Stat(keyPath)
118 if os.IsNotExist(err) {
119 keyPath = path.Join(homeDir, ".ssh", "id_rsa")
120 _, err = os.Stat(keyPath)
121 if os.IsNotExist(err) {
122 keyPath = ""
123 log.Println("WARNING: no host key found in ~/.ssh, this will result in key verification errors")
124 }
125 }
126 if keyPath != "" {
127 err = server.SetOption(ssh.HostKeyFile(keyPath))
128 if err != nil {
129 return nil, fmt.Errorf("failed to set host key file: %s", err)
130 }
131 }
132
133 t := time.NewTimer(ListenTimeout)
134 errs := make(chan error)
135 go func() {
136 err := server.ListenAndServe()
137 if err != nil {
138 errs <- fmt.Errorf("failed to start SSH server: %s", err)
139 }
140 }()
141 select {
142 case err = <-errs:
143 return nil, err
144 case <-t.C:
145
146 }
147
148 p := Portal{Name: name, Address: address, Command: command, Server: server}
149
150 return &p, nil
151 }
152
153
154 func (p *Portal) Close() {
155 p.Server.Close()
156 }
157
158
159 func (p *Portal) Shutdown() {
160 p.Server.Shutdown(context.Background())
161 }
162
163 func setWinsize(f *os.File, w, h int) {
164 syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
165 uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
166 }
167
View as plain text