Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 96 additions & 58 deletions cmd/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"os/exec"
"path"
"strings"
"sync"
"time"

"al.essio.dev/pkg/shellescape"
Expand Down Expand Up @@ -113,11 +114,11 @@ func NewSSHCmd(f *flags.GlobalFlags) *cobra.Command {
sshCmd.Flags().
StringArrayVarP(&cmd.ForwardPorts, "forward-ports", "L", []string{},
"Specifies that connections to the given TCP port or Unix socket on the local (client) "+
"host are to be forwarded to the given host and port, or Unix socket, on the remote side.")
"host are to be forwarded to the given host, service name, and port, or Unix socket, on the remote side.")
sshCmd.Flags().
StringArrayVarP(&cmd.ReverseForwardPorts, "reverse-forward-ports", "R", []string{},
"Specifies that connections to the given TCP port or Unix socket on the local (client) "+
"host are to be reverse forwarded to the given host and port, or Unix socket, on the remote side.")
"Specifies that connections to the given TCP port or Unix socket on the remote side "+
"are to be reverse forwarded to the given local host, service name, and port, or Unix socket.")
Comment on lines 119 to +121
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Keep --reverse-forward-ports on the legacy parser/contract.

reverseForwardPorts now shares the expanded port.ParsePortSpec path, and the help text advertises service names for -R. That appears to broaden reverse forwarding even though the PR scope says hostname/service-name support is forward-only. Please route -R through a strict/direction-aware parser and keep the reverse help text aligned with the legacy behavior.

Possible structure
 type portForwardConfig struct {
 	mappings    []string
 	logTemplate string
 	forwardFn   portForwardFunc
+	parseFn     func(string) (port.Mapping, error)
 }

 func (cmd *SSHCmd) reverseForwardPorts(
 	ctx context.Context,
 	containerClient *ssh.Client,
 	log log.Logger,
 ) error {
 	return cmd.runPortForwards(ctx, containerClient, portForwardConfig{
 		mappings:    cmd.ReverseForwardPorts,
 		logTemplate: "Reverse forwarding remote %s/%s to local %s/%s",
 		forwardFn:   devssh.ReversePortForward,
+		parseFn:     port.ParseReversePortSpec, // strict parser preserving legacy -R behavior
 	}, log)
 }

 func (cmd *SSHCmd) forwardPorts(
 	ctx context.Context,
 	containerClient *ssh.Client,
 	log log.Logger,
 ) error {
 	return cmd.runPortForwards(ctx, containerClient, portForwardConfig{
 		mappings:    cmd.ForwardPorts,
 		logTemplate: "Forwarding local %s/%s to remote %s/%s",
 		forwardFn:   devssh.PortForward,
+		parseFn:     port.ParsePortSpec,
 	}, log)
 }

