diff --git a/cmd/sst/tunnel.go b/cmd/sst/tunnel.go index a03cb18786..4f65bc1e21 100644 --- a/cmd/sst/tunnel.go +++ b/cmd/sst/tunnel.go @@ -1,10 +1,12 @@ package main import ( + "encoding/json" "fmt" "io" "log/slog" "os" + "os/exec" "os/user" "strings" @@ -30,7 +32,7 @@ var CmdTunnel = &cli.Command{ "```", "", "If your app has a VPC with `bastion` enabled, you can use this to connect to it.", - "This will forward traffic from the following ranges over SSH:", + "This will forward traffic from the following ranges using either SSH or SSM, depending on your bastion configuration:", "- `10.0.4.0/22`", "- `10.0.12.0/22`", "- `10.0.0.0/22`", @@ -39,7 +41,7 @@ var CmdTunnel = &cli.Command{ "The tunnel allows your local machine to access resources that are in the VPC.", "", ":::note", - "The tunnel is only available for apps that have a VPC with `bastion` enabled.", + "The tunnel is only available for apps that have a VPC with `bastion` enabled, or apps that have a Bastion component", ":::", "", "If you are running `sst dev`, this tunnel will be started automatically under the", @@ -57,6 +59,11 @@ var CmdTunnel = &cli.Command{ "", "This needs a network interface on your local machine. You can create this", "with the `sst tunnel install` command.", + "", + ":::note", + "When using the Bastion component in SSM mode, the tunnel requires the AWS Session Manager Plugin to be installed.", + "https://docs.aws.amazon.com/systems-manager/latest/userguide/session-manager-working-with-install-plugin.html", + ":::", }, "\n"), }, Run: func(c *cli.Cli) error { @@ -104,41 +111,117 @@ var CmdTunnel = &cli.Command{ if len(completed.Tunnels) == 0 { return util.NewReadableError(nil, "No tunnels found for stage "+stage) } - var tun project.Tunnel - for _, item := range completed.Tunnels { - tun = item + + var ssmConfigs []tunnel.SSMConfig + var sshConfigs []tunnel.SSHConfig + var allSubnets []string + + for name, tun := range completed.Tunnels { + mode := tun.Mode + + // backwards compatible for vpc bastion v1 + if mode == "" { + mode = "ssh" + } + + if mode == "ssm" { + if tun.InstanceID == "" { + slog.Warn("SSM tunnel missing instance ID, skipping", "name", name) + continue + } + if tun.Region == "" { + slog.Warn("SSM tunnel missing region, skipping", "name", name) + continue + } + ssmConfigs = append(ssmConfigs, tunnel.SSMConfig{ + InstanceID: tun.InstanceID, + Region: tun.Region, + Subnets: tun.Subnets, + }) + } else if mode == "ssh" { + if tun.IP == "" || tun.PrivateKey == "" { + slog.Warn("SSH tunnel missing IP or private key, skipping", "name", name) + continue + } + sshConfigs = append(sshConfigs, tunnel.SSHConfig{ + Host: tun.IP, + Username: tun.Username, + PrivateKey: tun.PrivateKey, + Subnets: tun.Subnets, + }) + } + + allSubnets = append(allSubnets, tun.Subnets...) + } + + if len(ssmConfigs) == 0 && len(sshConfigs) == 0 { + return util.NewReadableError(nil, "No tunnels found. Make sure you have a bastion deployed.") } - subnets := strings.Join(tun.Subnets, ",") + + if len(ssmConfigs) > 0 { + if _, err := exec.LookPath("session-manager-plugin"); err != nil { + return util.NewReadableError(nil, "AWS Session Manager Plugin is required for SSM tunnels but was not found.\n\nInstall it from: https://docs.aws.amazon.com/systems-manager/latest/userguide/session-manager-working-with-install-plugin.html\n\nAlternatively, you can use SSH mode by setting `ssm: false` on your Bastion component.") + } + } + + args := []string{ + "-n", "-E", + tunnel.BINARY_PATH, "tunnel", "start", + "--subnets", strings.Join(allSubnets, ","), + "--print-logs", + } + + if len(ssmConfigs) > 0 { + ssmJSON, err := json.Marshal(ssmConfigs) + if err != nil { + return fmt.Errorf("failed to serialize SSM config: %w", err) + } + args = append(args, "--ssm-config", string(ssmJSON)) + } + + if len(sshConfigs) > 0 { + sshJSON, err := json.Marshal(sshConfigs) + if err != nil { + return fmt.Errorf("failed to serialize SSH config: %w", err) + } + args = append(args, "--ssh-config", string(sshJSON)) + } + // run as root tunnelCmd := process.CommandContext( c.Context, - "sudo", "-n", "-E", - tunnel.BINARY_PATH, "tunnel", "start", - "--subnets", subnets, - "--host", tun.IP, - "--user", tun.Username, - "--print-logs", + "sudo", + args..., ) tunnelCmd.Env = append( os.Environ(), "SST_SKIP_LOCAL=true", "SST_SKIP_DEPENDENCY_CHECK=true", - "SSH_PRIVATE_KEY="+tun.PrivateKey, "SST_LOG="+strings.ReplaceAll(os.Getenv("SST_LOG"), ".log", "_sudo.log"), ) tunnelCmd.Stdout = os.Stdout slog.Info("starting tunnel", "cmd", tunnelCmd.Args) fmt.Println(ui.TEXT_HIGHLIGHT_BOLD.Render("Tunnel")) fmt.Println() - fmt.Print(ui.TEXT_HIGHLIGHT_BOLD.Render("▤")) - fmt.Println(ui.TEXT_NORMAL.Render(" " + tun.IP)) - fmt.Println() - fmt.Print(ui.TEXT_SUCCESS_BOLD.Render("➜")) - fmt.Println(ui.TEXT_NORMAL.Render(" Ranges")) - for _, subnet := range tun.Subnets { - fmt.Println(ui.TEXT_DIM.Render(" " + subnet)) + + for _, cfg := range ssmConfigs { + fmt.Print(ui.TEXT_HIGHLIGHT_BOLD.Render("▤")) + fmt.Println(ui.TEXT_NORMAL.Render(" " + cfg.InstanceID + " (SSM, " + cfg.Region + ")")) + for _, subnet := range cfg.Subnets { + fmt.Println(ui.TEXT_DIM.Render(" " + subnet)) + } + fmt.Println() } - fmt.Println() + + for _, cfg := range sshConfigs { + fmt.Print(ui.TEXT_HIGHLIGHT_BOLD.Render("▤")) + fmt.Println(ui.TEXT_NORMAL.Render(" " + cfg.Host + " (SSH)")) + for _, subnet := range cfg.Subnets { + fmt.Println(ui.TEXT_DIM.Render(" " + subnet)) + } + fmt.Println() + } + fmt.Println(ui.TEXT_DIM.Render("Waiting for connections...")) fmt.Println() stderr, _ := tunnelCmd.StderrPipe() @@ -201,56 +284,61 @@ var CmdTunnel = &cli.Command{ Name: "subnets", Type: "string", Description: cli.Description{ - Short: "The subnet to use for the tunnel", - Long: "The subnet to use for the tunnel", - }, - }, - { - Name: "host", - Type: "string", - Description: cli.Description{ - Short: "The host to use for the tunnel", - Long: "The host to use for the tunnel", + Short: "The subnets to route through the tunnel", + Long: "The subnets to route through the tunnel", }, }, { - Name: "port", + Name: "ssm-config", Type: "string", Description: cli.Description{ - Short: "The port to use for the tunnel", - Long: "The port to use for the tunnel", + Short: "JSON-encoded SSM tunnel configurations", + Long: "JSON-encoded SSM tunnel configurations", }, }, { - Name: "user", + Name: "ssh-config", Type: "string", Description: cli.Description{ - Short: "The user to use for the tunnel", - Long: "The user to use for the tunnel", + Short: "JSON-encoded SSH tunnel configurations", + Long: "JSON-encoded SSH tunnel configurations", }, }, }, Run: func(c *cli.Cli) error { subnets := strings.Split(c.String("subnets"), ",") - host := c.String("host") - port := c.String("port") - user := c.String("user") - if port == "" { - port = "22" + ssmJSON := c.String("ssm-config") + sshJSON := c.String("ssh-config") + + var ssmConfigs []tunnel.SSMConfig + var sshConfigs []tunnel.SSHConfig + + if ssmJSON != "" { + if err := json.Unmarshal([]byte(ssmJSON), &ssmConfigs); err != nil { + return util.NewReadableError(err, "failed to parse SSM configuration") + } + } + + if sshJSON != "" { + if err := json.Unmarshal([]byte(sshJSON), &sshConfigs); err != nil { + return util.NewReadableError(err, "failed to parse SSH configuration") + } } - slog.Info("starting tunnel", "subnet", subnets, "host", host, "port", port) + + if len(ssmConfigs) == 0 && len(sshConfigs) == 0 { + return util.NewReadableError(nil, "at least one SSM or SSH tunnel is required") + } + + slog.Info("starting tunnel", "subnets", subnets, "ssm", len(ssmConfigs), "ssh", len(sshConfigs)) err := tunnel.Start(subnets...) + defer tunnel.Stop() if err != nil { return err } - defer tunnel.Stop() slog.Info("tunnel started") - err = tunnel.StartProxy( - c.Context, - user, - host+":"+port, - []byte(os.Getenv("SSH_PRIVATE_KEY")), - ) + + err = tunnel.StartProxy(c.Context, ssmConfigs, sshConfigs) + if err != nil { slog.Error("failed to start tunnel", "error", err) } diff --git a/pkg/project/completed.go b/pkg/project/completed.go index 7a7d25871b..ff4a728708 100644 --- a/pkg/project/completed.go +++ b/pkg/project/completed.go @@ -96,18 +96,32 @@ func getCompletedEvent(ctx context.Context, passphrase string, workdir *PulumiWo } if match, ok := outputs["_tunnel"].(map[string]interface{}); ok { - ip, ipOk := match["ip"].(string) - username, usernameOk := match["username"].(string) - privateKey, privateKeyOk := match["privateKey"].(string) - if !ipOk || !usernameOk || !privateKeyOk { - continue + mode, _ := match["mode"].(string) + if mode == "" { + mode = "ssh" } + tunnel := Tunnel{ - IP: ip, - Username: username, - PrivateKey: privateKey, - Subnets: []string{}, + Mode: mode, + Subnets: []string{}, + } + + if ip, ok := match["ip"].(string); ok { + tunnel.IP = ip + } + if username, ok := match["username"].(string); ok { + tunnel.Username = username } + if privateKey, ok := match["privateKey"].(string); ok { + tunnel.PrivateKey = privateKey + } + if instanceId, ok := match["instanceId"].(string); ok { + tunnel.InstanceID = instanceId + } + if region, ok := match["region"].(string); ok { + tunnel.Region = region + } + if subnets, ok := match["subnets"].([]interface{}); ok { for _, subnet := range subnets { if s, ok := subnet.(string); ok { diff --git a/pkg/project/stack.go b/pkg/project/stack.go index 0dd0100548..3738dd2a6a 100644 --- a/pkg/project/stack.go +++ b/pkg/project/stack.go @@ -86,6 +86,9 @@ type Tunnel struct { Username string `json:"username"` PrivateKey string `json:"privateKey"` Subnets []string `json:"subnets"` + InstanceID string `json:"instanceId"` + Region string `json:"region"` + Mode string `json:"mode"` // "ssh" or "ssm" } type ImportDiff struct { diff --git a/pkg/tunnel/proxy.go b/pkg/tunnel/proxy.go index 3540c670ca..b8ddb36df6 100644 --- a/pkg/tunnel/proxy.go +++ b/pkg/tunnel/proxy.go @@ -2,44 +2,379 @@ package tunnel import ( "context" + "encoding/json" "fmt" + "io" + "log/slog" "net" + "os/exec" + "strconv" + "sync" + "sync/atomic" + "time" "github.com/armon/go-socks5" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/sst/sst/v3/cmd/sst/mosaic/ui" "golang.org/x/crypto/ssh" ) -func StartProxy(ctx context.Context, username string, host string, key []byte) error { - signer, err := ssh.ParsePrivateKey(key) +type SSMConfig struct { + InstanceID string `json:"instanceId"` + Region string `json:"region"` + Subnets []string `json:"subnets"` +} + +type SSHConfig struct { + Host string `json:"host"` + Username string `json:"username"` + PrivateKey string `json:"privateKey"` + Subnets []string `json:"subnets"` +} + +type proxy struct { + ssm []*ssmEntry + ssh []*sshEntry + nextPort atomic.Int32 + ctx context.Context +} + +type ssmEntry struct { + config SSMConfig + manager *ssmSessionManager + networks []*net.IPNet +} + +type sshEntry struct { + config SSHConfig + networks []*net.IPNet + sshClient *ssh.Client + mu sync.Mutex +} + +type ssmSessionManager struct { + client *ssm.Client + instanceID string + region string + sessions sync.Map + nextPort *atomic.Int32 + ctx context.Context +} + +type ssmSession struct { + localPort int + cmd *exec.Cmd + sessionID string + cancel context.CancelFunc + lastUsed time.Time + mu sync.Mutex +} + +func newProxy(ctx context.Context, ssmConfigs []SSMConfig, sshConfigs []SSHConfig) (*proxy, error) { + p := &proxy{ + ssm: make([]*ssmEntry, 0, len(ssmConfigs)), + ssh: make([]*sshEntry, 0, len(sshConfigs)), + ctx: ctx, + } + p.nextPort.Store(10000) + + for _, cfg := range ssmConfigs { + awsCfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(cfg.Region)) + if err != nil { + return nil, fmt.Errorf("failed to load AWS config for region %s: %w", cfg.Region, err) + } + + networks := parseCIDRs(cfg.Subnets) + manager := &ssmSessionManager{ + client: ssm.NewFromConfig(awsCfg), + instanceID: cfg.InstanceID, + region: cfg.Region, + nextPort: &p.nextPort, + ctx: ctx, + } + + p.ssm = append(p.ssm, &ssmEntry{ + config: cfg, + manager: manager, + networks: networks, + }) + + go manager.cleanupIdleSessions() + } + + for _, cfg := range sshConfigs { + p.ssh = append(p.ssh, &sshEntry{ + config: cfg, + networks: parseCIDRs(cfg.Subnets), + }) + } + + return p, nil +} + +func parseCIDRs(subnets []string) []*net.IPNet { + networks := make([]*net.IPNet, 0, len(subnets)) + for _, cidr := range subnets { + _, network, err := net.ParseCIDR(cidr) + if err != nil { + slog.Warn("failed to parse CIDR", "cidr", cidr, "error", err) + continue + } + networks = append(networks, network) + } + return networks +} + +func (p *proxy) dial(ctx context.Context, network, addr string) (net.Conn, error) { + host, portStr, err := net.SplitHostPort(addr) if err != nil { - return err + return nil, fmt.Errorf("invalid address %s: %w", addr, err) + } + + ip := net.ParseIP(host) + if ip == nil { + ips, err := net.LookupIP(host) + if err != nil || len(ips) == 0 { + return nil, fmt.Errorf("failed to resolve hostname %s: %w", host, err) + } + ip = ips[0] + } + + for _, entry := range p.ssm { + for _, net := range entry.networks { + if net.Contains(ip) { + port, _ := strconv.Atoi(portStr) + fmt.Println(ui.TEXT_INFO_BOLD.Render("| "), ui.TEXT_NORMAL.Render("Tunneling", network, addr, "via", entry.config.InstanceID)) + return entry.manager.dialThroughSSM(ctx, host, port) + } + } + } + + for _, entry := range p.ssh { + for _, net := range entry.networks { + if net.Contains(ip) { + fmt.Println(ui.TEXT_INFO_BOLD.Render("| "), ui.TEXT_NORMAL.Render("Tunneling", network, addr, "via SSH", entry.config.Host)) + sshClient, err := entry.getOrCreateSSHClient() + if err != nil { + return nil, err + } + return sshClient.Dial(network, addr) + } + } + } + + return nil, fmt.Errorf("no tunnel found for IP %s", ip) +} + +func (e *sshEntry) getOrCreateSSHClient() (*ssh.Client, error) { + e.mu.Lock() + defer e.mu.Unlock() + + if e.sshClient != nil { + _, _, err := e.sshClient.SendRequest("keepalive@openssh.com", true, nil) + if err == nil { + return e.sshClient, nil + } + e.sshClient.Close() + e.sshClient = nil + } + + signer, err := ssh.ParsePrivateKey([]byte(e.config.PrivateKey)) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) } - config := &ssh.ClientConfig{ - User: username, + + sshConfig := &ssh.ClientConfig{ + User: e.config.Username, Auth: []ssh.AuthMethod{ ssh.PublicKeys(signer), }, HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 10 * time.Second, + } + + host := e.config.Host + if _, _, err := net.SplitHostPort(host); err != nil { + host = net.JoinHostPort(host, "22") } - sshClient, err := ssh.Dial("tcp", host, config) + + client, err := ssh.Dial("tcp", host, sshConfig) + if err != nil { + return nil, fmt.Errorf("failed to connect to SSH server %s: %w", host, err) + } + + e.sshClient = client + return client, nil +} + +func (m *ssmSessionManager) cleanupIdleSessions() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-m.ctx.Done(): + m.sessions.Range(func(key, value interface{}) bool { + if session, ok := value.(*ssmSession); ok { + m.terminateSession(session) + } + return true + }) + return + case <-ticker.C: + now := time.Now() + m.sessions.Range(func(key, value interface{}) bool { + if session, ok := value.(*ssmSession); ok { + session.mu.Lock() + if now.Sub(session.lastUsed) > 5*time.Minute { + m.sessions.Delete(key) + go m.terminateSession(session) + } + session.mu.Unlock() + } + return true + }) + } + } +} + +func (m *ssmSessionManager) terminateSession(session *ssmSession) { + if session.cancel != nil { + session.cancel() + } + if session.cmd != nil && session.cmd.Process != nil { + session.cmd.Process.Kill() + } + if session.sessionID != "" { + m.client.TerminateSession(context.Background(), &ssm.TerminateSessionInput{ + SessionId: aws.String(session.sessionID), + }) + } +} + +func (m *ssmSessionManager) dialThroughSSM(ctx context.Context, remoteHost string, remotePort int) (net.Conn, error) { + key := fmt.Sprintf("%s:%d", remoteHost, remotePort) + + if val, ok := m.sessions.Load(key); ok { + session := val.(*ssmSession) + session.mu.Lock() + session.lastUsed = time.Now() + session.mu.Unlock() + + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", session.localPort)) + if err == nil { + return conn, nil + } + m.sessions.Delete(key) + m.terminateSession(session) + } + + localPort := int(m.nextPort.Add(1)) + + input := &ssm.StartSessionInput{ + Target: aws.String(m.instanceID), + DocumentName: aws.String("AWS-StartPortForwardingSessionToRemoteHost"), + Parameters: map[string][]string{ + "host": {remoteHost}, + "portNumber": {strconv.Itoa(remotePort)}, + "localPortNumber": {strconv.Itoa(localPort)}, + }, + } + + output, err := m.client.StartSession(ctx, input) + if err != nil { + return nil, fmt.Errorf("failed to start SSM session: %w", err) + } + + slog.Info("SSM session started", "sessionId", *output.SessionId, "remoteHost", remoteHost, "remotePort", remotePort, "localPort", localPort) + + sessionJSON, _ := json.Marshal(output) + inputJSON, _ := json.Marshal(input) + + pluginCtx, cancel := context.WithCancel(m.ctx) + pluginCmd := exec.CommandContext(pluginCtx, "session-manager-plugin", + string(sessionJSON), + m.region, + "StartSession", + "", + string(inputJSON), + fmt.Sprintf("https://ssm.%s.amazonaws.com", m.region), + ) + pluginCmd.Stdout = io.Discard + pluginCmd.Stderr = io.Discard + + if err := pluginCmd.Start(); err != nil { + cancel() + m.client.TerminateSession(context.Background(), &ssm.TerminateSessionInput{ + SessionId: output.SessionId, + }) + return nil, fmt.Errorf("failed to start session-manager-plugin: %w", err) + } + + session := &ssmSession{ + localPort: localPort, + cmd: pluginCmd, + sessionID: *output.SessionId, + cancel: cancel, + lastUsed: time.Now(), + } + + if err := m.waitForPort(ctx, localPort); err != nil { + m.terminateSession(session) + return nil, err + } + + m.sessions.Store(key, session) + + go func() { + pluginCmd.Wait() + m.sessions.Delete(key) + }() + + return net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", localPort)) +} + +func (m *ssmSessionManager) waitForPort(ctx context.Context, port int) error { + deadline := time.Now().Add(10 * time.Second) + for time.Now().Before(deadline) { + conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 100*time.Millisecond) + if err == nil { + conn.Close() + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(100 * time.Millisecond): + } + } + return fmt.Errorf("timeout waiting for local port %d", port) +} + +func StartProxy(ctx context.Context, ssmConfigs []SSMConfig, sshConfigs []SSHConfig) error { + if len(ssmConfigs) == 0 && len(sshConfigs) == 0 { + return fmt.Errorf("no tunnels configured") + } + + p, err := newProxy(ctx, ssmConfigs, sshConfigs) if err != nil { return err } - defer sshClient.Close() + server, err := socks5.New(&socks5.Config{ - Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { - fmt.Println(ui.TEXT_INFO_BOLD.Render(("| "), ui.TEXT_NORMAL.Render("Tunneling", network, addr))) - return sshClient.Dial(network, addr) - }, + Dial: p.dial, }) if err != nil { - return err + return fmt.Errorf("failed to create SOCKS5 server: %w", err) } + errChan := make(chan error, 1) go func() { - errChan <- server.ListenAndServe("tcp", fmt.Sprintf("%s:%d", "127.0.0.1", 1080)) + errChan <- server.ListenAndServe("tcp", "127.0.0.1:1080") }() + select { case err := <-errChan: return err diff --git a/pkg/tunnel/tunnel.go b/pkg/tunnel/tunnel.go index 82e72ec6c2..ea3fc8da4b 100644 --- a/pkg/tunnel/tunnel.go +++ b/pkg/tunnel/tunnel.go @@ -4,12 +4,16 @@ import ( "fmt" "log/slog" "os" + "strings" "github.com/xjasonlyu/tun2socks/v2/engine" "github.com/sst/sst/v3/pkg/process" ) +// Version of the tunnel binary. Bump this when tunnel code changes and needs to be re-installed. +const Version = "2" + func IsRunning() bool { return impl.isRunning() } @@ -25,12 +29,25 @@ type tunnelPlatform interface { var impl tunnelPlatform var BINARY_PATH = "/opt/sst/tunnel" +var VERSION_PATH = "/opt/sst/tunnel.version" func NeedsInstall() bool { - if _, err := os.Stat(BINARY_PATH); err == nil { - return false + if _, err := os.Stat(BINARY_PATH); err != nil { + return true + } + + checkVersion := Version + if testVersion := os.Getenv("SST_TEST_TUNNEL_VERSION"); testVersion != "" { + checkVersion = testVersion } - return true + + versionBytes, err := os.ReadFile(VERSION_PATH) + if err != nil { + return true + } + + installedVersion := strings.TrimSpace(string(versionBytes)) + return installedVersion != checkVersion } func Install() error { diff --git a/pkg/tunnel/tunnel_darwin.go b/pkg/tunnel/tunnel_darwin.go index 2c513efbf0..14ff70ae09 100644 --- a/pkg/tunnel/tunnel_darwin.go +++ b/pkg/tunnel/tunnel_darwin.go @@ -41,6 +41,15 @@ func (p *darwinPlatform) install() error { return err } err = os.Chmod(BINARY_PATH, 0755) + if err != nil { + return err + } + + err = os.WriteFile(VERSION_PATH, []byte(Version), 0644) + if err != nil { + return err + } + user := os.Getenv("SUDO_USER") sudoersPath := "/etc/sudoers.d/sst-" + user slog.Info("creating sudoers file", "path", sudoersPath) diff --git a/pkg/tunnel/tunnel_linux.go b/pkg/tunnel/tunnel_linux.go index e9b77d4af5..80832ca5f6 100644 --- a/pkg/tunnel/tunnel_linux.go +++ b/pkg/tunnel/tunnel_linux.go @@ -44,6 +44,15 @@ func (p *linuxPlatform) install() error { return err } err = os.Chmod(BINARY_PATH, 0755) + if err != nil { + return err + } + + err = os.WriteFile(VERSION_PATH, []byte(Version), 0644) + if err != nil { + return err + } + user := os.Getenv("SUDO_USER") if isNixOS() { diff --git a/pkg/tunnel/tunnel_windows.go b/pkg/tunnel/tunnel_windows.go index 6b091c1044..c7815ce09f 100644 --- a/pkg/tunnel/tunnel_windows.go +++ b/pkg/tunnel/tunnel_windows.go @@ -10,6 +10,7 @@ type windowsPlatform struct{} func init() { // Use Windows-style path BINARY_PATH = filepath.Join(os.Getenv("PROGRAMFILES"), "SST", "tunnel.exe") + VERSION_PATH = filepath.Join(os.Getenv("PROGRAMFILES"), "SST", "tunnel.version") impl = &windowsPlatform{} } @@ -29,5 +30,6 @@ func (p *windowsPlatform) isRunning() bool { // Override Install for Windows func (p *windowsPlatform) install() error { // Windows-specific installation + // TODO: implement version file writing when Windows tunnel is implemented return nil } diff --git a/platform/src/components/aws/bastion.ts b/platform/src/components/aws/bastion.ts new file mode 100644 index 0000000000..ba9e1f85c0 --- /dev/null +++ b/platform/src/components/aws/bastion.ts @@ -0,0 +1,586 @@ +import { + all, + ComponentResourceOptions, + interpolate, + Output, + output, +} from "@pulumi/pulumi"; +import { Component, Transform, transform } from "../component.js"; +import { Input } from "../input.js"; +import { + ec2, + getPartitionOutput, + getRegionOutput, + iam, + ssm, +} from "@pulumi/aws"; +import { Vpc } from "./vpc.js"; +import { VisibleError } from "../error.js"; +import { PrivateKey } from "@pulumi/tls"; + +export interface BastionArgs { + /** + * The VPC to launch the bastion host in. + * + * @example + * Create a VPC component. + * + * ```js + * const myVpc = new sst.aws.Vpc("MyVpc"); + * ``` + * + * And pass it in. + * + * ```js + * { + * vpc: myVpc + * } + * ``` + * + * Or pass in a custom VPC configuration. + * + * ```js + * { + * vpc: { + * id: "vpc-0d19d2b8ca2b268a1", + * routeSubnets: ["subnet-0b6a2b73896dc8c4c", "subnet-021f7e8f975b2b9c2"], + * subnet: "subnet-0b6a2b73896dc8c4c" + * } + * } + * ``` + * + * When using SSH mode, you need to provide a public subnet for the `subnet` arg. For SSM, this can be public or nat enabled (required for ssm agent to function). + * + * :::note + * A security group ingress rule will be created which allows internet access over port 22. It is recommended you use the ssm mode which does not require opening port 22 or having a public instance + * ::: + * + */ + vpc: + | Vpc + | Input<{ + /** + * The ID of the VPC. + */ + id: Input; + /** + * A list of subnet IDs in the VPC. Traffic to these subnets will be routed + * through the bastion. + */ + routeSubnets: Input[]>; + /** + * The subnet to launch the bastion host in. When in SSH mode, this must be a public subnet. + * In SSM mode, this needs to be subnet with an egress route to the internet; either a public subnet or a private subnet with a nat gateway. + * This is required for the SSM agent to function. + */ + subnet: Input; + }>; + /** + * Enable SSM mode for the bastion host. + * + * Use SSM manager to create a tunneled ssm session. + * This is more secure than SSH as it doesn't require any open ports or a public instance. + * + * @default false + * @example + * ```ts + * { + * ssm: true + * } + * ``` + */ + ssm?: Input; + /** + * Provide an existing IAM instance profile to use for the bastion host. + * + * By default, the component creates a new instance profile with the + * `AmazonSSMManagedInstanceCore` managed policy attached. + * + * @example + * ```ts + * { + * instanceProfile: "my-instance-profile-name" + * } + * ``` + */ + instanceProfile?: Input; + /** + * [Transform](/docs/components#transform) how this component creates its underlying + * resources. + */ + transform?: { + /** + * Transform the EC2 security group for the bastion host. + */ + securityGroup?: Transform; + /** + * Transform the EC2 instance for the bastion host. + */ + instance?: Transform; + }; +} + +interface BastionRef { + ref: boolean; + instanceId: Input; +} + +/** + * The `Bastion` component lets you add a bastion host to a VPC for secure access to + * private resources. This standalone component is similar to the VPC component bastion, but allows you to tunnel to a non-sst vpc and also has a more secure SSM mode. + * + * By default, the bastion uses SSH for tunneling. Which is the same behaviour as the VPC component bastion. + * This new component however also has an opt-in SSM mode, this doesn't require opening + * port 22, having a public IP, or managing SSH keys. This mode is recommended for teams which have strict network security rules. You can read more about SSM port forwarding [here](https://aws.amazon.com/blogs/aws/new-port-forwarding-using-aws-system-manager-sessions-manager/) + * + * :::note + * SSM mode requires the [AWS Session Manager Plugin](https://docs.aws.amazon.com/systems-manager/latest/userguide/session-manager-working-with-install-plugin.html) to be installed on your local machine. + * ::: + * + * SSH mode doesn't require the Session Manager Plugin, but it opens port 22 to the + * internet and requires a public subnet. + * + * @example + * + * #### Create a bastion with an SST VPC + * + * ```ts title="sst.config.ts" + * const vpc = new sst.aws.Vpc("MyVpc"); + * const bastion = new sst.aws.Bastion("MyBastion", { vpc }); + * ``` + * + * #### Create a bastion with SSM mode + * + * ```ts title="sst.config.ts" + * const vpc = new sst.aws.Vpc("MyVpc"); + * const bastion = new sst.aws.Bastion("MyBastion", { + * vpc, + * ssm: true + * }); + * ``` + * + * #### Create a bastion with a custom VPC + * + * If you have an existing VPC that was not created with SST, you can still use the Bastion + * component by providing the VPC ID and subnet IDs. + * + * ```ts title="sst.config.ts" + * new sst.aws.Bastion("MyBastion", { + * vpc: { + * id: "vpc-0d19d2b8ca2b268a1", + * routeSubnets: ["subnet-0b6a2b73896dc8c4c", "subnet-021f7e8f975b2b9c2"], + * subnet: "subnet-0b6a2b73896dc8c4c" // must have route to the internet + * }, + * ssm: true + * }); + * ``` + * + * For SSH mode with a custom VPC, you also need to specify a public subnet: + * + * ```ts title="sst.config.ts" + * new sst.aws.Bastion("MyBastion", { + * vpc: { + * id: "vpc-0d19d2b8ca2b268a1", + * routeSubnets: ["subnet-0b6a2b73896dc8c4c", "subnet-021f7e8f975b2b9c2"], + * subnet: "subnet-0b6a2b73896dc8c4c" // must be public, i.e. have route from/to the internet + * } + * }); + * ``` + * + * --- + * + * ### Cost + * + * The bastion host uses a `t4g.nano` instance which costs about $3/month. + * + */ +export class Bastion extends Component { + private _instance: Output; + private _mode: Output; + private _cidrRange: Output; + private _privateKeyValue: Output; + + constructor( + name: string, + args: BastionArgs, + opts?: ComponentResourceOptions, + ) { + super(__pulumiType, name, args, opts); + + const parent = this; + const self = this; + + if (args && "ref" in args) { + const ref = reference(); + this._instance = ref.instance; + this._mode = ref.mode; + this._privateKeyValue = ref.privateKeyValue; + this._cidrRange = ref.cidrRange; + registerOutputs(); + return; + } + + const partition = getPartitionOutput({}, { parent }).partition; + const mode = output(args.ssm ? "ssm" : "ssh"); + + const vpc = normalizeVpc(); + const keyPairResult = createKeyPair(); + const securityGroup = createSecurityGroup(); + const instanceProfile = createInstanceProfile(); + const instance = createInstance(); + + this._instance = instance; + this._mode = mode; + this._privateKeyValue = output(keyPairResult?.privateKeyValue); + this._cidrRange = vpc.apply((vpc) => + all( + vpc.routeSubnets.map((id) => + ec2.getSubnetOutput({ id }, { parent }).apply((s) => s.cidrBlock), + ), + ), + ); + + registerOutputs(); + + function reference() { + const ref = args as unknown as BastionRef; + const instanceId = output(ref.instanceId); + + const instance = ec2.Instance.get( + `${name}Instance`, + instanceId, + undefined, + { parent }, + ); + + const mode = instance.tags.apply((tags) => { + if (!tags) return "ssh"; + return tags["sst:bastion-mode"] ?? "ssh"; + }); + + return mode.apply((mode) => { + const subnetData = ec2.getSubnetOutput( + { id: instance.subnetId }, + { parent }, + ); + + const vpcId = subnetData.apply((s) => s.vpcId); + const vpcSubnets = vpcId.apply((vpcId) => + ec2.getSubnetsOutput( + { filters: [{ name: "vpc-id", values: [vpcId] }] }, + { parent }, + ), + ); + + const subnets = vpcSubnets.apply((s) => + all( + s.ids.map((id) => + ec2.getSubnetOutput({ id }, { parent }).apply((s) => s.cidrBlock), + ), + ), + ); + + if (mode === "ssm") { + return { + instance: output(instance), + cidrRange: subnets, + privateKeyValue: undefined, + mode, + }; + } + + const privateKeyValue = vpcId.apply((vpcId) => { + const param = ssm.Parameter.get( + `${name}PrivateKeyValue`, + `/sst/bastionv2/${vpcId}/private-key-value`, + undefined, + { parent }, + ); + return param.value; + }); + return { + instance: output(instance), + cidrRange: subnets, + privateKeyValue, + mode, + }; + }); + } + + function normalizeVpc() { + // "vpc" is a Vpc component + if (args.vpc instanceof Vpc) { + return all([ + args.vpc.id, + args.vpc.privateSubnets, + args.vpc.publicSubnets, + ]).apply(([id, privateSubnets, publicSubnets]) => ({ + id, + routeSubnets: [...publicSubnets, ...privateSubnets], + subnet: publicSubnets[0], + })); + } + + // "vpc" is object + return output(args.vpc).apply((vpc) => { + if (!vpc.id) { + throw new VisibleError( + `Missing "vpc.id" for the "${name}" Bastion component.`, + ); + } + if (!vpc.routeSubnets?.length) { + throw new VisibleError( + `Missing "vpc.subnets" for the "${name}" Bastion component. At least one subnet is required.`, + ); + } + return { + id: vpc.id, + routeSubnets: vpc.routeSubnets, + subnet: vpc.subnet, + }; + }); + } + + function createKeyPair(): + | { + keyPair: ec2.KeyPair; + privateKeyValue: Output; + } + | undefined { + if (args.ssm === true) { + return undefined; + } + + const tlsPrivateKey = new PrivateKey( + `${name}TlsPrivateKey`, + { + algorithm: "RSA", + rsaBits: 4096, + }, + { parent }, + ); + + new ssm.Parameter( + `${name}PrivateKeyValue`, + { + name: vpc.apply((v) => `/sst/bastionv2/${v.id}/private-key-value`), + description: "Bastion host private key", + type: ssm.ParameterType.SecureString, + value: tlsPrivateKey.privateKeyOpenssh, + }, + { parent }, + ); + + const keyPair = new ec2.KeyPair( + `${name}KeyPair`, + { + publicKey: tlsPrivateKey.publicKeyOpenssh, + }, + { parent }, + ); + + return { keyPair, privateKeyValue: tlsPrivateKey.privateKeyOpenssh }; + } + + function createSecurityGroup() { + const ingress = !args.ssm + ? [ + { + protocol: "tcp", + fromPort: 22, + toPort: 22, + cidrBlocks: ["0.0.0.0/0"], + }, + ] + : []; + + return new ec2.SecurityGroup( + ...transform( + args.transform?.securityGroup, + `${name}SecurityGroup`, + { + vpcId: vpc.apply((v) => v.id), + ingress, + egress: [ + { + protocol: "-1", + fromPort: 0, + toPort: 0, + cidrBlocks: ["0.0.0.0/0"], + }, + ], + }, + { parent }, + ), + ); + } + + function createInstanceProfile() { + return output(args.instanceProfile).apply((instanceProfileName) => { + if (instanceProfileName) { + if (instanceProfileName.startsWith("arn:")) { + throw new VisibleError( + "Bastion instance profile must be a name, not an ARN.", + ); + } + + return iam.InstanceProfile.get( + `${name}InstanceProfile`, + instanceProfileName, + {}, + { parent }, + ); + } + + const role = new iam.Role( + `${name}Role`, + { + assumeRolePolicy: iam.getPolicyDocumentOutput({ + statements: [ + { + actions: ["sts:AssumeRole"], + principals: [ + { + type: "Service", + identifiers: ["ec2.amazonaws.com"], + }, + ], + }, + ], + }).json, + managedPolicyArns: [ + interpolate`arn:${partition}:iam::aws:policy/AmazonSSMManagedInstanceCore`, + ], + }, + { parent }, + ); + + return new iam.InstanceProfile( + `${name}InstanceProfile`, + { role: role.name }, + { parent }, + ); + }); + } + + function createInstance() { + const ami = ec2.getAmiOutput( + { + owners: ["amazon"], + filters: [ + { + name: "name", + // The AMI has the SSM agent pre-installed + values: ["al2023-ami-20*"], + }, + { + name: "architecture", + values: ["arm64"], + }, + ], + mostRecent: true, + }, + { parent }, + ); + + return all([vpc, instanceProfile]).apply(([vpc, instanceProfile]) => { + return new ec2.Instance( + ...transform( + args.transform?.instance, + `${name}Instance`, + { + instanceType: "t4g.nano", + ami: ami.id, + subnetId: vpc.subnet, + vpcSecurityGroupIds: [securityGroup.id], + iamInstanceProfile: instanceProfile.name, + keyName: keyPairResult?.keyPair.keyName, + associatePublicIpAddress: true, + tags: { + "sst:bastion-mode": args.ssm ? "ssm" : "ssh", + }, + }, + { parent }, + ), + ); + }); + } + + function registerOutputs() { + const region = getRegionOutput({}, { parent }).region; + self.registerOutputs({ + _tunnel: all([ + self._instance, + self._cidrRange, + self._mode, + region, + self._privateKeyValue, + ]).apply(([instance, cidrRange, mode, region, privateKeyValue]) => ({ + ip: mode === "ssh" ? instance.publicIp : undefined, + username: "ec2-user", + privateKey: privateKeyValue, + instanceId: instance.id, + region: region, + subnets: cidrRange, + mode, + })), + }); + } + } + + /** + * The public IP address of the bastion host. Only available in SSH mode. + */ + public get publicIp() { + return this._instance.publicIp; + } + + /** + * The underlying [resources](/docs/components/#nodes) this component creates. + */ + public get nodes() { + return { + /** + * The Amazon EC2 instance for the bastion host. + */ + instance: this._instance, + }; + } + + /** + * Reference an existing Bastion component with the given instance ID. This is useful when you + * create a Bastion in one stage and want to share it in another stage. + * + * @param name The name of the component. + * @param instanceId The ID of the EC2 instance. + * @param opts Resource options. + * + * @example + * Imagine you create a bastion in the `dev` stage. And in your personal stage, `frank`, + * instead of creating a new bastion, you want to reuse the one from `dev`. + * + * ```ts title="sst.config.ts" + * const bastion = $app.stage === "frank" + * ? sst.aws.Bastion.get("MyBastion", "i-1234567890abcdef0") + * : new sst.aws.Bastion("MyBastion", { vpc }); + * ``` + */ + public static get( + name: string, + instanceId: Input, + opts?: ComponentResourceOptions, + ) { + return new Bastion( + name, + { + ref: true, + instanceId, + } as unknown as BastionArgs, + opts, + ); + } +} + +const __pulumiType = "sst:aws:Bastion"; +// @ts-expect-error +Bastion.__pulumiType = __pulumiType; diff --git a/platform/src/components/aws/index.ts b/platform/src/components/aws/index.ts index c1569e0ebf..33e394b42f 100644 --- a/platform/src/components/aws/index.ts +++ b/platform/src/components/aws/index.ts @@ -6,6 +6,7 @@ export * from "./app-sync.js"; export * from "./astro.js"; export * from "./aurora.js"; export * from "./auth.js"; +export * from "./bastion.js"; export * from "./bucket.js"; export * from "./bus.js"; export * from "./cluster.js"; diff --git a/www/astro.config.mjs b/www/astro.config.mjs index 0f0b1aaa23..c3e40281ac 100644 --- a/www/astro.config.mjs +++ b/www/astro.config.mjs @@ -98,6 +98,7 @@ const sidebar = [ "docs/component/aws/queue", "docs/component/aws/vector", "docs/component/aws/aurora", + "docs/component/aws/bastion", "docs/component/aws/router", "docs/component/aws/analog", "docs/component/aws/bucket", diff --git a/www/generate.ts b/www/generate.ts index 94c5fc41de..56e20d58c1 100644 --- a/www/generate.ts +++ b/www/generate.ts @@ -2156,6 +2156,7 @@ async function buildComponents() { "../platform/src/components/aws/app-sync-resolver.ts", "../platform/src/components/aws/auth.ts", "../platform/src/components/aws/aurora.ts", + "../platform/src/components/aws/bastion.ts", "../platform/src/components/aws/bucket.ts", "../platform/src/components/aws/bucket-notification.ts", "../platform/src/components/aws/bus.ts",