
Source file src/code.rocketnine.space/tslocum/sshtargate/portal/portal.go

Documentation: code.rocketnine.space/tslocum/sshtargate/portal

     1  // Package portal provides SSH portals to applications.
     2  package portal
     4  import (
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"os"
    11  	"os/exec"
    12  	"path"
    13  	"syscall"
    14  	"time"
    15  	"unsafe"
    17  	"github.com/creack/pty"
    18  	"github.com/gliderlabs/ssh"
    19  	gossh "golang.org/x/crypto/ssh"
    20  )
    22  const (
    23  	// ListenTimeout is the maximum time to start listening on an address.
    24  	ListenTimeout = 1 * time.Second
    26  	// IdleTimeout is the maximum time for a connection to be inactive.
    27  	IdleTimeout = 1 * time.Minute
    28  )
    30  // Portal is an SSH portal to an application.
    31  type Portal struct {
    32  	Name    string
    33  	Address string
    34  	Command []string
    35  	Server  *ssh.Server
    36  }
    38  // New opens an SSH portal to an application.
    39  func New(name string, address string, command []string) (*Portal, error) {
    40  	if address == "" {
    41  		return nil, errors.New("no address supplied")
    42  	} else if command == nil || command[0] == "" {
    43  		return nil, errors.New("no command supplied")
    44  	}
    46  	server := &ssh.Server{
    47  		Addr:        address,
    48  		IdleTimeout: IdleTimeout,
    49  		Handler: func(sshSession ssh.Session) {
    50  			ptyReq, winCh, isPty := sshSession.Pty()
    51  			if !isPty {
    52  				io.WriteString(sshSession, "failed to start command: non-interactive terminals are not supported\n")
    53  				sshSession.Exit(1)
    54  				return
    55  			}
    57  			cmdCtx, cancelCmd := context.WithCancel(sshSession.Context())
    58  			defer cancelCmd()
    60  			var args []string
    61  			if len(command) > 1 {
    62  				args = command[1:]
    63  			}
    64  			cmd := exec.CommandContext(cmdCtx, command[0], args...)
    66  			cmd.Env = append(sshSession.Environ(), fmt.Sprintf("TERM=%s", ptyReq.Term))
    68  			stderr, err := cmd.StderrPipe()
    69  			if err != nil {
    70  				log.Printf("error: failed to create stderr pipe for portal %s: %s", name, err)
    71  				return
    72  			}
    73  			go func() {
    74  				io.Copy(sshSession.Stderr(), stderr)
    75  			}()
    77  			f, err := pty.Start(cmd)
    78  			if err != nil {
    79  				io.WriteString(sshSession, fmt.Sprintf("failed to start command: failed to initialize pseudo-terminal: %s\n", err))
    80  				sshSession.Exit(1)
    81  				return
    82  			}
    83  			go func() {
    84  				for win := range winCh {
    85  					setWinsize(f, win.Width, win.Height)
    86  				}
    87  			}()
    89  			go func() {
    90  				io.Copy(f, sshSession)
    91  			}()
    92  			io.Copy(sshSession, f)
    94  			f.Close()
    95  			cmd.Wait()
    96  		},
    97  		PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
    98  			return true
    99  		},
   100  		PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
   101  			return true
   102  		},
   103  		PasswordHandler: func(ctx ssh.Context, password string) bool {
   104  			return true
   105  		},
   106  		KeyboardInteractiveHandler: func(ctx ssh.Context, challenger gossh.KeyboardInteractiveChallenge) bool {
   107  			return true
   108  		},
   109  	}
   111  	homeDir, err := os.UserHomeDir()
   112  	if err != nil {
   113  		return nil, fmt.Errorf("failed to retrieve user home dir: %s", err)
   114  	}
   116  	keyPath := path.Join(homeDir, ".ssh", "id_ed25519")
   117  	_, err = os.Stat(keyPath)
   118  	if os.IsNotExist(err) {
   119  		keyPath = path.Join(homeDir, ".ssh", "id_rsa")
   120  		_, err = os.Stat(keyPath)
   121  		if os.IsNotExist(err) {
   122  			keyPath = ""
   123  			log.Println("WARNING: no host key found in ~/.ssh, this will result in key verification errors")
   124  		}
   125  	}
   126  	if keyPath != "" {
   127  		err = server.SetOption(ssh.HostKeyFile(keyPath))
   128  		if err != nil {
   129  			return nil, fmt.Errorf("failed to set host key file: %s", err)
   130  		}
   131  	}
   133  	t := time.NewTimer(ListenTimeout)
   134  	errs := make(chan error)
   135  	go func() {
   136  		err := server.ListenAndServe()
   137  		if err != nil {
   138  			errs <- fmt.Errorf("failed to start SSH server: %s", err)
   139  		}
   140  	}()
   141  	select {
   142  	case err = <-errs:
   143  		return nil, err
   144  	case <-t.C:
   145  		// Server started
   146  	}
   148  	p := Portal{Name: name, Address: address, Command: command, Server: server}
   150  	return &p, nil
   151  }
   153  // Close closes the portal immediately.
   154  func (p *Portal) Close() {
   155  	p.Server.Close()
   156  }
   158  // Shutdown closes the portal without interrupting active connections.
   159  func (p *Portal) Shutdown() {
   160  	p.Server.Shutdown(context.Background())
   161  }
   163  func setWinsize(f *os.File, w, h int) {
   164  	syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
   165  		uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
   166  }

View as plain text