diff --git a/cmd/up.go b/cmd/up.go index 1ca38fcdf..1aed3c510 100644 --- a/cmd/up.go +++ b/cmd/up.go @@ -1009,6 +1009,7 @@ func configureSSH(client client2.BaseWorkspaceClient, params configureSSHParams) Workdir: params.workdir, GPGAgent: params.gpgagent, DevPodHome: params.devPodHome, + Provider: client.Provider(), Log: log.Default, }) if err != nil { diff --git a/pkg/ssh/config.go b/pkg/ssh/config.go index fe3029ba6..c727229e0 100644 --- a/pkg/ssh/config.go +++ b/pkg/ssh/config.go @@ -33,6 +33,7 @@ type SSHConfigParams struct { Command string GPGAgent bool DevPodHome string + Provider string Log log.Logger } @@ -45,7 +46,18 @@ func ConfigureSSHConfig(params SSHConfigParams) error { targetPath = params.SSHConfigIncludePath } - newFile, err := addHost(targetPath, params.Workspace+"."+"devpod", params.User, params.Context, params.Workspace, params.Workdir, params.Command, params.GPGAgent, params.DevPodHome) + newFile, err := addHost(addHostParams{ + path: targetPath, + host: params.Workspace + "." + "devpod", + user: params.User, + context: params.Context, + workspace: params.Workspace, + workdir: params.Workdir, + command: params.Command, + gpgagent: params.GPGAgent, + devPodHome: params.DevPodHome, + provider: params.Provider, + }) if err != nil { return fmt.Errorf("parse ssh config %w", err) } @@ -59,8 +71,21 @@ type DevPodSSHEntry struct { Workspace string } -func addHost(path, host, user, context, workspace, workdir, command string, gpgagent bool, devPodHome string) (string, error) { - newConfig, err := removeFromConfig(path, host) +type addHostParams struct { + path string + host string + user string + context string + workspace string + workdir string + command string + gpgagent bool + devPodHome string + provider string +} + +func addHost(params addHostParams) (string, error) { + newConfig, err := removeFromConfig(params.path, params.host) if err != nil { return "", err } @@ -71,58 +96,138 @@ func addHost(path, host, user, context, workspace, workdir, command string, gpga return "", err } - return addHostSection(newConfig, execPath, host, user, context, workspace, workdir, command, gpgagent, devPodHome) + return addHostSection(newConfig, execPath, params) } -func addHostSection(config, execPath, host, user, context, workspace, workdir, command string, gpgagent bool, devPodHome string) (string, error) { - newLines := []string{} - // add new section - startMarker := MarkerStartPrefix + host - endMarker := MarkerEndPrefix + host - newLines = append(newLines, startMarker) - newLines = append(newLines, "Host "+host) - newLines = append(newLines, " ForwardAgent yes") - newLines = append(newLines, " LogLevel error") - newLines = append(newLines, " StrictHostKeyChecking no") - newLines = append(newLines, " UserKnownHostsFile /dev/null") - newLines = append(newLines, " HostKeyAlgorithms rsa-sha2-256,rsa-sha2-512,ssh-rsa") - - proxyCommand := "" - if command != "" { - proxyCommand = fmt.Sprintf(" ProxyCommand \"%s\"", command) - } else { - proxyCommand = fmt.Sprintf(" ProxyCommand \"%s\" ssh --stdio --context %s --user %s %s", execPath, context, user, workspace) +// proxyCommandBuilder builds SSH ProxyCommand strings +type proxyCommandBuilder struct { + baseCommand string + options []string +} + +func newProxyCommandBuilder(execPath, context, user, workspace string) *proxyCommandBuilder { + return &proxyCommandBuilder{ + baseCommand: fmt.Sprintf("\"%s\" ssh --stdio --context %s --user %s %s", execPath, context, user, workspace), } +} - if devPodHome != "" { - proxyCommand = fmt.Sprintf("%s --devpod-home \"%s\"", proxyCommand, devPodHome) +func (b *proxyCommandBuilder) withDevPodHome(home string) *proxyCommandBuilder { + if home != "" { + b.options = append(b.options, fmt.Sprintf("--devpod-home \"%s\"", home)) } + return b +} + +func (b *proxyCommandBuilder) withWorkdir(workdir string) *proxyCommandBuilder { if workdir != "" { - proxyCommand = fmt.Sprintf("%s --workdir \"%s\"", proxyCommand, workdir) + b.options = append(b.options, fmt.Sprintf("--workdir \"%s\"", workdir)) + } + return b +} + +func (b *proxyCommandBuilder) withGPGAgent(enabled bool) *proxyCommandBuilder { + if enabled { + b.options = append(b.options, "--gpg-agent-forwarding") + } + return b +} + +func (b *proxyCommandBuilder) build() string { + if len(b.options) == 0 { + return " ProxyCommand " + b.baseCommand + } + return fmt.Sprintf(" ProxyCommand %s %s", b.baseCommand, strings.Join(b.options, " ")) +} + +// sshConfigBuilder builds SSH config entries +type sshConfigBuilder struct { + lines []string +} + +func newSSHConfigBuilder(host string) *sshConfigBuilder { + return &sshConfigBuilder{ + lines: []string{ + MarkerStartPrefix + host, + "Host " + host, + }, + } +} + +func (b *sshConfigBuilder) addSSHOptions(provider string) *sshConfigBuilder { + b.lines = append(b.lines, + " ForwardAgent yes", + " LogLevel error", + " StrictHostKeyChecking no", + " UserKnownHostsFile /dev/null", + " HostKeyAlgorithms rsa-sha2-256,rsa-sha2-512,ssh-rsa", + ) + + // TODO: Make SSH timeout configurable per provider via provider options + // The ms-vscode-remote.remote-ssh extension times out after 15s by default + // This is insufficient for the aws AWS provider as it needs additional time to + // connect to the instance + // + // The SSH config ConnectTimeout overrides the VSCode Remote-SSH remote.SSH.connectTimeout setting + // https://github.com/microsoft/vscode-remote-release/issues/8519 + if strings.Contains(provider, "aws") { + b.lines = append(b.lines, " ConnectTimeout 60") } - if gpgagent { - proxyCommand = fmt.Sprintf("%s --gpg-agent-forwarding", proxyCommand) + + return b +} + +func (b *sshConfigBuilder) addProxyCommand(proxyCmd string) *sshConfigBuilder { + b.lines = append(b.lines, proxyCmd) + return b +} + +func (b *sshConfigBuilder) addUser(user, host string) *sshConfigBuilder { + b.lines = append(b.lines, " User "+user, MarkerEndPrefix+host) + return b +} + +func (b *sshConfigBuilder) build() []string { + return b.lines +} + +// buildProxyCommand creates the ProxyCommand string +func buildProxyCommand(execPath string, params addHostParams) string { + if params.command != "" { + return fmt.Sprintf(" ProxyCommand \"%s\"", params.command) } - newLines = append(newLines, proxyCommand) - newLines = append(newLines, " User "+user) - newLines = append(newLines, endMarker) - // now we append the original config - // keep our blocks on top of the hosts for priority reasons, but below any includes + return newProxyCommandBuilder(execPath, params.context, params.user, params.workspace). + withDevPodHome(params.devPodHome). + withWorkdir(params.workdir). + withGPGAgent(params.gpgagent). + build() +} + +// buildSSHConfigLines creates the SSH config entry lines +func buildSSHConfigLines(params addHostParams, proxyCmd string) []string { + return newSSHConfigBuilder(params.host). + addSSHOptions(params.provider). + addProxyCommand(proxyCmd). + addUser(params.user, params.host). + build() +} + +// findInsertPosition finds where to insert new SSH config entry +func findInsertPosition(config string) (int, []string, error) { lineNumber := 0 found := false lines := []string{} commentLines := 0 + scanner := bufio.NewScanner(strings.NewReader(config)) for scanner.Scan() { line := scanner.Text() - // Check `Host` keyword + if strings.HasPrefix(strings.TrimSpace(line), "Host") && !found { found = true lineNumber = max(lineNumber-commentLines, 0) } - // Preserve comments if strings.HasPrefix(strings.TrimSpace(line), "#") { commentLines++ } else { @@ -135,18 +240,36 @@ func addHostSection(config, execPath, host, user, context, workspace, workdir, c lines = append(lines, line) } + if err := scanner.Err(); err != nil { - return config, err + return 0, nil, err } - lines = slices.Insert(lines, lineNumber, newLines...) + return lineNumber, lines, nil +} + +// mergeSSHConfig inserts new lines into existing config +func mergeSSHConfig(lines, newLines []string, position int) string { + merged := slices.Insert(lines, position, newLines...) newLineSep := "\n" if runtime.GOOS == "windows" { newLineSep = "\r\n" } - return strings.Join(lines, newLineSep), nil + return strings.Join(merged, newLineSep) +} + +func addHostSection(config, execPath string, params addHostParams) (string, error) { + proxyCmd := buildProxyCommand(execPath, params) + newLines := buildSSHConfigLines(params, proxyCmd) + + position, lines, err := findInsertPosition(config) + if err != nil { + return config, err + } + + return mergeSSHConfig(lines, newLines, position), nil } func GetUser(workspaceID string, sshConfigPath string, sshConfigIncludePath string) (string, error) { diff --git a/pkg/ssh/config_test.go b/pkg/ssh/config_test.go index 0ffbde23e..f7752cc4d 100644 --- a/pkg/ssh/config_test.go +++ b/pkg/ssh/config_test.go @@ -1,14 +1,21 @@ package ssh import ( - "fmt" - "strings" "testing" - "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestAddHostSection(t *testing.T) { +type SSHConfigTestSuite struct { + suite.Suite +} + +func TestSSHConfigSuite(t *testing.T) { + suite.Run(t, new(SSHConfigTestSuite)) +} + +func (s *SSHConfigTestSuite) TestAddHostSection() { tests := []struct { name string config string @@ -21,6 +28,7 @@ func TestAddHostSection(t *testing.T) { command string gpgagent bool devPodHome string + provider string expected string }{ { @@ -35,6 +43,31 @@ func TestAddHostSection(t *testing.T) { command: "", gpgagent: false, devPodHome: "", + provider: "", + expected: `# DevPod Start testhost +Host testhost + ForwardAgent yes + LogLevel error + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + HostKeyAlgorithms rsa-sha2-256,rsa-sha2-512,ssh-rsa + ProxyCommand "/path/to/exec" ssh --stdio --context testcontext --user testuser testworkspace + User testuser +# DevPod End testhost`, + }, + { + name: "AWS provider with ConnectTimeout", + config: "", + execPath: "/path/to/exec", + host: "testhost", + user: "testuser", + context: "testcontext", + workspace: "testworkspace", + workdir: "", + command: "", + gpgagent: false, + devPodHome: "", + provider: "aws", expected: `# DevPod Start testhost Host testhost ForwardAgent yes @@ -42,6 +75,7 @@ Host testhost StrictHostKeyChecking no UserKnownHostsFile /dev/null HostKeyAlgorithms rsa-sha2-256,rsa-sha2-512,ssh-rsa + ConnectTimeout 60 ProxyCommand "/path/to/exec" ssh --stdio --context testcontext --user testuser testworkspace User testuser # DevPod End testhost`, @@ -58,6 +92,7 @@ Host testhost command: "", gpgagent: false, devPodHome: "C:\\\\White Space\\devpod\\test", + provider: "", expected: `# DevPod Start testhost Host testhost ForwardAgent yes @@ -81,6 +116,7 @@ Host testhost command: "", gpgagent: false, devPodHome: "", + provider: "", expected: `# DevPod Start testhost Host testhost ForwardAgent yes @@ -104,6 +140,7 @@ Host testhost command: "", gpgagent: true, devPodHome: "", + provider: "", expected: `# DevPod Start testhost Host testhost ForwardAgent yes @@ -127,6 +164,7 @@ Host testhost command: "ssh -W %h:%p bastion", gpgagent: false, devPodHome: "", + provider: "", expected: `# DevPod Start testhost Host testhost ForwardAgent yes @@ -151,6 +189,7 @@ Host testhost command: "", gpgagent: false, devPodHome: "", + provider: "", expected: `# DevPod Start testhost Host testhost ForwardAgent yes @@ -188,6 +227,7 @@ Host existinghost command: "", gpgagent: false, devPodHome: "", + provider: "", expected: `# DevPod Start testhost Host testhost ForwardAgent yes @@ -230,6 +270,7 @@ Include ~/config3`, command: "", gpgagent: false, devPodHome: "", + provider: "", expected: `Include ~/config1 Include ~/config2 @@ -251,43 +292,40 @@ Host testhost } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := addHostSection(tt.config, tt.execPath, tt.host, tt.user, tt.context, tt.workspace, tt.workdir, tt.command, tt.gpgagent, tt.devPodHome) - if err != nil { - t.Errorf("Failed with err: %v", err) - } - - if result != tt.expected { - t.Errorf("addHostSection result does not match expected.\nGot:\n%s\nExpected:\n%s", result, tt.expected) - t.Errorf("addHostSection result does not match expected:\n%s", cmp.Diff(result, tt.expected)) - } - - if !strings.Contains(result, MarkerEndPrefix+tt.host) { - t.Errorf("Result does not contain the end marker: %s", MarkerEndPrefix+tt.host) - } - - if !strings.Contains(result, "Host "+tt.host) { - t.Errorf("Result does not contain the Host line: Host %s", tt.host) - } + s.Run(tt.name, func() { + result, err := addHostSection(tt.config, tt.execPath, addHostParams{ + path: "", + host: tt.host, + user: tt.user, + context: tt.context, + workspace: tt.workspace, + workdir: tt.workdir, + command: tt.command, + gpgagent: tt.gpgagent, + devPodHome: tt.devPodHome, + provider: tt.provider, + }) - if !strings.Contains(result, "User "+tt.user) { - t.Errorf("Result does not contain the User line: User %s", tt.user) - } + assert.NoError(s.T(), err) + assert.Equal(s.T(), tt.expected, result) + assert.Contains(s.T(), result, MarkerEndPrefix+tt.host) + assert.Contains(s.T(), result, "Host "+tt.host) + assert.Contains(s.T(), result, "User "+tt.user) - if tt.command != "" && !strings.Contains(result, fmt.Sprintf("ProxyCommand \"%s\"", tt.command)) { - t.Errorf("Result does not contain the custom ProxyCommand: %s", tt.command) + if tt.command != "" { + assert.Contains(s.T(), result, "ProxyCommand \""+tt.command+"\"") } - if tt.workdir != "" && !strings.Contains(result, fmt.Sprintf("--workdir \"%s\"", tt.workdir)) { - t.Errorf("Result does not contain the workdir: %s", tt.workdir) + if tt.workdir != "" { + assert.Contains(s.T(), result, "--workdir \""+tt.workdir+"\"") } - if tt.gpgagent && !strings.Contains(result, "--gpg-agent-forwarding") { - t.Errorf("Result does not contain gpg-agent-forwarding flag") + if tt.gpgagent { + assert.Contains(s.T(), result, "--gpg-agent-forwarding") } - if tt.config != "" && !strings.Contains(result, tt.config) { - t.Errorf("Result does not contain the original config") + if tt.config != "" { + assert.Contains(s.T(), result, tt.config) } }) }