From e110964d6974cd603b5dd5e012a97f2f161c907a Mon Sep 17 00:00:00 2001 From: David Zucker Date: Fri, 17 Apr 2026 14:07:15 +0900 Subject: [PATCH 1/6] feat: allow remote hostnames in ssh forward ports This keeps reverse forwarding unchanged while letting forward-only port mappings target service names and other remote hostnames that resolve inside the workspace runtime. --- cmd/ssh.go | 4 +- cmd/ssh_forward_ports.go | 86 +++++++++++++++++++ cmd/ssh_forward_ports_test.go | 65 ++++++++++++++ .../up-docker-compose/up_docker_compose.go | 71 +++++++++++++++ 4 files changed, 224 insertions(+), 2 deletions(-) create mode 100644 cmd/ssh_forward_ports.go create mode 100644 cmd/ssh_forward_ports_test.go diff --git a/cmd/ssh.go b/cmd/ssh.go index 0bf3ea1ee..8d1e82df7 100644 --- a/cmd/ssh.go +++ b/cmd/ssh.go @@ -113,7 +113,7 @@ func NewSSHCmd(f *flags.GlobalFlags) *cobra.Command { sshCmd.Flags(). StringArrayVarP(&cmd.ForwardPorts, "forward-ports", "L", []string{}, "Specifies that connections to the given TCP port or Unix socket on the local (client) "+ - "host are to be forwarded to the given host and port, or Unix socket, on the remote side.") + "host are to be forwarded to the given host, service name, and port, or Unix socket, on the remote side.") sshCmd.Flags(). StringArrayVarP(&cmd.ReverseForwardPorts, "reverse-forward-ports", "R", []string{}, "Specifies that connections to the given TCP port or Unix socket on the local (client) "+ @@ -452,7 +452,7 @@ func (cmd *SSHCmd) forwardPorts( errChan := make(chan error, len(cmd.ForwardPorts)) for _, portMapping := range cmd.ForwardPorts { - mapping, err := port.ParsePortSpec(portMapping) + mapping, err := parseForwardPortSpec(portMapping) if err != nil { return fmt.Errorf("parse port mapping: %w", err) } diff --git a/cmd/ssh_forward_ports.go b/cmd/ssh_forward_ports.go new file mode 100644 index 000000000..1a5e0e951 --- /dev/null +++ b/cmd/ssh_forward_ports.go @@ -0,0 +1,86 @@ +package cmd + +import ( + "fmt" + "net" + "strconv" + "strings" + + "github.com/skevetter/devpod/pkg/port" +) + +func parseForwardPortSpec(raw string) (port.Mapping, error) { + parts := strings.Split(raw, ":") + + switch len(parts) { + case 1, 2: + return port.ParsePortSpec(raw) + case 3: + if !isPortNumber(parts[0]) { + return port.ParsePortSpec(raw) + } + + return newForwardPortMapping("", parts[0], parts[1], parts[2]) + case 4: + if parts[0] == "" { + return port.Mapping{}, fmt.Errorf("local host is empty") + } + + return newForwardPortMapping(parts[0], parts[1], parts[2], parts[3]) + default: + return port.Mapping{}, fmt.Errorf("unexpected port format: %s", raw) + } +} + +func newForwardPortMapping(localHost, localPort, remoteHost, remotePort string) (port.Mapping, error) { + hostAddress, err := parseForwardLocalAddress(localHost, localPort) + if err != nil { + return port.Mapping{}, fmt.Errorf("parse host address: %w", err) + } + + containerAddress, err := parseForwardRemoteAddress(remoteHost, remotePort) + if err != nil { + return port.Mapping{}, fmt.Errorf("parse container address: %w", err) + } + + return port.Mapping{ + Host: hostAddress, + Container: containerAddress, + }, nil +} + +func parseForwardLocalAddress(host, rawPort string) (port.Address, error) { + return parseForwardTCPAddress(host, rawPort, false) +} + +func parseForwardRemoteAddress(host, rawPort string) (port.Address, error) { + if host == "" { + return port.Address{}, fmt.Errorf("remote host is empty") + } + + return parseForwardTCPAddress(host, rawPort, true) +} + +func parseForwardTCPAddress(host, rawPort string, allowHostnames bool) (port.Address, error) { + if !isPortNumber(rawPort) { + return port.Address{}, fmt.Errorf("invalid port %s", rawPort) + } + + if host == "" { + host = "localhost" + } + + if !allowHostnames && host != "localhost" && net.ParseIP(host) == nil { + return port.Address{}, fmt.Errorf("not an ip address %s", host) + } + + return port.Address{ + Protocol: "tcp", + Address: net.JoinHostPort(host, rawPort), + }, nil +} + +func isPortNumber(raw string) bool { + _, err := strconv.Atoi(raw) + return err == nil +} diff --git a/cmd/ssh_forward_ports_test.go b/cmd/ssh_forward_ports_test.go new file mode 100644 index 000000000..999076e14 --- /dev/null +++ b/cmd/ssh_forward_ports_test.go @@ -0,0 +1,65 @@ +package cmd + +import ( + "testing" + + "github.com/skevetter/devpod/pkg/port" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseForwardPortSpec_ServiceNameTarget(t *testing.T) { + mapping, err := parseForwardPortSpec("8080:nginx:80") + require.NoError(t, err) + assert.Equal(t, "tcp", mapping.Host.Protocol) + assert.Equal(t, "localhost:8080", mapping.Host.Address) + assert.Equal(t, "tcp", mapping.Container.Protocol) + assert.Equal(t, "nginx:80", mapping.Container.Address) +} + +func TestParseForwardPortSpec_ServiceNameTargetWithLocalBindHost(t *testing.T) { + mapping, err := parseForwardPortSpec("127.0.0.1:8080:nginx:80") + require.NoError(t, err) + assert.Equal(t, "127.0.0.1:8080", mapping.Host.Address) + assert.Equal(t, "nginx:80", mapping.Container.Address) +} + +func TestParseForwardPortSpec_PreservesLocalBindDisambiguation(t *testing.T) { + mapping, err := parseForwardPortSpec("localhost:8080:80") + require.NoError(t, err) + assert.Equal(t, "localhost:8080", mapping.Host.Address) + assert.Equal(t, "localhost:80", mapping.Container.Address) +} + +func TestParseForwardPortSpec_AllowsRemoteIPTargets(t *testing.T) { + mapping, err := parseForwardPortSpec("8080:10.0.0.2:80") + require.NoError(t, err) + assert.Equal(t, "localhost:8080", mapping.Host.Address) + assert.Equal(t, "10.0.0.2:80", mapping.Container.Address) +} + +func TestParseForwardPortSpec_RejectsNonIPLocalBindHost(t *testing.T) { + _, err := parseForwardPortSpec("app:8080:nginx:80") + require.Error(t, err) + assert.ErrorContains(t, err, "not an ip address app") +} + +func TestParseForwardPortSpec_RejectsEmptyRemoteHost(t *testing.T) { + _, err := parseForwardPortSpec("8080::80") + require.Error(t, err) + assert.ErrorContains(t, err, "remote host is empty") +} + +func TestParseForwardPortSpec_DelegatesUnixSocketMappings(t *testing.T) { + mapping, err := parseForwardPortSpec("/tmp/local.sock:/tmp/remote.sock") + require.NoError(t, err) + assert.Equal(t, port.Mapping{ + Host: port.Address{Protocol: "unix", Address: "/tmp/local.sock"}, + Container: port.Address{Protocol: "unix", Address: "/tmp/remote.sock"}, + }, mapping) +} + +func TestParseForwardPortSpec_DoesNotChangeSharedParser(t *testing.T) { + _, err := port.ParsePortSpec("8080:nginx:80") + require.Error(t, err) +} diff --git a/e2e/tests/up-docker-compose/up_docker_compose.go b/e2e/tests/up-docker-compose/up_docker_compose.go index 704fab2b8..6116abc4c 100644 --- a/e2e/tests/up-docker-compose/up_docker_compose.go +++ b/e2e/tests/up-docker-compose/up_docker_compose.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "io" + "net" "net/http" "os" "os/exec" @@ -161,6 +162,76 @@ var _ = ginkgo.Describe( )) }, ginkgo.SpecTimeout(framework.GetTimeout())) + ginkgo.It("ssh forward ports support remote service names", func(ctx context.Context) { + _, workspace, err := tc.setupAndStartWorkspace( + ctx, + "tests/up-docker-compose/testdata/docker-compose-forward-ports", + "--debug", + ) + framework.ExpectNoError(err) + + ids, err := findComposeContainer( + ctx, + tc.dockerHelper, + tc.composeHelper, + workspace.UID, + "app", + ) + framework.ExpectNoError(err) + gomega.Expect(ids).To(gomega.HaveLen(1), "1 compose container to be created") + + listener, err := net.Listen("tcp", "127.0.0.1:0") + framework.ExpectNoError(err) + localPort := listener.Addr().(*net.TCPAddr).Port + framework.ExpectNoError(listener.Close()) + + done := make(chan error) + sshContext, sshCancel := context.WithCancel(context.Background()) + go func() { + cmd := exec.CommandContext( + sshContext, + filepath.Join(tc.f.DevpodBinDir, tc.f.DevpodBinName), + "ssh", + "--forward-ports", + fmt.Sprintf("%d:nginx:8080", localPort), + workspace.ID, + ) + + if err := cmd.Start(); err != nil { + done <- err + return + } + + if err := cmd.Wait(); err != nil { + done <- err + return + } + + done <- nil + }() + + gomega.Eventually(func(g gomega.Gomega) { + response, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d", localPort)) + g.Expect(err).NotTo(gomega.HaveOccurred()) + defer response.Body.Close() + + body, err := io.ReadAll(response.Body) + g.Expect(err).NotTo(gomega.HaveOccurred()) + g.Expect(body).To(gomega.ContainSubstring("Thank you for using nginx.")) + }). + WithPolling(1 * time.Second). + WithTimeout(20 * time.Second). + Should(gomega.Succeed()) + + sshCancel() + err = <-done + + gomega.Expect(err).To(gomega.Or( + gomega.MatchError("signal: killed"), + gomega.MatchError(context.Canceled), + )) + }, ginkgo.SpecTimeout(framework.GetTimeout())) + ginkgo.It("features", func(ctx context.Context) { tempDir, workspace, err := tc.setupAndStartWorkspace( ctx, From 293955a058e131a9a63cfc38cf739827eb0fdd36 Mon Sep 17 00:00:00 2001 From: David Zucker Date: Fri, 17 Apr 2026 15:08:37 +0900 Subject: [PATCH 2/6] fix: address ssh forward ports CI issues Format the new forward-port parser as expected by golangci-lint-fmt and satisfy lint in the compose e2e by handling the HTTP response close and documenting the controlled exec arguments. --- cmd/ssh_forward_ports.go | 4 +++- e2e/tests/up-docker-compose/up_docker_compose.go | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/cmd/ssh_forward_ports.go b/cmd/ssh_forward_ports.go index 1a5e0e951..5038be4d5 100644 --- a/cmd/ssh_forward_ports.go +++ b/cmd/ssh_forward_ports.go @@ -32,7 +32,9 @@ func parseForwardPortSpec(raw string) (port.Mapping, error) { } } -func newForwardPortMapping(localHost, localPort, remoteHost, remotePort string) (port.Mapping, error) { +func newForwardPortMapping( + localHost, localPort, remoteHost, remotePort string, +) (port.Mapping, error) { hostAddress, err := parseForwardLocalAddress(localHost, localPort) if err != nil { return port.Mapping{}, fmt.Errorf("parse host address: %w", err) diff --git a/e2e/tests/up-docker-compose/up_docker_compose.go b/e2e/tests/up-docker-compose/up_docker_compose.go index 6116abc4c..c81528c36 100644 --- a/e2e/tests/up-docker-compose/up_docker_compose.go +++ b/e2e/tests/up-docker-compose/up_docker_compose.go @@ -188,6 +188,7 @@ var _ = ginkgo.Describe( done := make(chan error) sshContext, sshCancel := context.WithCancel(context.Background()) go func() { + // #nosec G204 -- test command with controlled arguments cmd := exec.CommandContext( sshContext, filepath.Join(tc.f.DevpodBinDir, tc.f.DevpodBinName), @@ -213,7 +214,7 @@ var _ = ginkgo.Describe( gomega.Eventually(func(g gomega.Gomega) { response, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d", localPort)) g.Expect(err).NotTo(gomega.HaveOccurred()) - defer response.Body.Close() + defer func() { _ = response.Body.Close() }() body, err := io.ReadAll(response.Body) g.Expect(err).NotTo(gomega.HaveOccurred()) From 7897d987066a3efcb4dc1182b4467c8357c4807c Mon Sep 17 00:00:00 2001 From: David Zucker Date: Mon, 20 Apr 2026 14:48:23 +0900 Subject: [PATCH 3/6] feat: align ssh port parsing for forward and reverse Move the parser into pkg/port so both ssh forwarding directions share the same listen-versus-target rules while allowing service names on the dial side and documenting the new behavior. --- cmd/ssh.go | 99 ++++++++-------- cmd/ssh_forward_ports.go | 88 --------------- cmd/ssh_forward_ports_test.go | 65 ----------- .../connect-to-a-workspace.mdx | 15 +++ pkg/port/parse.go | 106 ++++++++++++++---- pkg/port/parse_test.go | 65 +++++++++++ 6 files changed, 210 insertions(+), 228 deletions(-) delete mode 100644 cmd/ssh_forward_ports.go delete mode 100644 cmd/ssh_forward_ports_test.go create mode 100644 pkg/port/parse_test.go diff --git a/cmd/ssh.go b/cmd/ssh.go index 8d1e82df7..6dff22104 100644 --- a/cmd/ssh.go +++ b/cmd/ssh.go @@ -116,8 +116,8 @@ func NewSSHCmd(f *flags.GlobalFlags) *cobra.Command { "host are to be forwarded to the given host, service name, and port, or Unix socket, on the remote side.") sshCmd.Flags(). StringArrayVarP(&cmd.ReverseForwardPorts, "reverse-forward-ports", "R", []string{}, - "Specifies that connections to the given TCP port or Unix socket on the local (client) "+ - "host are to be reverse forwarded to the given host and port, or Unix socket, on the remote side.") + "Specifies that connections to the given TCP port or Unix socket on the remote side "+ + "are to be reverse forwarded to the given local host, service name, and port, or Unix socket.") sshCmd.Flags(). StringArrayVarP(&cmd.SendEnvVars, "send-env", "", []string{}, "Specifies which local env variables shall be sent to the container.") @@ -400,44 +400,11 @@ func (cmd *SSHCmd) reverseForwardPorts( containerClient *ssh.Client, log log.Logger, ) error { - timeout, err := cmd.forwardTimeout(log) - if err != nil { - return fmt.Errorf("parse forward ports timeout: %w", err) - } - - errChan := make(chan error, len(cmd.ReverseForwardPorts)) - for _, portMapping := range cmd.ReverseForwardPorts { - mapping, err := port.ParsePortSpec(portMapping) - if err != nil { - return fmt.Errorf("parse port mapping: %w", err) - } - - // start the forwarding - log.Infof( - "Reverse forwarding local %s/%s to remote %s/%s", - mapping.Host.Protocol, - mapping.Host.Address, - mapping.Container.Protocol, - mapping.Container.Address, - ) - go func(portMapping string) { - err := devssh.ReversePortForward( - ctx, - containerClient, - mapping.Host.Protocol, - mapping.Host.Address, - mapping.Container.Protocol, - mapping.Container.Address, - timeout, - log, - ) - if !errors.Is(io.EOF, err) { - errChan <- fmt.Errorf("error forwarding %s: %w", portMapping, err) - } - }(portMapping) - } - - return <-errChan + return cmd.runPortForwards(ctx, containerClient, portForwardConfig{ + mappings: cmd.ReverseForwardPorts, + logTemplate: "Reverse forwarding remote %s/%s to local %s/%s", + forwardFn: devssh.ReversePortForward, + }, log) } func (cmd *SSHCmd) forwardPorts( @@ -445,28 +412,58 @@ func (cmd *SSHCmd) forwardPorts( containerClient *ssh.Client, log log.Logger, ) error { - timeout, err := cmd.forwardTimeout(log) + return cmd.runPortForwards(ctx, containerClient, portForwardConfig{ + mappings: cmd.ForwardPorts, + logTemplate: "Forwarding local %s/%s to remote %s/%s", + forwardFn: devssh.PortForward, + }, log) +} + +type portForwardFunc func( + context.Context, + *ssh.Client, + string, + string, + string, + string, + time.Duration, + log.Logger, +) error + +type portForwardConfig struct { + mappings []string + logTemplate string + forwardFn portForwardFunc +} + +func (cmd *SSHCmd) runPortForwards( + ctx context.Context, + containerClient *ssh.Client, + config portForwardConfig, + logger log.Logger, +) error { + timeout, err := cmd.forwardTimeout(logger) if err != nil { return fmt.Errorf("parse forward ports timeout: %w", err) } - errChan := make(chan error, len(cmd.ForwardPorts)) - for _, portMapping := range cmd.ForwardPorts { - mapping, err := parseForwardPortSpec(portMapping) + errChan := make(chan error, len(config.mappings)) + for _, portMapping := range config.mappings { + mapping, err := port.ParsePortSpec(portMapping) if err != nil { return fmt.Errorf("parse port mapping: %w", err) } // start the forwarding - log.Infof( - "Forwarding local %s/%s to remote %s/%s", + logger.Infof( + config.logTemplate, mapping.Host.Protocol, mapping.Host.Address, mapping.Container.Protocol, mapping.Container.Address, ) - go func(portMapping string) { - err := devssh.PortForward( + go func(portMapping string, mapping port.Mapping) { + err := config.forwardFn( ctx, containerClient, mapping.Host.Protocol, @@ -474,12 +471,12 @@ func (cmd *SSHCmd) forwardPorts( mapping.Container.Protocol, mapping.Container.Address, timeout, - log, + logger, ) - if !errors.Is(io.EOF, err) { + if err != nil && !errors.Is(err, io.EOF) { errChan <- fmt.Errorf("error forwarding %s: %w", portMapping, err) } - }(portMapping) + }(portMapping, mapping) } return <-errChan diff --git a/cmd/ssh_forward_ports.go b/cmd/ssh_forward_ports.go deleted file mode 100644 index 5038be4d5..000000000 --- a/cmd/ssh_forward_ports.go +++ /dev/null @@ -1,88 +0,0 @@ -package cmd - -import ( - "fmt" - "net" - "strconv" - "strings" - - "github.com/skevetter/devpod/pkg/port" -) - -func parseForwardPortSpec(raw string) (port.Mapping, error) { - parts := strings.Split(raw, ":") - - switch len(parts) { - case 1, 2: - return port.ParsePortSpec(raw) - case 3: - if !isPortNumber(parts[0]) { - return port.ParsePortSpec(raw) - } - - return newForwardPortMapping("", parts[0], parts[1], parts[2]) - case 4: - if parts[0] == "" { - return port.Mapping{}, fmt.Errorf("local host is empty") - } - - return newForwardPortMapping(parts[0], parts[1], parts[2], parts[3]) - default: - return port.Mapping{}, fmt.Errorf("unexpected port format: %s", raw) - } -} - -func newForwardPortMapping( - localHost, localPort, remoteHost, remotePort string, -) (port.Mapping, error) { - hostAddress, err := parseForwardLocalAddress(localHost, localPort) - if err != nil { - return port.Mapping{}, fmt.Errorf("parse host address: %w", err) - } - - containerAddress, err := parseForwardRemoteAddress(remoteHost, remotePort) - if err != nil { - return port.Mapping{}, fmt.Errorf("parse container address: %w", err) - } - - return port.Mapping{ - Host: hostAddress, - Container: containerAddress, - }, nil -} - -func parseForwardLocalAddress(host, rawPort string) (port.Address, error) { - return parseForwardTCPAddress(host, rawPort, false) -} - -func parseForwardRemoteAddress(host, rawPort string) (port.Address, error) { - if host == "" { - return port.Address{}, fmt.Errorf("remote host is empty") - } - - return parseForwardTCPAddress(host, rawPort, true) -} - -func parseForwardTCPAddress(host, rawPort string, allowHostnames bool) (port.Address, error) { - if !isPortNumber(rawPort) { - return port.Address{}, fmt.Errorf("invalid port %s", rawPort) - } - - if host == "" { - host = "localhost" - } - - if !allowHostnames && host != "localhost" && net.ParseIP(host) == nil { - return port.Address{}, fmt.Errorf("not an ip address %s", host) - } - - return port.Address{ - Protocol: "tcp", - Address: net.JoinHostPort(host, rawPort), - }, nil -} - -func isPortNumber(raw string) bool { - _, err := strconv.Atoi(raw) - return err == nil -} diff --git a/cmd/ssh_forward_ports_test.go b/cmd/ssh_forward_ports_test.go deleted file mode 100644 index 999076e14..000000000 --- a/cmd/ssh_forward_ports_test.go +++ /dev/null @@ -1,65 +0,0 @@ -package cmd - -import ( - "testing" - - "github.com/skevetter/devpod/pkg/port" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestParseForwardPortSpec_ServiceNameTarget(t *testing.T) { - mapping, err := parseForwardPortSpec("8080:nginx:80") - require.NoError(t, err) - assert.Equal(t, "tcp", mapping.Host.Protocol) - assert.Equal(t, "localhost:8080", mapping.Host.Address) - assert.Equal(t, "tcp", mapping.Container.Protocol) - assert.Equal(t, "nginx:80", mapping.Container.Address) -} - -func TestParseForwardPortSpec_ServiceNameTargetWithLocalBindHost(t *testing.T) { - mapping, err := parseForwardPortSpec("127.0.0.1:8080:nginx:80") - require.NoError(t, err) - assert.Equal(t, "127.0.0.1:8080", mapping.Host.Address) - assert.Equal(t, "nginx:80", mapping.Container.Address) -} - -func TestParseForwardPortSpec_PreservesLocalBindDisambiguation(t *testing.T) { - mapping, err := parseForwardPortSpec("localhost:8080:80") - require.NoError(t, err) - assert.Equal(t, "localhost:8080", mapping.Host.Address) - assert.Equal(t, "localhost:80", mapping.Container.Address) -} - -func TestParseForwardPortSpec_AllowsRemoteIPTargets(t *testing.T) { - mapping, err := parseForwardPortSpec("8080:10.0.0.2:80") - require.NoError(t, err) - assert.Equal(t, "localhost:8080", mapping.Host.Address) - assert.Equal(t, "10.0.0.2:80", mapping.Container.Address) -} - -func TestParseForwardPortSpec_RejectsNonIPLocalBindHost(t *testing.T) { - _, err := parseForwardPortSpec("app:8080:nginx:80") - require.Error(t, err) - assert.ErrorContains(t, err, "not an ip address app") -} - -func TestParseForwardPortSpec_RejectsEmptyRemoteHost(t *testing.T) { - _, err := parseForwardPortSpec("8080::80") - require.Error(t, err) - assert.ErrorContains(t, err, "remote host is empty") -} - -func TestParseForwardPortSpec_DelegatesUnixSocketMappings(t *testing.T) { - mapping, err := parseForwardPortSpec("/tmp/local.sock:/tmp/remote.sock") - require.NoError(t, err) - assert.Equal(t, port.Mapping{ - Host: port.Address{Protocol: "unix", Address: "/tmp/local.sock"}, - Container: port.Address{Protocol: "unix", Address: "/tmp/remote.sock"}, - }, mapping) -} - -func TestParseForwardPortSpec_DoesNotChangeSharedParser(t *testing.T) { - _, err := port.ParsePortSpec("8080:nginx:80") - require.Error(t, err) -} diff --git a/docs/pages/developing-in-workspaces/connect-to-a-workspace.mdx b/docs/pages/developing-in-workspaces/connect-to-a-workspace.mdx index 043456432..d77581208 100644 --- a/docs/pages/developing-in-workspaces/connect-to-a-workspace.mdx +++ b/docs/pages/developing-in-workspaces/connect-to-a-workspace.mdx @@ -92,6 +92,21 @@ Optionally you can also define a command to run: devpod ssh my-workspace --command "echo Hello World" ``` +You can also forward ports while using `devpod ssh`. +Target hosts can be service names or hostnames that are resolvable from the side that dials them. +For `--forward-ports`, that is the workspace. +For `--reverse-forward-ports`, that is the machine running `devpod ssh`. + +To forward a local port to a host or service reachable from the workspace: +``` +devpod ssh my-workspace --forward-ports 18080:nginx:8080 +``` + +To expose a port inside the workspace that forwards back to a host or service reachable from the machine running `devpod ssh`: +``` +devpod ssh my-workspace --reverse-forward-ports 15432:postgres.internal:5432 +``` + ## IDE Commands This section shows additional commands to configure DevPod's behavior when opening a workspace. diff --git a/pkg/port/parse.go b/pkg/port/parse.go index 3e44f84ae..4b65536d2 100644 --- a/pkg/port/parse.go +++ b/pkg/port/parse.go @@ -18,17 +18,24 @@ type Mapping struct { } func ParsePortSpec(port string) (Mapping, error) { - hostIP, hostPort, containerIP, containerPort, err := splitParts(port) + parts, err := splitParts(port) if err != nil { return Mapping{}, err } - hostAddress, err := toAddress(hostIP, hostPort) + hostAddress, err := toAddress(parts.host, parts.hostPort, addressOptions{ + emptyHostLabel: "listen host", + requireHost: parts.explicitHost, + }) if err != nil { return Mapping{}, fmt.Errorf("parse host address: %w", err) } - containerAddress, err := toAddress(containerIP, containerPort) + containerAddress, err := toAddress(parts.container, parts.containerPort, addressOptions{ + emptyHostLabel: "target host", + requireHost: parts.explicitContainer, + allowHostnames: true, + }) if err != nil { return Mapping{}, fmt.Errorf("parse container address: %w", err) } @@ -39,22 +46,51 @@ func ParsePortSpec(port string) (Mapping, error) { }, nil } -func toAddress(ip, port string) (Address, error) { - // check if port is integer - _, err := strconv.Atoi(port) - if err == nil { - if ip == "" { - ip = "localhost" - } +type splitResult struct { + host string + hostPort string + container string + containerPort string + explicitHost bool + explicitContainer bool +} + +type addressOptions struct { + emptyHostLabel string + requireHost bool + allowHostnames bool +} + +func toAddress(ip, port string, opts addressOptions) (Address, error) { + if isPortNumber(port) { + return toTCPAddress(ip, port, opts) + } - if ip != "localhost" && net.ParseIP(ip) == nil { - return Address{}, fmt.Errorf("not an ip address %s", ip) + return toUnixAddress(ip, port, opts) +} + +func toTCPAddress(ip, port string, opts addressOptions) (Address, error) { + if ip == "" { + if opts.requireHost { + return Address{}, fmt.Errorf("%s is empty", opts.emptyHostLabel) } - return Address{ - Protocol: "tcp", - Address: ip + ":" + port, - }, nil + ip = "localhost" + } + + if !opts.allowHostnames && ip != "localhost" && net.ParseIP(ip) == nil { + return Address{}, fmt.Errorf("not an ip address %s", ip) + } + + return Address{ + Protocol: "tcp", + Address: ip + ":" + port, + }, nil +} + +func toUnixAddress(ip, port string, opts addressOptions) (Address, error) { + if opts.requireHost && ip == "" { + return Address{}, fmt.Errorf("%s is empty", opts.emptyHostLabel) } if ip != "" { @@ -67,25 +103,47 @@ func toAddress(ip, port string) (Address, error) { }, nil } -func splitParts(rawport string) (string, string, string, string, error) { +func isPortNumber(raw string) bool { + _, err := strconv.Atoi(raw) + return err == nil +} + +func splitParts(rawport string) (splitResult, error) { parts := strings.Split(rawport, ":") n := len(parts) containerport := parts[n-1] switch n { case 1: - return "", containerport, "", containerport, nil + return splitResult{hostPort: containerport, containerPort: containerport}, nil case 2: - return "", parts[0], "", containerport, nil + return splitResult{hostPort: parts[0], containerPort: containerport}, nil case 3: - if parts[1] == "localhost" || net.ParseIP(parts[1]) != nil { - return "", parts[0], parts[1], containerport, nil + if isPortNumber(parts[0]) { + return splitResult{ + hostPort: parts[0], + container: parts[1], + containerPort: containerport, + explicitContainer: true, + }, nil } - return parts[0], parts[1], "", containerport, nil + return splitResult{ + host: parts[0], + hostPort: parts[1], + containerPort: containerport, + explicitHost: true, + }, nil case 4: - return parts[0], parts[1], parts[2], parts[3], nil + return splitResult{ + host: parts[0], + hostPort: parts[1], + container: parts[2], + containerPort: parts[3], + explicitHost: true, + explicitContainer: true, + }, nil default: - return "", "", "", "", fmt.Errorf("unexpected port format: %s", rawport) + return splitResult{}, fmt.Errorf("unexpected port format: %s", rawport) } } diff --git a/pkg/port/parse_test.go b/pkg/port/parse_test.go new file mode 100644 index 000000000..bf00149d5 --- /dev/null +++ b/pkg/port/parse_test.go @@ -0,0 +1,65 @@ +package port + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParsePortSpec_ServiceNameTarget(t *testing.T) { + mapping, err := ParsePortSpec("8080:nginx:80") + require.NoError(t, err) + assert.Equal(t, "tcp", mapping.Host.Protocol) + assert.Equal(t, "localhost:8080", mapping.Host.Address) + assert.Equal(t, "tcp", mapping.Container.Protocol) + assert.Equal(t, "nginx:80", mapping.Container.Address) +} + +func TestParsePortSpec_ServiceNameTargetWithLocalBindHost(t *testing.T) { + mapping, err := ParsePortSpec("127.0.0.1:8080:nginx:80") + require.NoError(t, err) + assert.Equal(t, "127.0.0.1:8080", mapping.Host.Address) + assert.Equal(t, "nginx:80", mapping.Container.Address) +} + +func TestParsePortSpec_PreservesLocalBindDisambiguation(t *testing.T) { + mapping, err := ParsePortSpec("localhost:8080:80") + require.NoError(t, err) + assert.Equal(t, "localhost:8080", mapping.Host.Address) + assert.Equal(t, "localhost:80", mapping.Container.Address) +} + +func TestParsePortSpec_AllowsTargetIPHosts(t *testing.T) { + mapping, err := ParsePortSpec("8080:10.0.0.2:80") + require.NoError(t, err) + assert.Equal(t, "localhost:8080", mapping.Host.Address) + assert.Equal(t, "10.0.0.2:80", mapping.Container.Address) +} + +func TestParsePortSpec_RejectsNonIPListenHost(t *testing.T) { + _, err := ParsePortSpec("app:8080:nginx:80") + require.Error(t, err) + assert.ErrorContains(t, err, "not an ip address app") +} + +func TestParsePortSpec_RejectsEmptyTargetHost(t *testing.T) { + _, err := ParsePortSpec("8080::80") + require.Error(t, err) + assert.ErrorContains(t, err, "target host is empty") +} + +func TestParsePortSpec_RejectsEmptyListenHost(t *testing.T) { + _, err := ParsePortSpec(":8080:nginx:80") + require.Error(t, err) + assert.ErrorContains(t, err, "listen host is empty") +} + +func TestParsePortSpec_DelegatesUnixSocketMappings(t *testing.T) { + mapping, err := ParsePortSpec("/tmp/local.sock:/tmp/remote.sock") + require.NoError(t, err) + assert.Equal(t, Mapping{ + Host: Address{Protocol: "unix", Address: "/tmp/local.sock"}, + Container: Address{Protocol: "unix", Address: "/tmp/remote.sock"}, + }, mapping) +} From 6b88d4a652d623978a3d2795918bdeb2725b2c7b Mon Sep 17 00:00:00 2001 From: David Zucker Date: Tue, 21 Apr 2026 09:07:15 +0900 Subject: [PATCH 4/6] fix: address ssh port forwarding review comments Wait for all forwarding goroutines to finish so clean exits cannot hang the ssh command, and reject empty unix socket paths during port spec parsing with regression tests for both cases. --- cmd/ssh.go | 18 +++++++++++++++++- cmd/ssh_test.go | 29 +++++++++++++++++++++++++++++ pkg/port/parse.go | 4 ++++ pkg/port/parse_test.go | 12 ++++++++++++ 4 files changed, 62 insertions(+), 1 deletion(-) diff --git a/cmd/ssh.go b/cmd/ssh.go index 6dff22104..953fa2f5b 100644 --- a/cmd/ssh.go +++ b/cmd/ssh.go @@ -10,6 +10,7 @@ import ( "os/exec" "path" "strings" + "sync" "time" "al.essio.dev/pkg/shellescape" @@ -448,6 +449,7 @@ func (cmd *SSHCmd) runPortForwards( } errChan := make(chan error, len(config.mappings)) + var waitGroup sync.WaitGroup for _, portMapping := range config.mappings { mapping, err := port.ParsePortSpec(portMapping) if err != nil { @@ -462,7 +464,10 @@ func (cmd *SSHCmd) runPortForwards( mapping.Container.Protocol, mapping.Container.Address, ) + waitGroup.Add(1) go func(portMapping string, mapping port.Mapping) { + defer waitGroup.Done() + err := config.forwardFn( ctx, containerClient, @@ -479,7 +484,18 @@ func (cmd *SSHCmd) runPortForwards( }(portMapping, mapping) } - return <-errChan + go func() { + waitGroup.Wait() + close(errChan) + }() + + for err := range errChan { + if err != nil { + return err + } + } + + return nil } func (cmd *SSHCmd) startTunnel( diff --git a/cmd/ssh_test.go b/cmd/ssh_test.go index 0b332dec7..688333991 100644 --- a/cmd/ssh_test.go +++ b/cmd/ssh_test.go @@ -1,12 +1,16 @@ package cmd import ( + "context" "os" "path/filepath" "testing" + "time" "github.com/skevetter/log" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" ) func writeGitConfig(t *testing.T, content string) { @@ -57,3 +61,28 @@ func TestGpgSigningKey_TildeKeyPath_Skipped(t *testing.T) { result := gpgSigningKey(log.Discard) assert.Empty(t, result) } + +func TestRunPortForwards_ReturnsOnCleanExit(t *testing.T) { + cmd := &SSHCmd{} + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + err := cmd.runPortForwards(ctx, nil, portForwardConfig{ + mappings: []string{"8080:80"}, + logTemplate: "test %s/%s %s/%s", + forwardFn: func( + context.Context, + *ssh.Client, + string, + string, + string, + string, + time.Duration, + log.Logger, + ) error { + return nil + }, + }, log.Discard) + + require.NoError(t, err) +} diff --git a/pkg/port/parse.go b/pkg/port/parse.go index 4b65536d2..45ac5d34a 100644 --- a/pkg/port/parse.go +++ b/pkg/port/parse.go @@ -89,6 +89,10 @@ func toTCPAddress(ip, port string, opts addressOptions) (Address, error) { } func toUnixAddress(ip, port string, opts addressOptions) (Address, error) { + if port == "" { + return Address{}, fmt.Errorf("%s is empty", opts.emptyHostLabel) + } + if opts.requireHost && ip == "" { return Address{}, fmt.Errorf("%s is empty", opts.emptyHostLabel) } diff --git a/pkg/port/parse_test.go b/pkg/port/parse_test.go index bf00149d5..40c9b84e1 100644 --- a/pkg/port/parse_test.go +++ b/pkg/port/parse_test.go @@ -63,3 +63,15 @@ func TestParsePortSpec_DelegatesUnixSocketMappings(t *testing.T) { Container: Address{Protocol: "unix", Address: "/tmp/remote.sock"}, }, mapping) } + +func TestParsePortSpec_RejectsEmptyListenUnixSocketPath(t *testing.T) { + _, err := ParsePortSpec(":/tmp/remote.sock") + require.Error(t, err) + assert.ErrorContains(t, err, "listen host is empty") +} + +func TestParsePortSpec_RejectsEmptyTargetUnixSocketPath(t *testing.T) { + _, err := ParsePortSpec("/tmp/local.sock:") + require.Error(t, err) + assert.ErrorContains(t, err, "target host is empty") +} From 392a96274b080a19a28f0841219fb559e2eb3721 Mon Sep 17 00:00:00 2001 From: David Zucker Date: Tue, 21 Apr 2026 10:59:31 +0900 Subject: [PATCH 5/6] fix: address remaining ssh forwarding review issues Return the parsed forward timeout value, validate all port mappings before any forwarders start, and add focused regression coverage for timeout parsing, EOF handling, parse failures, and mixed multi-mapping outcomes. --- cmd/ssh.go | 31 ++++++--- cmd/ssh_test.go | 176 ++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 177 insertions(+), 30 deletions(-) diff --git a/cmd/ssh.go b/cmd/ssh.go index 953fa2f5b..59894345a 100644 --- a/cmd/ssh.go +++ b/cmd/ssh.go @@ -383,16 +383,17 @@ func (cmd *SSHCmd) jumpContainer( } func (cmd *SSHCmd) forwardTimeout(log log.Logger) (time.Duration, error) { - timeout := time.Duration(0) - if cmd.ForwardPortsTimeout != "" { - timeout, err := time.ParseDuration(cmd.ForwardPortsTimeout) - if err != nil { - return timeout, fmt.Errorf("parse forward ports timeout: %w", err) - } + if cmd.ForwardPortsTimeout == "" { + return 0, nil + } - log.Infof("Using port forwarding timeout of %s", cmd.ForwardPortsTimeout) + timeout, err := time.ParseDuration(cmd.ForwardPortsTimeout) + if err != nil { + return 0, fmt.Errorf("parse forward ports timeout: %w", err) } + log.Infof("Using port forwarding timeout of %s", cmd.ForwardPortsTimeout) + return timeout, nil } @@ -437,6 +438,11 @@ type portForwardConfig struct { forwardFn portForwardFunc } +type parsedPortForward struct { + spec string + mapping port.Mapping +} + func (cmd *SSHCmd) runPortForwards( ctx context.Context, containerClient *ssh.Client, @@ -448,14 +454,21 @@ func (cmd *SSHCmd) runPortForwards( return fmt.Errorf("parse forward ports timeout: %w", err) } - errChan := make(chan error, len(config.mappings)) - var waitGroup sync.WaitGroup + parsedMappings := make([]parsedPortForward, 0, len(config.mappings)) for _, portMapping := range config.mappings { mapping, err := port.ParsePortSpec(portMapping) if err != nil { return fmt.Errorf("parse port mapping: %w", err) } + parsedMappings = append(parsedMappings, parsedPortForward{spec: portMapping, mapping: mapping}) + } + + errChan := make(chan error, len(parsedMappings)) + var waitGroup sync.WaitGroup + for _, parsedMapping := range parsedMappings { + portMapping, mapping := parsedMapping.spec, parsedMapping.mapping + // start the forwarding logger.Infof( config.logTemplate, diff --git a/cmd/ssh_test.go b/cmd/ssh_test.go index 688333991..c24d94e19 100644 --- a/cmd/ssh_test.go +++ b/cmd/ssh_test.go @@ -2,8 +2,11 @@ package cmd import ( "context" + "errors" + "io" "os" "path/filepath" + "sync/atomic" "testing" "time" @@ -62,27 +65,158 @@ func TestGpgSigningKey_TildeKeyPath_Skipped(t *testing.T) { assert.Empty(t, result) } -func TestRunPortForwards_ReturnsOnCleanExit(t *testing.T) { - cmd := &SSHCmd{} - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - err := cmd.runPortForwards(ctx, nil, portForwardConfig{ - mappings: []string{"8080:80"}, - logTemplate: "test %s/%s %s/%s", - forwardFn: func( - context.Context, - *ssh.Client, - string, - string, - string, - string, - time.Duration, - log.Logger, - ) error { - return nil - }, - }, log.Discard) +func TestForwardTimeout_UsesParsedDuration(t *testing.T) { + cmd := &SSHCmd{ForwardPortsTimeout: "90s"} + timeout, err := cmd.forwardTimeout(log.Discard) require.NoError(t, err) + assert.Equal(t, 90*time.Second, timeout) +} + +func TestRunPortForwards(t *testing.T) { + tests := []struct { + name string + mappings []string + forwardFn portForwardFunc + wantErr string + wantCalls int32 + }{ + { + name: "clean exit", + mappings: []string{"8080:80"}, + forwardFn: func( + context.Context, + *ssh.Client, + string, + string, + string, + string, + time.Duration, + log.Logger, + ) error { + return nil + }, + wantCalls: 1, + }, + { + name: "eof exit", + mappings: []string{"8080:80"}, + forwardFn: func( + context.Context, + *ssh.Client, + string, + string, + string, + string, + time.Duration, + log.Logger, + ) error { + return io.EOF + }, + wantCalls: 1, + }, + { + name: "forward error", + mappings: []string{"8080:80"}, + forwardFn: func( + context.Context, + *ssh.Client, + string, + string, + string, + string, + time.Duration, + log.Logger, + ) error { + return errors.New("boom") + }, + wantErr: "error forwarding 8080:80: boom", + wantCalls: 1, + }, + { + name: "parse error", + mappings: []string{""}, + forwardFn: func( + context.Context, + *ssh.Client, + string, + string, + string, + string, + time.Duration, + log.Logger, + ) error { + return nil + }, + wantErr: "parse port mapping", + wantCalls: 0, + }, + { + name: "multiple mappings with error", + mappings: []string{"8080:80", "8081:81"}, + forwardFn: func( + _ context.Context, + _ *ssh.Client, + _ string, + localAddr string, + _ string, + _ string, + _ time.Duration, + _ log.Logger, + ) error { + if localAddr == "localhost:8081" { + return errors.New("boom") + } + + return nil + }, + wantErr: "error forwarding 8081:81: boom", + wantCalls: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := &SSHCmd{} + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + var calls atomic.Int32 + err := cmd.runPortForwards(ctx, nil, portForwardConfig{ + mappings: tt.mappings, + logTemplate: "test %s/%s %s/%s", + forwardFn: func( + ctx context.Context, + client *ssh.Client, + localNetwork string, + localAddr string, + remoteNetwork string, + remoteAddr string, + timeout time.Duration, + logger log.Logger, + ) error { + calls.Add(1) + return tt.forwardFn( + ctx, + client, + localNetwork, + localAddr, + remoteNetwork, + remoteAddr, + timeout, + logger, + ) + }, + }, log.Discard) + + if tt.wantErr == "" { + require.NoError(t, err) + } else { + require.Error(t, err) + assert.ErrorContains(t, err, tt.wantErr) + } + + assert.Equal(t, tt.wantCalls, calls.Load()) + }) + } } From 62d12e9b8e50acd4ab7aa1a6b098730e03f2a2c0 Mon Sep 17 00:00:00 2001 From: David Zucker Date: Tue, 21 Apr 2026 15:29:24 +0900 Subject: [PATCH 6/6] fix: address ssh forwarding lint regressions Extract port mapping parsing to reduce command complexity, split the forwarding regression coverage into smaller focused tests, and make the multi-mapping error case deterministic under concurrency. --- cmd/ssh.go | 28 +++-- cmd/ssh_test.go | 313 ++++++++++++++++++++++++++---------------------- 2 files changed, 188 insertions(+), 153 deletions(-) diff --git a/cmd/ssh.go b/cmd/ssh.go index 59894345a..01027ab28 100644 --- a/cmd/ssh.go +++ b/cmd/ssh.go @@ -443,6 +443,23 @@ type parsedPortForward struct { mapping port.Mapping } +func parsePortForwards(mappings []string) ([]parsedPortForward, error) { + parsedMappings := make([]parsedPortForward, 0, len(mappings)) + for _, portMapping := range mappings { + mapping, err := port.ParsePortSpec(portMapping) + if err != nil { + return nil, fmt.Errorf("parse port mapping: %w", err) + } + + parsedMappings = append(parsedMappings, parsedPortForward{ + spec: portMapping, + mapping: mapping, + }) + } + + return parsedMappings, nil +} + func (cmd *SSHCmd) runPortForwards( ctx context.Context, containerClient *ssh.Client, @@ -454,14 +471,9 @@ func (cmd *SSHCmd) runPortForwards( return fmt.Errorf("parse forward ports timeout: %w", err) } - parsedMappings := make([]parsedPortForward, 0, len(config.mappings)) - for _, portMapping := range config.mappings { - mapping, err := port.ParsePortSpec(portMapping) - if err != nil { - return fmt.Errorf("parse port mapping: %w", err) - } - - parsedMappings = append(parsedMappings, parsedPortForward{spec: portMapping, mapping: mapping}) + parsedMappings, err := parsePortForwards(config.mappings) + if err != nil { + return err } errChan := make(chan error, len(parsedMappings)) diff --git a/cmd/ssh_test.go b/cmd/ssh_test.go index c24d94e19..b30f0e005 100644 --- a/cmd/ssh_test.go +++ b/cmd/ssh_test.go @@ -73,150 +73,173 @@ func TestForwardTimeout_UsesParsedDuration(t *testing.T) { assert.Equal(t, 90*time.Second, timeout) } -func TestRunPortForwards(t *testing.T) { - tests := []struct { - name string - mappings []string - forwardFn portForwardFunc - wantErr string - wantCalls int32 - }{ - { - name: "clean exit", - mappings: []string{"8080:80"}, - forwardFn: func( - context.Context, - *ssh.Client, - string, - string, - string, - string, - time.Duration, - log.Logger, - ) error { - return nil - }, - wantCalls: 1, - }, - { - name: "eof exit", - mappings: []string{"8080:80"}, - forwardFn: func( - context.Context, - *ssh.Client, - string, - string, - string, - string, - time.Duration, - log.Logger, - ) error { - return io.EOF - }, - wantCalls: 1, - }, - { - name: "forward error", - mappings: []string{"8080:80"}, - forwardFn: func( - context.Context, - *ssh.Client, - string, - string, - string, - string, - time.Duration, - log.Logger, - ) error { - return errors.New("boom") - }, - wantErr: "error forwarding 8080:80: boom", - wantCalls: 1, - }, - { - name: "parse error", - mappings: []string{""}, - forwardFn: func( - context.Context, - *ssh.Client, - string, - string, - string, - string, - time.Duration, - log.Logger, - ) error { - return nil - }, - wantErr: "parse port mapping", - wantCalls: 0, - }, - { - name: "multiple mappings with error", - mappings: []string{"8080:80", "8081:81"}, - forwardFn: func( - _ context.Context, - _ *ssh.Client, - _ string, - localAddr string, - _ string, - _ string, - _ time.Duration, - _ log.Logger, - ) error { - if localAddr == "localhost:8081" { - return errors.New("boom") - } - - return nil - }, - wantErr: "error forwarding 8081:81: boom", - wantCalls: 2, +func runPortForwardsForTest( + t *testing.T, + cmd *SSHCmd, + mappings []string, + forwardFn portForwardFunc, +) (int32, error) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + var calls atomic.Int32 + err := cmd.runPortForwards(ctx, nil, portForwardConfig{ + mappings: mappings, + logTemplate: "test %s/%s %s/%s", + forwardFn: func( + ctx context.Context, + client *ssh.Client, + localNetwork string, + localAddr string, + remoteNetwork string, + remoteAddr string, + timeout time.Duration, + logger log.Logger, + ) error { + calls.Add(1) + return forwardFn( + ctx, + client, + localNetwork, + localAddr, + remoteNetwork, + remoteAddr, + timeout, + logger, + ) }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cmd := &SSHCmd{} - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - var calls atomic.Int32 - err := cmd.runPortForwards(ctx, nil, portForwardConfig{ - mappings: tt.mappings, - logTemplate: "test %s/%s %s/%s", - forwardFn: func( - ctx context.Context, - client *ssh.Client, - localNetwork string, - localAddr string, - remoteNetwork string, - remoteAddr string, - timeout time.Duration, - logger log.Logger, - ) error { - calls.Add(1) - return tt.forwardFn( - ctx, - client, - localNetwork, - localAddr, - remoteNetwork, - remoteAddr, - timeout, - logger, - ) - }, - }, log.Discard) - - if tt.wantErr == "" { - require.NoError(t, err) - } else { - require.Error(t, err) - assert.ErrorContains(t, err, tt.wantErr) - } - - assert.Equal(t, tt.wantCalls, calls.Load()) - }) - } + }, log.Discard) + + return calls.Load(), err +} + +func TestRunPortForwards_CleanExit(t *testing.T) { + calls, err := runPortForwardsForTest(t, &SSHCmd{}, []string{"8080:80"}, func( + context.Context, + *ssh.Client, + string, + string, + string, + string, + time.Duration, + log.Logger, + ) error { + return nil + }) + + require.NoError(t, err) + assert.Equal(t, int32(1), calls) +} + +func TestRunPortForwards_EOFExit(t *testing.T) { + calls, err := runPortForwardsForTest(t, &SSHCmd{}, []string{"8080:80"}, func( + context.Context, + *ssh.Client, + string, + string, + string, + string, + time.Duration, + log.Logger, + ) error { + return io.EOF + }) + + require.NoError(t, err) + assert.Equal(t, int32(1), calls) +} + +func TestRunPortForwards_ForwardError(t *testing.T) { + calls, err := runPortForwardsForTest(t, &SSHCmd{}, []string{"8080:80"}, func( + context.Context, + *ssh.Client, + string, + string, + string, + string, + time.Duration, + log.Logger, + ) error { + return errors.New("boom") + }) + + require.Error(t, err) + assert.ErrorContains(t, err, "error forwarding 8080:80: boom") + assert.Equal(t, int32(1), calls) +} + +func TestRunPortForwards_UsesConfiguredTimeout(t *testing.T) { + cmd := &SSHCmd{ForwardPortsTimeout: "90s"} + var gotTimeout time.Duration + + calls, err := runPortForwardsForTest(t, cmd, []string{"8080:80"}, func( + _ context.Context, + _ *ssh.Client, + _ string, + _ string, + _ string, + _ string, + timeout time.Duration, + _ log.Logger, + ) error { + gotTimeout = timeout + return nil + }) + + require.NoError(t, err) + assert.Equal(t, int32(1), calls) + assert.Equal(t, 90*time.Second, gotTimeout) +} + +func TestRunPortForwards_ParseErrorStopsBeforeLaunch(t *testing.T) { + calls, err := runPortForwardsForTest(t, &SSHCmd{}, []string{"8080:80", ""}, func( + context.Context, + *ssh.Client, + string, + string, + string, + string, + time.Duration, + log.Logger, + ) error { + return nil + }) + + require.Error(t, err) + assert.ErrorContains(t, err, "parse port mapping") + assert.Equal(t, int32(0), calls) +} + +func TestRunPortForwards_MultipleMappingsReturnError(t *testing.T) { + var started atomic.Int32 + ready := make(chan struct{}) + + calls, err := runPortForwardsForTest(t, &SSHCmd{}, []string{"8080:80", "8081:81"}, func( + _ context.Context, + _ *ssh.Client, + _ string, + localAddr string, + _ string, + _ string, + _ time.Duration, + _ log.Logger, + ) error { + if started.Add(1) == 2 { + close(ready) + } else { + <-ready + } + + if localAddr == "localhost:8081" { + return errors.New("boom") + } + + return nil + }) + + require.Error(t, err) + assert.ErrorContains(t, err, "error forwarding 8081:81: boom") + assert.Equal(t, int32(2), calls) }