Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 65 additions & 52 deletions cmd/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"os/exec"
"path"
"strings"
"sync"
"time"

"al.essio.dev/pkg/shellescape"
Expand Down Expand Up @@ -113,11 +114,11 @@ func NewSSHCmd(f *flags.GlobalFlags) *cobra.Command {
sshCmd.Flags().
StringArrayVarP(&cmd.ForwardPorts, "forward-ports", "L", []string{},
"Specifies that connections to the given TCP port or Unix socket on the local (client) "+
"host are to be forwarded to the given host and port, or Unix socket, on the remote side.")
"host are to be forwarded to the given host, service name, and port, or Unix socket, on the remote side.")
sshCmd.Flags().
StringArrayVarP(&cmd.ReverseForwardPorts, "reverse-forward-ports", "R", []string{},
"Specifies that connections to the given TCP port or Unix socket on the local (client) "+
"host are to be reverse forwarded to the given host and port, or Unix socket, on the remote side.")
"Specifies that connections to the given TCP port or Unix socket on the remote side "+
"are to be reverse forwarded to the given local host, service name, and port, or Unix socket.")
Comment on lines 119 to +121
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Keep --reverse-forward-ports on the legacy parser/contract.

reverseForwardPorts now shares the expanded port.ParsePortSpec path, and the help text advertises service names for -R. That appears to broaden reverse forwarding even though the PR scope says hostname/service-name support is forward-only. Please route -R through a strict/direction-aware parser and keep the reverse help text aligned with the legacy behavior.

Possible structure
 type portForwardConfig struct {
 	mappings    []string
 	logTemplate string
 	forwardFn   portForwardFunc
+	parseFn     func(string) (port.Mapping, error)
 }

 func (cmd *SSHCmd) reverseForwardPorts(
 	ctx context.Context,
 	containerClient *ssh.Client,
 	log log.Logger,
 ) error {
 	return cmd.runPortForwards(ctx, containerClient, portForwardConfig{
 		mappings:    cmd.ReverseForwardPorts,
 		logTemplate: "Reverse forwarding remote %s/%s to local %s/%s",
 		forwardFn:   devssh.ReversePortForward,
+		parseFn:     port.ParseReversePortSpec, // strict parser preserving legacy -R behavior
 	}, log)
 }

 func (cmd *SSHCmd) forwardPorts(
 	ctx context.Context,
 	containerClient *ssh.Client,
 	log log.Logger,
 ) error {
 	return cmd.runPortForwards(ctx, containerClient, portForwardConfig{
 		mappings:    cmd.ForwardPorts,
 		logTemplate: "Forwarding local %s/%s to remote %s/%s",
 		forwardFn:   devssh.PortForward,
+		parseFn:     port.ParsePortSpec,
 	}, log)
 }

