diff --git a/cmd/ssh.go b/cmd/ssh.go index 0bf3ea1ee..01027ab28 100644 --- a/cmd/ssh.go +++ b/cmd/ssh.go @@ -10,6 +10,7 @@ import ( "os/exec" "path" "strings" + "sync" "time" "al.essio.dev/pkg/shellescape" @@ -113,11 +114,11 @@ func NewSSHCmd(f *flags.GlobalFlags) *cobra.Command { sshCmd.Flags(). StringArrayVarP(&cmd.ForwardPorts, "forward-ports", "L", []string{}, "Specifies that connections to the given TCP port or Unix socket on the local (client) "+ - "host are to be forwarded to the given host and port, or Unix socket, on the remote side.") + "host are to be forwarded to the given host, service name, and port, or Unix socket, on the remote side.") sshCmd.Flags(). StringArrayVarP(&cmd.ReverseForwardPorts, "reverse-forward-ports", "R", []string{}, - "Specifies that connections to the given TCP port or Unix socket on the local (client) "+ - "host are to be reverse forwarded to the given host and port, or Unix socket, on the remote side.") + "Specifies that connections to the given TCP port or Unix socket on the remote side "+ + "are to be reverse forwarded to the given local host, service name, and port, or Unix socket.") sshCmd.Flags(). StringArrayVarP(&cmd.SendEnvVars, "send-env", "", []string{}, "Specifies which local env variables shall be sent to the container.") @@ -382,16 +383,17 @@ func (cmd *SSHCmd) jumpContainer( } func (cmd *SSHCmd) forwardTimeout(log log.Logger) (time.Duration, error) { - timeout := time.Duration(0) - if cmd.ForwardPortsTimeout != "" { - timeout, err := time.ParseDuration(cmd.ForwardPortsTimeout) - if err != nil { - return timeout, fmt.Errorf("parse forward ports timeout: %w", err) - } + if cmd.ForwardPortsTimeout == "" { + return 0, nil + } - log.Infof("Using port forwarding timeout of %s", cmd.ForwardPortsTimeout) + timeout, err := time.ParseDuration(cmd.ForwardPortsTimeout) + if err != nil { + return 0, fmt.Errorf("parse forward ports timeout: %w", err) } + log.Infof("Using port forwarding timeout of %s", cmd.ForwardPortsTimeout) + return timeout, nil } @@ -400,73 +402,98 @@ func (cmd *SSHCmd) reverseForwardPorts( containerClient *ssh.Client, log log.Logger, ) error { - timeout, err := cmd.forwardTimeout(log) - if err != nil { - return fmt.Errorf("parse forward ports timeout: %w", err) - } + return cmd.runPortForwards(ctx, containerClient, portForwardConfig{ + mappings: cmd.ReverseForwardPorts, + logTemplate: "Reverse forwarding remote %s/%s to local %s/%s", + forwardFn: devssh.ReversePortForward, + }, log) +} - errChan := make(chan error, len(cmd.ReverseForwardPorts)) - for _, portMapping := range cmd.ReverseForwardPorts { +func (cmd *SSHCmd) forwardPorts( + ctx context.Context, + containerClient *ssh.Client, + log log.Logger, +) error { + return cmd.runPortForwards(ctx, containerClient, portForwardConfig{ + mappings: cmd.ForwardPorts, + logTemplate: "Forwarding local %s/%s to remote %s/%s", + forwardFn: devssh.PortForward, + }, log) +} + +type portForwardFunc func( + context.Context, + *ssh.Client, + string, + string, + string, + string, + time.Duration, + log.Logger, +) error + +type portForwardConfig struct { + mappings []string + logTemplate string + forwardFn portForwardFunc +} + +type parsedPortForward struct { + spec string + mapping port.Mapping +} + +func parsePortForwards(mappings []string) ([]parsedPortForward, error) { + parsedMappings := make([]parsedPortForward, 0, len(mappings)) + for _, portMapping := range mappings { mapping, err := port.ParsePortSpec(portMapping) if err != nil { - return fmt.Errorf("parse port mapping: %w", err) + return nil, fmt.Errorf("parse port mapping: %w", err) } - // start the forwarding - log.Infof( - "Reverse forwarding local %s/%s to remote %s/%s", - mapping.Host.Protocol, - mapping.Host.Address, - mapping.Container.Protocol, - mapping.Container.Address, - ) - go func(portMapping string) { - err := devssh.ReversePortForward( - ctx, - containerClient, - mapping.Host.Protocol, - mapping.Host.Address, - mapping.Container.Protocol, - mapping.Container.Address, - timeout, - log, - ) - if !errors.Is(io.EOF, err) { - errChan <- fmt.Errorf("error forwarding %s: %w", portMapping, err) - } - }(portMapping) + parsedMappings = append(parsedMappings, parsedPortForward{ + spec: portMapping, + mapping: mapping, + }) } - return <-errChan + return parsedMappings, nil } -func (cmd *SSHCmd) forwardPorts( +func (cmd *SSHCmd) runPortForwards( ctx context.Context, containerClient *ssh.Client, - log log.Logger, + config portForwardConfig, + logger log.Logger, ) error { - timeout, err := cmd.forwardTimeout(log) + timeout, err := cmd.forwardTimeout(logger) if err != nil { return fmt.Errorf("parse forward ports timeout: %w", err) } - errChan := make(chan error, len(cmd.ForwardPorts)) - for _, portMapping := range cmd.ForwardPorts { - mapping, err := port.ParsePortSpec(portMapping) - if err != nil { - return fmt.Errorf("parse port mapping: %w", err) - } + parsedMappings, err := parsePortForwards(config.mappings) + if err != nil { + return err + } + + errChan := make(chan error, len(parsedMappings)) + var waitGroup sync.WaitGroup + for _, parsedMapping := range parsedMappings { + portMapping, mapping := parsedMapping.spec, parsedMapping.mapping // start the forwarding - log.Infof( - "Forwarding local %s/%s to remote %s/%s", + logger.Infof( + config.logTemplate, mapping.Host.Protocol, mapping.Host.Address, mapping.Container.Protocol, mapping.Container.Address, ) - go func(portMapping string) { - err := devssh.PortForward( + waitGroup.Add(1) + go func(portMapping string, mapping port.Mapping) { + defer waitGroup.Done() + + err := config.forwardFn( ctx, containerClient, mapping.Host.Protocol, @@ -474,15 +501,26 @@ func (cmd *SSHCmd) forwardPorts( mapping.Container.Protocol, mapping.Container.Address, timeout, - log, + logger, ) - if !errors.Is(io.EOF, err) { + if err != nil && !errors.Is(err, io.EOF) { errChan <- fmt.Errorf("error forwarding %s: %w", portMapping, err) } - }(portMapping) + }(portMapping, mapping) + } + + go func() { + waitGroup.Wait() + close(errChan) + }() + + for err := range errChan { + if err != nil { + return err + } } - return <-errChan + return nil } func (cmd *SSHCmd) startTunnel( diff --git a/cmd/ssh_test.go b/cmd/ssh_test.go index 0b332dec7..b30f0e005 100644 --- a/cmd/ssh_test.go +++ b/cmd/ssh_test.go @@ -1,12 +1,19 @@ package cmd import ( + "context" + "errors" + "io" "os" "path/filepath" + "sync/atomic" "testing" + "time" "github.com/skevetter/log" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" ) func writeGitConfig(t *testing.T, content string) { @@ -57,3 +64,182 @@ func TestGpgSigningKey_TildeKeyPath_Skipped(t *testing.T) { result := gpgSigningKey(log.Discard) assert.Empty(t, result) } + +func TestForwardTimeout_UsesParsedDuration(t *testing.T) { + cmd := &SSHCmd{ForwardPortsTimeout: "90s"} + + timeout, err := cmd.forwardTimeout(log.Discard) + require.NoError(t, err) + assert.Equal(t, 90*time.Second, timeout) +} + +func runPortForwardsForTest( + t *testing.T, + cmd *SSHCmd, + mappings []string, + forwardFn portForwardFunc, +) (int32, error) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + var calls atomic.Int32 + err := cmd.runPortForwards(ctx, nil, portForwardConfig{ + mappings: mappings, + logTemplate: "test %s/%s %s/%s", + forwardFn: func( + ctx context.Context, + client *ssh.Client, + localNetwork string, + localAddr string, + remoteNetwork string, + remoteAddr string, + timeout time.Duration, + logger log.Logger, + ) error { + calls.Add(1) + return forwardFn( + ctx, + client, + localNetwork, + localAddr, + remoteNetwork, + remoteAddr, + timeout, + logger, + ) + }, + }, log.Discard) + + return calls.Load(), err +} + +func TestRunPortForwards_CleanExit(t *testing.T) { + calls, err := runPortForwardsForTest(t, &SSHCmd{}, []string{"8080:80"}, func( + context.Context, + *ssh.Client, + string, + string, + string, + string, + time.Duration, + log.Logger, + ) error { + return nil + }) + + require.NoError(t, err) + assert.Equal(t, int32(1), calls) +} + +func TestRunPortForwards_EOFExit(t *testing.T) { + calls, err := runPortForwardsForTest(t, &SSHCmd{}, []string{"8080:80"}, func( + context.Context, + *ssh.Client, + string, + string, + string, + string, + time.Duration, + log.Logger, + ) error { + return io.EOF + }) + + require.NoError(t, err) + assert.Equal(t, int32(1), calls) +} + +func TestRunPortForwards_ForwardError(t *testing.T) { + calls, err := runPortForwardsForTest(t, &SSHCmd{}, []string{"8080:80"}, func( + context.Context, + *ssh.Client, + string, + string, + string, + string, + time.Duration, + log.Logger, + ) error { + return errors.New("boom") + }) + + require.Error(t, err) + assert.ErrorContains(t, err, "error forwarding 8080:80: boom") + assert.Equal(t, int32(1), calls) +} + +func TestRunPortForwards_UsesConfiguredTimeout(t *testing.T) { + cmd := &SSHCmd{ForwardPortsTimeout: "90s"} + var gotTimeout time.Duration + + calls, err := runPortForwardsForTest(t, cmd, []string{"8080:80"}, func( + _ context.Context, + _ *ssh.Client, + _ string, + _ string, + _ string, + _ string, + timeout time.Duration, + _ log.Logger, + ) error { + gotTimeout = timeout + return nil + }) + + require.NoError(t, err) + assert.Equal(t, int32(1), calls) + assert.Equal(t, 90*time.Second, gotTimeout) +} + +func TestRunPortForwards_ParseErrorStopsBeforeLaunch(t *testing.T) { + calls, err := runPortForwardsForTest(t, &SSHCmd{}, []string{"8080:80", ""}, func( + context.Context, + *ssh.Client, + string, + string, + string, + string, + time.Duration, + log.Logger, + ) error { + return nil + }) + + require.Error(t, err) + assert.ErrorContains(t, err, "parse port mapping") + assert.Equal(t, int32(0), calls) +} + +func TestRunPortForwards_MultipleMappingsReturnError(t *testing.T) { + var started atomic.Int32 + ready := make(chan struct{}) + + calls, err := runPortForwardsForTest(t, &SSHCmd{}, []string{"8080:80", "8081:81"}, func( + _ context.Context, + _ *ssh.Client, + _ string, + localAddr string, + _ string, + _ string, + _ time.Duration, + _ log.Logger, + ) error { + if started.Add(1) == 2 { + close(ready) + } else { + <-ready + } + + if localAddr == "localhost:8081" { + return errors.New("boom") + } + + return nil + }) + + require.Error(t, err) + assert.ErrorContains(t, err, "error forwarding 8081:81: boom") + assert.Equal(t, int32(2), calls) +} diff --git a/docs/pages/developing-in-workspaces/connect-to-a-workspace.mdx b/docs/pages/developing-in-workspaces/connect-to-a-workspace.mdx index 043456432..d77581208 100644 --- a/docs/pages/developing-in-workspaces/connect-to-a-workspace.mdx +++ b/docs/pages/developing-in-workspaces/connect-to-a-workspace.mdx @@ -92,6 +92,21 @@ Optionally you can also define a command to run: devpod ssh my-workspace --command "echo Hello World" ``` +You can also forward ports while using `devpod ssh`. +Target hosts can be service names or hostnames that are resolvable from the side that dials them. +For `--forward-ports`, that is the workspace. +For `--reverse-forward-ports`, that is the machine running `devpod ssh`. + +To forward a local port to a host or service reachable from the workspace: +``` +devpod ssh my-workspace --forward-ports 18080:nginx:8080 +``` + +To expose a port inside the workspace that forwards back to a host or service reachable from the machine running `devpod ssh`: +``` +devpod ssh my-workspace --reverse-forward-ports 15432:postgres.internal:5432 +``` + ## IDE Commands This section shows additional commands to configure DevPod's behavior when opening a workspace. diff --git a/e2e/tests/up-docker-compose/up_docker_compose.go b/e2e/tests/up-docker-compose/up_docker_compose.go index 704fab2b8..c81528c36 100644 --- a/e2e/tests/up-docker-compose/up_docker_compose.go +++ b/e2e/tests/up-docker-compose/up_docker_compose.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "io" + "net" "net/http" "os" "os/exec" @@ -161,6 +162,77 @@ var _ = ginkgo.Describe( )) }, ginkgo.SpecTimeout(framework.GetTimeout())) + ginkgo.It("ssh forward ports support remote service names", func(ctx context.Context) { + _, workspace, err := tc.setupAndStartWorkspace( + ctx, + "tests/up-docker-compose/testdata/docker-compose-forward-ports", + "--debug", + ) + framework.ExpectNoError(err) + + ids, err := findComposeContainer( + ctx, + tc.dockerHelper, + tc.composeHelper, + workspace.UID, + "app", + ) + framework.ExpectNoError(err) + gomega.Expect(ids).To(gomega.HaveLen(1), "1 compose container to be created") + + listener, err := net.Listen("tcp", "127.0.0.1:0") + framework.ExpectNoError(err) + localPort := listener.Addr().(*net.TCPAddr).Port + framework.ExpectNoError(listener.Close()) + + done := make(chan error) + sshContext, sshCancel := context.WithCancel(context.Background()) + go func() { + // #nosec G204 -- test command with controlled arguments + cmd := exec.CommandContext( + sshContext, + filepath.Join(tc.f.DevpodBinDir, tc.f.DevpodBinName), + "ssh", + "--forward-ports", + fmt.Sprintf("%d:nginx:8080", localPort), + workspace.ID, + ) + + if err := cmd.Start(); err != nil { + done <- err + return + } + + if err := cmd.Wait(); err != nil { + done <- err + return + } + + done <- nil + }() + + gomega.Eventually(func(g gomega.Gomega) { + response, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d", localPort)) + g.Expect(err).NotTo(gomega.HaveOccurred()) + defer func() { _ = response.Body.Close() }() + + body, err := io.ReadAll(response.Body) + g.Expect(err).NotTo(gomega.HaveOccurred()) + g.Expect(body).To(gomega.ContainSubstring("Thank you for using nginx.")) + }). + WithPolling(1 * time.Second). + WithTimeout(20 * time.Second). + Should(gomega.Succeed()) + + sshCancel() + err = <-done + + gomega.Expect(err).To(gomega.Or( + gomega.MatchError("signal: killed"), + gomega.MatchError(context.Canceled), + )) + }, ginkgo.SpecTimeout(framework.GetTimeout())) + ginkgo.It("features", func(ctx context.Context) { tempDir, workspace, err := tc.setupAndStartWorkspace( ctx, diff --git a/pkg/port/parse.go b/pkg/port/parse.go index 3e44f84ae..45ac5d34a 100644 --- a/pkg/port/parse.go +++ b/pkg/port/parse.go @@ -18,17 +18,24 @@ type Mapping struct { } func ParsePortSpec(port string) (Mapping, error) { - hostIP, hostPort, containerIP, containerPort, err := splitParts(port) + parts, err := splitParts(port) if err != nil { return Mapping{}, err } - hostAddress, err := toAddress(hostIP, hostPort) + hostAddress, err := toAddress(parts.host, parts.hostPort, addressOptions{ + emptyHostLabel: "listen host", + requireHost: parts.explicitHost, + }) if err != nil { return Mapping{}, fmt.Errorf("parse host address: %w", err) } - containerAddress, err := toAddress(containerIP, containerPort) + containerAddress, err := toAddress(parts.container, parts.containerPort, addressOptions{ + emptyHostLabel: "target host", + requireHost: parts.explicitContainer, + allowHostnames: true, + }) if err != nil { return Mapping{}, fmt.Errorf("parse container address: %w", err) } @@ -39,22 +46,55 @@ func ParsePortSpec(port string) (Mapping, error) { }, nil } -func toAddress(ip, port string) (Address, error) { - // check if port is integer - _, err := strconv.Atoi(port) - if err == nil { - if ip == "" { - ip = "localhost" - } +type splitResult struct { + host string + hostPort string + container string + containerPort string + explicitHost bool + explicitContainer bool +} + +type addressOptions struct { + emptyHostLabel string + requireHost bool + allowHostnames bool +} + +func toAddress(ip, port string, opts addressOptions) (Address, error) { + if isPortNumber(port) { + return toTCPAddress(ip, port, opts) + } - if ip != "localhost" && net.ParseIP(ip) == nil { - return Address{}, fmt.Errorf("not an ip address %s", ip) + return toUnixAddress(ip, port, opts) +} + +func toTCPAddress(ip, port string, opts addressOptions) (Address, error) { + if ip == "" { + if opts.requireHost { + return Address{}, fmt.Errorf("%s is empty", opts.emptyHostLabel) } - return Address{ - Protocol: "tcp", - Address: ip + ":" + port, - }, nil + ip = "localhost" + } + + if !opts.allowHostnames && ip != "localhost" && net.ParseIP(ip) == nil { + return Address{}, fmt.Errorf("not an ip address %s", ip) + } + + return Address{ + Protocol: "tcp", + Address: ip + ":" + port, + }, nil +} + +func toUnixAddress(ip, port string, opts addressOptions) (Address, error) { + if port == "" { + return Address{}, fmt.Errorf("%s is empty", opts.emptyHostLabel) + } + + if opts.requireHost && ip == "" { + return Address{}, fmt.Errorf("%s is empty", opts.emptyHostLabel) } if ip != "" { @@ -67,25 +107,47 @@ func toAddress(ip, port string) (Address, error) { }, nil } -func splitParts(rawport string) (string, string, string, string, error) { +func isPortNumber(raw string) bool { + _, err := strconv.Atoi(raw) + return err == nil +} + +func splitParts(rawport string) (splitResult, error) { parts := strings.Split(rawport, ":") n := len(parts) containerport := parts[n-1] switch n { case 1: - return "", containerport, "", containerport, nil + return splitResult{hostPort: containerport, containerPort: containerport}, nil case 2: - return "", parts[0], "", containerport, nil + return splitResult{hostPort: parts[0], containerPort: containerport}, nil case 3: - if parts[1] == "localhost" || net.ParseIP(parts[1]) != nil { - return "", parts[0], parts[1], containerport, nil + if isPortNumber(parts[0]) { + return splitResult{ + hostPort: parts[0], + container: parts[1], + containerPort: containerport, + explicitContainer: true, + }, nil } - return parts[0], parts[1], "", containerport, nil + return splitResult{ + host: parts[0], + hostPort: parts[1], + containerPort: containerport, + explicitHost: true, + }, nil case 4: - return parts[0], parts[1], parts[2], parts[3], nil + return splitResult{ + host: parts[0], + hostPort: parts[1], + container: parts[2], + containerPort: parts[3], + explicitHost: true, + explicitContainer: true, + }, nil default: - return "", "", "", "", fmt.Errorf("unexpected port format: %s", rawport) + return splitResult{}, fmt.Errorf("unexpected port format: %s", rawport) } } diff --git a/pkg/port/parse_test.go b/pkg/port/parse_test.go new file mode 100644 index 000000000..40c9b84e1 --- /dev/null +++ b/pkg/port/parse_test.go @@ -0,0 +1,77 @@ +package port + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParsePortSpec_ServiceNameTarget(t *testing.T) { + mapping, err := ParsePortSpec("8080:nginx:80") + require.NoError(t, err) + assert.Equal(t, "tcp", mapping.Host.Protocol) + assert.Equal(t, "localhost:8080", mapping.Host.Address) + assert.Equal(t, "tcp", mapping.Container.Protocol) + assert.Equal(t, "nginx:80", mapping.Container.Address) +} + +func TestParsePortSpec_ServiceNameTargetWithLocalBindHost(t *testing.T) { + mapping, err := ParsePortSpec("127.0.0.1:8080:nginx:80") + require.NoError(t, err) + assert.Equal(t, "127.0.0.1:8080", mapping.Host.Address) + assert.Equal(t, "nginx:80", mapping.Container.Address) +} + +func TestParsePortSpec_PreservesLocalBindDisambiguation(t *testing.T) { + mapping, err := ParsePortSpec("localhost:8080:80") + require.NoError(t, err) + assert.Equal(t, "localhost:8080", mapping.Host.Address) + assert.Equal(t, "localhost:80", mapping.Container.Address) +} + +func TestParsePortSpec_AllowsTargetIPHosts(t *testing.T) { + mapping, err := ParsePortSpec("8080:10.0.0.2:80") + require.NoError(t, err) + assert.Equal(t, "localhost:8080", mapping.Host.Address) + assert.Equal(t, "10.0.0.2:80", mapping.Container.Address) +} + +func TestParsePortSpec_RejectsNonIPListenHost(t *testing.T) { + _, err := ParsePortSpec("app:8080:nginx:80") + require.Error(t, err) + assert.ErrorContains(t, err, "not an ip address app") +} + +func TestParsePortSpec_RejectsEmptyTargetHost(t *testing.T) { + _, err := ParsePortSpec("8080::80") + require.Error(t, err) + assert.ErrorContains(t, err, "target host is empty") +} + +func TestParsePortSpec_RejectsEmptyListenHost(t *testing.T) { + _, err := ParsePortSpec(":8080:nginx:80") + require.Error(t, err) + assert.ErrorContains(t, err, "listen host is empty") +} + +func TestParsePortSpec_DelegatesUnixSocketMappings(t *testing.T) { + mapping, err := ParsePortSpec("/tmp/local.sock:/tmp/remote.sock") + require.NoError(t, err) + assert.Equal(t, Mapping{ + Host: Address{Protocol: "unix", Address: "/tmp/local.sock"}, + Container: Address{Protocol: "unix", Address: "/tmp/remote.sock"}, + }, mapping) +} + +func TestParsePortSpec_RejectsEmptyListenUnixSocketPath(t *testing.T) { + _, err := ParsePortSpec(":/tmp/remote.sock") + require.Error(t, err) + assert.ErrorContains(t, err, "listen host is empty") +} + +func TestParsePortSpec_RejectsEmptyTargetUnixSocketPath(t *testing.T) { + _, err := ParsePortSpec("/tmp/local.sock:") + require.Error(t, err) + assert.ErrorContains(t, err, "target host is empty") +}