Also applies to: 405-409

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cmd/ssh.go` around lines 119 - 121, The reverse-forward flag registration
currently exposes expanded port.ParsePortSpec behavior and mentions service
names/hostnames; change cmd.ReverseForwardPorts handling to use the
legacy/direction-aware parser used for reverse bindings (route through the
strict reverse parser instead of port.ParsePortSpec) and update the flag
registration (StringArrayVarP for cmd.ReverseForwardPorts) help text to match
the original legacy semantics (remove service-name/hostname wording and limit to
the legacy reverse-forward syntax). Ensure the code that parses
cmd.ReverseForwardPorts calls the reverse-specific parsing function (e.g., the
legacy ParseReversePortSpec or equivalent direction-aware parser) and not the
generic port.ParsePortSpec so reverse forwarding remains restricted to the
original behavior.

sshCmd.Flags().
StringArrayVarP(&cmd.SendEnvVars, "send-env", "", []string{},
"Specifies which local env variables shall be sent to the container.")
Expand Down Expand Up @@ -382,16 +383,17 @@ func (cmd *SSHCmd) jumpContainer(
}

func (cmd *SSHCmd) forwardTimeout(log log.Logger) (time.Duration, error) {
timeout := time.Duration(0)
if cmd.ForwardPortsTimeout != "" {
timeout, err := time.ParseDuration(cmd.ForwardPortsTimeout)
if err != nil {
return timeout, fmt.Errorf("parse forward ports timeout: %w", err)
}
if cmd.ForwardPortsTimeout == "" {
return 0, nil
}

log.Infof("Using port forwarding timeout of %s", cmd.ForwardPortsTimeout)
timeout, err := time.ParseDuration(cmd.ForwardPortsTimeout)
if err != nil {
return 0, fmt.Errorf("parse forward ports timeout: %w", err)
}

log.Infof("Using port forwarding timeout of %s", cmd.ForwardPortsTimeout)

return timeout, nil
}

Expand All @@ -400,89 +402,125 @@ func (cmd *SSHCmd) reverseForwardPorts(
containerClient *ssh.Client,
log log.Logger,
) error {
timeout, err := cmd.forwardTimeout(log)
if err != nil {
return fmt.Errorf("parse forward ports timeout: %w", err)
}
return cmd.runPortForwards(ctx, containerClient, portForwardConfig{
mappings: cmd.ReverseForwardPorts,
logTemplate: "Reverse forwarding remote %s/%s to local %s/%s",
forwardFn: devssh.ReversePortForward,
}, log)
}

errChan := make(chan error, len(cmd.ReverseForwardPorts))
for _, portMapping := range cmd.ReverseForwardPorts {
func (cmd *SSHCmd) forwardPorts(
ctx context.Context,
containerClient *ssh.Client,
log log.Logger,
) error {
return cmd.runPortForwards(ctx, containerClient, portForwardConfig{
mappings: cmd.ForwardPorts,
logTemplate: "Forwarding local %s/%s to remote %s/%s",
forwardFn: devssh.PortForward,
}, log)
}

type portForwardFunc func(
context.Context,
*ssh.Client,
string,
string,
string,
string,
time.Duration,
log.Logger,
) error

type portForwardConfig struct {
mappings []string
logTemplate string
forwardFn portForwardFunc
}

type parsedPortForward struct {
spec string
mapping port.Mapping
}

func parsePortForwards(mappings []string) ([]parsedPortForward, error) {
parsedMappings := make([]parsedPortForward, 0, len(mappings))
for _, portMapping := range mappings {
mapping, err := port.ParsePortSpec(portMapping)
if err != nil {
return fmt.Errorf("parse port mapping: %w", err)
return nil, fmt.Errorf("parse port mapping: %w", err)
}

// start the forwarding
log.Infof(
"Reverse forwarding local %s/%s to remote %s/%s",
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
)
go func(portMapping string) {
err := devssh.ReversePortForward(
ctx,
containerClient,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
timeout,
log,
)
if !errors.Is(io.EOF, err) {
errChan <- fmt.Errorf("error forwarding %s: %w", portMapping, err)
}
}(portMapping)
parsedMappings = append(parsedMappings, parsedPortForward{
spec: portMapping,
mapping: mapping,
})
}

return <-errChan
return parsedMappings, nil
}

func (cmd *SSHCmd) forwardPorts(
func (cmd *SSHCmd) runPortForwards(
ctx context.Context,
containerClient *ssh.Client,
log log.Logger,
config portForwardConfig,
logger log.Logger,
) error {
timeout, err := cmd.forwardTimeout(log)
timeout, err := cmd.forwardTimeout(logger)
if err != nil {
return fmt.Errorf("parse forward ports timeout: %w", err)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

errChan := make(chan error, len(cmd.ForwardPorts))
for _, portMapping := range cmd.ForwardPorts {
mapping, err := port.ParsePortSpec(portMapping)
if err != nil {
return fmt.Errorf("parse port mapping: %w", err)
}
parsedMappings, err := parsePortForwards(config.mappings)
if err != nil {
return err
}

errChan := make(chan error, len(parsedMappings))
var waitGroup sync.WaitGroup
for _, parsedMapping := range parsedMappings {
portMapping, mapping := parsedMapping.spec, parsedMapping.mapping

// start the forwarding
log.Infof(
"Forwarding local %s/%s to remote %s/%s",
logger.Infof(
config.logTemplate,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
)
go func(portMapping string) {
err := devssh.PortForward(
waitGroup.Add(1)
go func(portMapping string, mapping port.Mapping) {
defer waitGroup.Done()

err := config.forwardFn(
ctx,
containerClient,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
timeout,
log,
logger,
)
if !errors.Is(io.EOF, err) {
if err != nil && !errors.Is(err, io.EOF) {
errChan <- fmt.Errorf("error forwarding %s: %w", portMapping, err)
}
}(portMapping)
}(portMapping, mapping)
}

go func() {
waitGroup.Wait()
close(errChan)
}()

for err := range errChan {
if err != nil {
return err
}
}

return <-errChan
return nil
}
Comment on lines +463 to 524
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Cancel sibling forwarders before returning the first runtime error.

If one forwarding goroutine fails, Line 519 returns immediately while the other forwarders keep using the caller’s ctx and may keep listeners open. Use a child context, cancel it on the first real error, then drain the channel until the WaitGroup closes it.

Proposed cleanup on first error
 func (cmd *SSHCmd) runPortForwards(
 	ctx context.Context,
 	containerClient *ssh.Client,
 	config portForwardConfig,
 	logger log.Logger,
 ) error {
 	timeout, err := cmd.forwardTimeout(logger)
 	if err != nil {
 		return fmt.Errorf("parse forward ports timeout: %w", err)
 	}

 	parsedMappings, err := parsePortForwards(config.mappings)
 	if err != nil {
 		return err
 	}
+
+	forwardCtx, cancel := context.WithCancel(ctx)
+	defer cancel()

 	errChan := make(chan error, len(parsedMappings))
 	var waitGroup sync.WaitGroup
 	for _, parsedMapping := range parsedMappings {
 		portMapping, mapping := parsedMapping.spec, parsedMapping.mapping

@@
 			defer waitGroup.Done()

 			err := config.forwardFn(
-				ctx,
+				forwardCtx,
 				containerClient,
 				mapping.Host.Protocol,
 				mapping.Host.Address,
 				mapping.Container.Protocol,
@@
 		}(portMapping, mapping)
 	}

 	go func() {
 		waitGroup.Wait()
 		close(errChan)
 	}()

+	var firstErr error
 	for err := range errChan {
-		if err != nil {
-			return err
+		if err != nil && firstErr == nil {
+			firstErr = err
+			cancel()
 		}
 	}

-	return nil
+	return firstErr
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
func (cmd *SSHCmd) runPortForwards(
ctx context.Context,
containerClient *ssh.Client,
log log.Logger,
config portForwardConfig,
logger log.Logger,
) error {
timeout, err := cmd.forwardTimeout(log)
timeout, err := cmd.forwardTimeout(logger)
if err != nil {
return fmt.Errorf("parse forward ports timeout: %w", err)
}
errChan := make(chan error, len(cmd.ForwardPorts))
for _, portMapping := range cmd.ForwardPorts {
mapping, err := port.ParsePortSpec(portMapping)
if err != nil {
return fmt.Errorf("parse port mapping: %w", err)
}
parsedMappings, err := parsePortForwards(config.mappings)
if err != nil {
return err
}
errChan := make(chan error, len(parsedMappings))
var waitGroup sync.WaitGroup
for _, parsedMapping := range parsedMappings {
portMapping, mapping := parsedMapping.spec, parsedMapping.mapping
// start the forwarding
log.Infof(
"Forwarding local %s/%s to remote %s/%s",
logger.Infof(
config.logTemplate,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
)
go func(portMapping string) {
err := devssh.PortForward(
waitGroup.Add(1)
go func(portMapping string, mapping port.Mapping) {
defer waitGroup.Done()
err := config.forwardFn(
ctx,
containerClient,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
timeout,
log,
logger,
)
if !errors.Is(io.EOF, err) {
if err != nil && !errors.Is(err, io.EOF) {
errChan <- fmt.Errorf("error forwarding %s: %w", portMapping, err)
}
}(portMapping)
}(portMapping, mapping)
}
go func() {
waitGroup.Wait()
close(errChan)
}()
for err := range errChan {
if err != nil {
return err
}
}
return <-errChan
return nil
}
func (cmd *SSHCmd) runPortForwards(
ctx context.Context,
containerClient *ssh.Client,
config portForwardConfig,
logger log.Logger,
) error {
timeout, err := cmd.forwardTimeout(logger)
if err != nil {
return fmt.Errorf("parse forward ports timeout: %w", err)
}
parsedMappings, err := parsePortForwards(config.mappings)
if err != nil {
return err
}
forwardCtx, cancel := context.WithCancel(ctx)
defer cancel()
errChan := make(chan error, len(parsedMappings))
var waitGroup sync.WaitGroup
for _, parsedMapping := range parsedMappings {
portMapping, mapping := parsedMapping.spec, parsedMapping.mapping
// start the forwarding
logger.Infof(
config.logTemplate,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
)
waitGroup.Add(1)
go func(portMapping string, mapping port.Mapping) {
defer waitGroup.Done()
err := config.forwardFn(
forwardCtx,
containerClient,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
timeout,
logger,
)
if err != nil && !errors.Is(err, io.EOF) {
errChan <- fmt.Errorf("error forwarding %s: %w", portMapping, err)
}
}(portMapping, mapping)
}
go func() {
waitGroup.Wait()
close(errChan)
}()
var firstErr error
for err := range errChan {
if err != nil && firstErr == nil {
firstErr = err
cancel()
}
}
return firstErr
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cmd/ssh.go` around lines 463 - 524, Wrap the incoming ctx in a cancellable
child (ctx, cancel := context.WithCancel(ctx)) at the top of runPortForwards and
pass that child ctx into config.forwardFn so all forwarder goroutines see
cancellation; in the per-goroutine error handling, when you detect a real error
(err != nil && !errors.Is(err, io.EOF)) send the formatted error to errChan and
call cancel() to cancel sibling forwarders; in the receiving loop, capture the
first error but continue draining errChan until it is closed by the waitGroup
goroutine, then return the first non-nil error (ensuring cancel is
deferred/called to release resources when leaving).


func (cmd *SSHCmd) startTunnel(
Expand Down
Loading