Also applies to: 405-409

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cmd/ssh.go` around lines 119 - 121, The reverse-forward flag registration
currently exposes expanded port.ParsePortSpec behavior and mentions service
names/hostnames; change cmd.ReverseForwardPorts handling to use the
legacy/direction-aware parser used for reverse bindings (route through the
strict reverse parser instead of port.ParsePortSpec) and update the flag
registration (StringArrayVarP for cmd.ReverseForwardPorts) help text to match
the original legacy semantics (remove service-name/hostname wording and limit to
the legacy reverse-forward syntax). Ensure the code that parses
cmd.ReverseForwardPorts calls the reverse-specific parsing function (e.g., the
legacy ParseReversePortSpec or equivalent direction-aware parser) and not the
generic port.ParsePortSpec so reverse forwarding remains restricted to the
original behavior.

sshCmd.Flags().
StringArrayVarP(&cmd.SendEnvVars, "send-env", "", []string{},
"Specifies which local env variables shall be sent to the container.")
Expand Down Expand Up @@ -400,89 +401,101 @@ func (cmd *SSHCmd) reverseForwardPorts(
containerClient *ssh.Client,
log log.Logger,
) error {
timeout, err := cmd.forwardTimeout(log)
if err != nil {
return fmt.Errorf("parse forward ports timeout: %w", err)
}

errChan := make(chan error, len(cmd.ReverseForwardPorts))
for _, portMapping := range cmd.ReverseForwardPorts {
mapping, err := port.ParsePortSpec(portMapping)
if err != nil {
return fmt.Errorf("parse port mapping: %w", err)
}

// start the forwarding
log.Infof(
"Reverse forwarding local %s/%s to remote %s/%s",
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
)
go func(portMapping string) {
err := devssh.ReversePortForward(
ctx,
containerClient,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
timeout,
log,
)
if !errors.Is(io.EOF, err) {
errChan <- fmt.Errorf("error forwarding %s: %w", portMapping, err)
}
}(portMapping)
}

return <-errChan
return cmd.runPortForwards(ctx, containerClient, portForwardConfig{
mappings: cmd.ReverseForwardPorts,
logTemplate: "Reverse forwarding remote %s/%s to local %s/%s",
forwardFn: devssh.ReversePortForward,
}, log)
}

func (cmd *SSHCmd) forwardPorts(
ctx context.Context,
containerClient *ssh.Client,
log log.Logger,
) error {
timeout, err := cmd.forwardTimeout(log)
return cmd.runPortForwards(ctx, containerClient, portForwardConfig{
mappings: cmd.ForwardPorts,
logTemplate: "Forwarding local %s/%s to remote %s/%s",
forwardFn: devssh.PortForward,
}, log)
}

type portForwardFunc func(
context.Context,
*ssh.Client,
string,
string,
string,
string,
time.Duration,
log.Logger,
) error

type portForwardConfig struct {
mappings []string
logTemplate string
forwardFn portForwardFunc
}

func (cmd *SSHCmd) runPortForwards(
ctx context.Context,
containerClient *ssh.Client,
config portForwardConfig,
logger log.Logger,
) error {
timeout, err := cmd.forwardTimeout(logger)
if err != nil {
return fmt.Errorf("parse forward ports timeout: %w", err)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

errChan := make(chan error, len(cmd.ForwardPorts))
for _, portMapping := range cmd.ForwardPorts {
errChan := make(chan error, len(config.mappings))
var waitGroup sync.WaitGroup
for _, portMapping := range config.mappings {
mapping, err := port.ParsePortSpec(portMapping)
if err != nil {
return fmt.Errorf("parse port mapping: %w", err)
}

// start the forwarding
log.Infof(
"Forwarding local %s/%s to remote %s/%s",
logger.Infof(
config.logTemplate,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
)
go func(portMapping string) {
err := devssh.PortForward(
waitGroup.Add(1)
go func(portMapping string, mapping port.Mapping) {
defer waitGroup.Done()

err := config.forwardFn(
ctx,
containerClient,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
timeout,
log,
logger,
)
if !errors.Is(io.EOF, err) {
if err != nil && !errors.Is(err, io.EOF) {
errChan <- fmt.Errorf("error forwarding %s: %w", portMapping, err)
}
}(portMapping)
}(portMapping, mapping)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

return <-errChan
go func() {
waitGroup.Wait()
close(errChan)
}()

for err := range errChan {
if err != nil {
return err
}
}

return nil
}
Comment on lines +463 to 524
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Cancel sibling forwarders before returning the first runtime error.

If one forwarding goroutine fails, Line 519 returns immediately while the other forwarders keep using the caller’s ctx and may keep listeners open. Use a child context, cancel it on the first real error, then drain the channel until the WaitGroup closes it.

Proposed cleanup on first error
 func (cmd *SSHCmd) runPortForwards(
 	ctx context.Context,
 	containerClient *ssh.Client,
 	config portForwardConfig,
 	logger log.Logger,
 ) error {
 	timeout, err := cmd.forwardTimeout(logger)
 	if err != nil {
 		return fmt.Errorf("parse forward ports timeout: %w", err)
 	}

 	parsedMappings, err := parsePortForwards(config.mappings)
 	if err != nil {
 		return err
 	}
+
+	forwardCtx, cancel := context.WithCancel(ctx)
+	defer cancel()

 	errChan := make(chan error, len(parsedMappings))
 	var waitGroup sync.WaitGroup
 	for _, parsedMapping := range parsedMappings {
 		portMapping, mapping := parsedMapping.spec, parsedMapping.mapping

@@
 			defer waitGroup.Done()

 			err := config.forwardFn(
-				ctx,
+				forwardCtx,
 				containerClient,
 				mapping.Host.Protocol,
 				mapping.Host.Address,
 				mapping.Container.Protocol,
@@
 		}(portMapping, mapping)
 	}

 	go func() {
 		waitGroup.Wait()
 		close(errChan)
 	}()

+	var firstErr error
 	for err := range errChan {
-		if err != nil {
-			return err
+		if err != nil && firstErr == nil {
+			firstErr = err
+			cancel()
 		}
 	}

-	return nil
+	return firstErr
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
func (cmd *SSHCmd) runPortForwards(
ctx context.Context,
containerClient *ssh.Client,
log log.Logger,
config portForwardConfig,
logger log.Logger,
) error {
timeout, err := cmd.forwardTimeout(log)
timeout, err := cmd.forwardTimeout(logger)
if err != nil {
return fmt.Errorf("parse forward ports timeout: %w", err)
}
errChan := make(chan error, len(cmd.ForwardPorts))
for _, portMapping := range cmd.ForwardPorts {
mapping, err := port.ParsePortSpec(portMapping)
if err != nil {
return fmt.Errorf("parse port mapping: %w", err)
}
parsedMappings, err := parsePortForwards(config.mappings)
if err != nil {
return err
}
errChan := make(chan error, len(parsedMappings))
var waitGroup sync.WaitGroup
for _, parsedMapping := range parsedMappings {
portMapping, mapping := parsedMapping.spec, parsedMapping.mapping
// start the forwarding
log.Infof(
"Forwarding local %s/%s to remote %s/%s",
logger.Infof(
config.logTemplate,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
)
go func(portMapping string) {
err := devssh.PortForward(
waitGroup.Add(1)
go func(portMapping string, mapping port.Mapping) {
defer waitGroup.Done()
err := config.forwardFn(
ctx,
containerClient,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
timeout,
log,
logger,
)
if !errors.Is(io.EOF, err) {
if err != nil && !errors.Is(err, io.EOF) {
errChan <- fmt.Errorf("error forwarding %s: %w", portMapping, err)
}
}(portMapping)
}(portMapping, mapping)
}
go func() {
waitGroup.Wait()
close(errChan)
}()
for err := range errChan {
if err != nil {
return err
}
}
return <-errChan
return nil
}
func (cmd *SSHCmd) runPortForwards(
ctx context.Context,
containerClient *ssh.Client,
config portForwardConfig,
logger log.Logger,
) error {
timeout, err := cmd.forwardTimeout(logger)
if err != nil {
return fmt.Errorf("parse forward ports timeout: %w", err)
}
parsedMappings, err := parsePortForwards(config.mappings)
if err != nil {
return err
}
forwardCtx, cancel := context.WithCancel(ctx)
defer cancel()
errChan := make(chan error, len(parsedMappings))
var waitGroup sync.WaitGroup
for _, parsedMapping := range parsedMappings {
portMapping, mapping := parsedMapping.spec, parsedMapping.mapping
// start the forwarding
logger.Infof(
config.logTemplate,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
)
waitGroup.Add(1)
go func(portMapping string, mapping port.Mapping) {
defer waitGroup.Done()
err := config.forwardFn(
forwardCtx,
containerClient,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
timeout,
logger,
)
if err != nil && !errors.Is(err, io.EOF) {
errChan <- fmt.Errorf("error forwarding %s: %w", portMapping, err)
}
}(portMapping, mapping)
}
go func() {
waitGroup.Wait()
close(errChan)
}()
var firstErr error
for err := range errChan {
if err != nil && firstErr == nil {
firstErr = err
cancel()
}
}
return firstErr
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cmd/ssh.go` around lines 463 - 524, Wrap the incoming ctx in a cancellable
child (ctx, cancel := context.WithCancel(ctx)) at the top of runPortForwards and
pass that child ctx into config.forwardFn so all forwarder goroutines see
cancellation; in the per-goroutine error handling, when you detect a real error
(err != nil && !errors.Is(err, io.EOF)) send the formatted error to errChan and
call cancel() to cancel sibling forwarders; in the receiving loop, capture the
first error but continue draining errChan until it is closed by the waitGroup
goroutine, then return the first non-nil error (ensuring cancel is
deferred/called to release resources when leaving).


func (cmd *SSHCmd) startTunnel(
Expand Down
29 changes: 29 additions & 0 deletions cmd/ssh_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package cmd

import (
"context"
"os"
"path/filepath"
"testing"
"time"

"github.com/skevetter/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)

func writeGitConfig(t *testing.T, content string) {
Expand Down Expand Up @@ -57,3 +61,28 @@ func TestGpgSigningKey_TildeKeyPath_Skipped(t *testing.T) {
result := gpgSigningKey(log.Discard)
assert.Empty(t, result)
}

func TestRunPortForwards_ReturnsOnCleanExit(t *testing.T) {
cmd := &SSHCmd{}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

err := cmd.runPortForwards(ctx, nil, portForwardConfig{
mappings: []string{"8080:80"},
logTemplate: "test %s/%s %s/%s",
forwardFn: func(
context.Context,
*ssh.Client,
string,
string,
string,
string,
time.Duration,
log.Logger,
) error {
return nil
},
}, log.Discard)

require.NoError(t, err)
}
15 changes: 15 additions & 0 deletions docs/pages/developing-in-workspaces/connect-to-a-workspace.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ Optionally you can also define a command to run:
devpod ssh my-workspace --command "echo Hello World"
```

You can also forward ports while using `devpod ssh`.
Target hosts can be service names or hostnames that are resolvable from the side that dials them.
For `--forward-ports`, that is the workspace.
For `--reverse-forward-ports`, that is the machine running `devpod ssh`.

To forward a local port to a host or service reachable from the workspace:
```
devpod ssh my-workspace --forward-ports 18080:nginx:8080
```

To expose a port inside the workspace that forwards back to a host or service reachable from the machine running `devpod ssh`:
```
devpod ssh my-workspace --reverse-forward-ports 15432:postgres.internal:5432
```

## IDE Commands

This section shows additional commands to configure DevPod's behavior when opening a workspace.
Expand Down
72 changes: 72 additions & 0 deletions e2e/tests/up-docker-compose/up_docker_compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"fmt"
"io"
"net"
"net/http"
"os"
"os/exec"
Expand Down Expand Up @@ -161,6 +162,77 @@ var _ = ginkgo.Describe(
))
}, ginkgo.SpecTimeout(framework.GetTimeout()))

ginkgo.It("ssh forward ports support remote service names", func(ctx context.Context) {
_, workspace, err := tc.setupAndStartWorkspace(
ctx,
"tests/up-docker-compose/testdata/docker-compose-forward-ports",
"--debug",
)
framework.ExpectNoError(err)

ids, err := findComposeContainer(
ctx,
tc.dockerHelper,
tc.composeHelper,
workspace.UID,
"app",
)
framework.ExpectNoError(err)
gomega.Expect(ids).To(gomega.HaveLen(1), "1 compose container to be created")

listener, err := net.Listen("tcp", "127.0.0.1:0")
framework.ExpectNoError(err)
localPort := listener.Addr().(*net.TCPAddr).Port
framework.ExpectNoError(listener.Close())
Comment on lines +183 to +186
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Dynamic port allocation has a minor TOCTOU window, but acceptable for e2e.

Closing the listener and reusing the port has a brief race where another process could grab it before devpod ssh binds. In practice this is fine for e2e and matches common Go test patterns, so no change required — just flagging it.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@e2e/tests/up-docker-compose/up_docker_compose.go` around lines 183 - 186, The
current pattern uses net.Listen to get a free port then closes the listener
(listener.Close()) and reuses localPort, which introduces a TOCTOU race; to fix
it eliminate the race by keeping the listener open and handing the active
listener (or its file descriptor) directly to the code that will serve SSH
rather than closing and re-binding — modify the code around net.Listen,
listener, localPort and the devpod ssh invocation so the produced listener is
passed through to the server routine (or use the listener's FD transfer) instead
of closing and re-opening the port.


done := make(chan error)
sshContext, sshCancel := context.WithCancel(context.Background())
go func() {
// #nosec G204 -- test command with controlled arguments
cmd := exec.CommandContext(
sshContext,
filepath.Join(tc.f.DevpodBinDir, tc.f.DevpodBinName),
"ssh",
"--forward-ports",
fmt.Sprintf("%d:nginx:8080", localPort),
workspace.ID,
)

if err := cmd.Start(); err != nil {
done <- err
return
}

if err := cmd.Wait(); err != nil {
done <- err
return
}

done <- nil
}()

gomega.Eventually(func(g gomega.Gomega) {
response, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d", localPort))
g.Expect(err).NotTo(gomega.HaveOccurred())
defer func() { _ = response.Body.Close() }()

body, err := io.ReadAll(response.Body)
g.Expect(err).NotTo(gomega.HaveOccurred())
g.Expect(body).To(gomega.ContainSubstring("Thank you for using nginx."))
}).
WithPolling(1 * time.Second).
WithTimeout(20 * time.Second).
Should(gomega.Succeed())

sshCancel()
err = <-done

gomega.Expect(err).To(gomega.Or(
gomega.MatchError("signal: killed"),
gomega.MatchError(context.Canceled),
))
}, ginkgo.SpecTimeout(framework.GetTimeout()))

ginkgo.It("features", func(ctx context.Context) {
tempDir, workspace, err := tc.setupAndStartWorkspace(
ctx,
Expand Down
Loading
Loading