diff --git a/.github/workflows/pr-ci.yml b/.github/workflows/pr-ci.yml index 4a6abd7f7..44b13ee27 100644 --- a/.github/workflows/pr-ci.yml +++ b/.github/workflows/pr-ci.yml @@ -366,7 +366,7 @@ jobs: KUBECONFIG="${KUBECONFIG:-$HOME/.kube/config}" \ PATH="${PATH}" \ GOROOT="${GOROOT}" \ - go test -v -ginkgo.v -timeout 3600s --ginkgo.label-filter="${{ matrix.label }}" + go test -v -ginkgo.v -timeout 900s --ginkgo.label-filter="${{ matrix.label }}" else GH_USERNAME="${GH_USERNAME}" \ GH_ACCESS_TOKEN="${GH_ACCESS_TOKEN}" \ @@ -374,7 +374,7 @@ jobs: PATH="${PATH}" \ GOROOT="${GOROOT}" \ DOCKER_HOST="npipe:////./pipe/podman-machine-default" \ - go test -v -ginkgo.v -timeout 3600s --ginkgo.label-filter="${{ matrix.label }}" + go test -v -ginkgo.v -timeout 900s --ginkgo.label-filter="${{ matrix.label }}" fi - name: verify docker is installed diff --git a/cmd/agent/workspace/build.go b/cmd/agent/workspace/build.go index db408be86..1f29baf32 100644 --- a/cmd/agent/workspace/build.go +++ b/cmd/agent/workspace/build.go @@ -5,7 +5,6 @@ import ( "fmt" "os" - "github.com/sirupsen/logrus" "github.com/skevetter/devpod/cmd/flags" "github.com/skevetter/devpod/pkg/agent" provider2 "github.com/skevetter/devpod/pkg/provider" @@ -57,7 +56,12 @@ func (cmd *BuildCmd) Run(ctx context.Context) error { // initialize the workspace cancelCtx, cancel := context.WithCancel(ctx) defer cancel() - _, logger, credentialsDir, err := initWorkspace(cancelCtx, cancel, workspaceInfo, cmd.Debug, false) + _, logger, credentialsDir, err := initWorkspace(initWorkspaceParams{ + ctx: cancelCtx, + workspaceInfo: workspaceInfo, + debug: cmd.Debug, + shouldInstallDaemon: false, + }) if err != nil { return err } else if credentialsDir != "" { @@ -93,15 +97,9 @@ func (cmd *BuildCmd) Run(ctx context.Context) error { } if workspaceInfo.CLIOptions.SkipPush { - logger.WithFields(logrus.Fields{ - "imageName": imageName, - }) - logger.Donef("done building image") + logger.Donef("done building image %s", imageName) } else { - logger.WithFields(logrus.Fields{ - "imageName": imageName, - }) - logger.Donef("done building and pushing image") + logger.Donef("done building and pushing image %s", imageName) } } diff --git a/cmd/agent/workspace/up.go b/cmd/agent/workspace/up.go index 233c474ba..bd46f50ca 100644 --- a/cmd/agent/workspace/up.go +++ b/cmd/agent/workspace/up.go @@ -16,7 +16,6 @@ import ( "github.com/skevetter/devpod/pkg/agent" "github.com/skevetter/devpod/pkg/agent/tunnel" "github.com/skevetter/devpod/pkg/agent/tunnelserver" - "github.com/skevetter/devpod/pkg/binaries" "github.com/skevetter/devpod/pkg/client/clientimplementation" "github.com/skevetter/devpod/pkg/command" "github.com/skevetter/devpod/pkg/credentials" @@ -27,7 +26,7 @@ import ( "github.com/skevetter/devpod/pkg/dockercredentials" "github.com/skevetter/devpod/pkg/dockerinstall" "github.com/skevetter/devpod/pkg/extract" - provider2 "github.com/skevetter/devpod/pkg/provider" + "github.com/skevetter/devpod/pkg/provider" "github.com/skevetter/devpod/pkg/util" "github.com/skevetter/log" "github.com/spf13/cobra" @@ -75,13 +74,12 @@ func (cmd *UpCmd) Run(ctx context.Context) error { cancelCtx, cancel := context.WithCancel(ctx) defer cancel() - tunnelClient, logger, credentialsDir, err := initWorkspace( - cancelCtx, - cancel, - workspaceInfo, - cmd.Debug, - cmd.shouldInstallDaemon(workspaceInfo), - ) + tunnelClient, logger, credentialsDir, err := initWorkspace(initWorkspaceParams{ + ctx: cancelCtx, + workspaceInfo: workspaceInfo, + debug: cmd.Debug, + shouldInstallDaemon: cmd.shouldInstallDaemon(workspaceInfo), + }) defer cmd.cleanupCredentials(credentialsDir) if err != nil { return cmd.handleInitError(err, workspaceInfo, logger) @@ -94,10 +92,10 @@ func (cmd *UpCmd) Run(ctx context.Context) error { return nil } -func (cmd *UpCmd) loadWorkspaceInfo(ctx context.Context) (*provider2.AgentWorkspaceInfo, error) { +func (cmd *UpCmd) loadWorkspaceInfo(ctx context.Context) (*provider.AgentWorkspaceInfo, error) { shouldExit, workspaceInfo, err := agent.WriteWorkspaceInfoAndDeleteOld( cmd.WorkspaceInfo, - func(workspaceInfo *provider2.AgentWorkspaceInfo, log log.Logger) error { + func(workspaceInfo *provider.AgentWorkspaceInfo, log log.Logger) error { return deleteWorkspace(ctx, workspaceInfo, log) }, log.Default.ErrorStreamOnly(), @@ -111,15 +109,15 @@ func (cmd *UpCmd) loadWorkspaceInfo(ctx context.Context) (*provider2.AgentWorksp return workspaceInfo, nil } -func (cmd *UpCmd) shouldPreventDaemonShutdown(workspaceInfo *provider2.AgentWorkspaceInfo) bool { +func (cmd *UpCmd) shouldPreventDaemonShutdown(workspaceInfo *provider.AgentWorkspaceInfo) bool { return !workspaceInfo.CLIOptions.Platform.Enabled } -func (cmd *UpCmd) shouldInstallDaemon(workspaceInfo *provider2.AgentWorkspaceInfo) bool { +func (cmd *UpCmd) shouldInstallDaemon(workspaceInfo *provider.AgentWorkspaceInfo) bool { return !workspaceInfo.CLIOptions.Platform.Enabled && !workspaceInfo.CLIOptions.DisableDaemon } -func (cmd *UpCmd) handleInitError(err error, workspaceInfo *provider2.AgentWorkspaceInfo, logger log.Logger) error { +func (cmd *UpCmd) handleInitError(err error, workspaceInfo *provider.AgentWorkspaceInfo, logger log.Logger) error { if logger == nil { logger = log.Discard } @@ -141,7 +139,12 @@ func (cmd *UpCmd) cleanupCredentials(credentialsDir string) { } } -func (cmd *UpCmd) up(ctx context.Context, workspaceInfo *provider2.AgentWorkspaceInfo, tunnelClient tunnel.TunnelClient, logger log.Logger) error { +func (cmd *UpCmd) up( + ctx context.Context, + workspaceInfo *provider.AgentWorkspaceInfo, + tunnelClient tunnel.TunnelClient, + logger log.Logger, +) error { result, err := cmd.devPodUp(ctx, workspaceInfo, logger) if err != nil { return err @@ -164,7 +167,11 @@ func (cmd *UpCmd) sendResult(ctx context.Context, result *config2.Result, tunnel return nil } -func (cmd *UpCmd) devPodUp(ctx context.Context, workspaceInfo *provider2.AgentWorkspaceInfo, log log.Logger) (*config2.Result, error) { +func (cmd *UpCmd) devPodUp( + ctx context.Context, + workspaceInfo *provider.AgentWorkspaceInfo, + log log.Logger, +) (*config2.Result, error) { runner, err := CreateRunner(workspaceInfo, log) if err != nil { return nil, err @@ -176,11 +183,11 @@ func (cmd *UpCmd) devPodUp(ctx context.Context, workspaceInfo *provider2.AgentWo }, workspaceInfo.InjectTimeout) } -func CreateRunner(workspaceInfo *provider2.AgentWorkspaceInfo, log log.Logger) (devcontainer.Runner, error) { +func CreateRunner(workspaceInfo *provider.AgentWorkspaceInfo, log log.Logger) (devcontainer.Runner, error) { return devcontainer.NewRunner(agent.ContainerDevPodHelperLocation, agent.DefaultAgentDownloadURL(), workspaceInfo, log) } -func InitContentFolder(workspaceInfo *provider2.AgentWorkspaceInfo, log log.Logger) (bool, error) { +func InitContentFolder(workspaceInfo *provider.AgentWorkspaceInfo, log log.Logger) (bool, error) { exists, err := contentFolderExists(workspaceInfo.ContentFolder, log) if err != nil { return false, err @@ -222,13 +229,13 @@ func contentFolderExists(path string, log log.Logger) (bool, error) { func createContentFolder(path string, log log.Logger) error { log.WithFields(logrus.Fields{"path": path}).Debug("create content folder") - if err := os.MkdirAll(path, 0o777); err != nil { + if err := os.MkdirAll(path, 0o750); err != nil { return fmt.Errorf("make workspace folder: %w", err) } return nil } -func downloadWorkspaceBinaries(workspaceInfo *provider2.AgentWorkspaceInfo, log log.Logger) error { +func downloadWorkspaceBinaries(workspaceInfo *provider.AgentWorkspaceInfo, log log.Logger) error { binariesDir, err := agent.GetAgentBinariesDir( workspaceInfo.Agent.DataPath, workspaceInfo.Workspace.Context, @@ -238,7 +245,7 @@ func downloadWorkspaceBinaries(workspaceInfo *provider2.AgentWorkspaceInfo, log return fmt.Errorf("error getting workspace %s binaries dir: %w", workspaceInfo.Workspace.ID, err) } - _, err = binaries.DownloadBinaries(workspaceInfo.Agent.Binaries, binariesDir, log) + _, err = provider.DownloadBinaries(workspaceInfo.Agent.Binaries, binariesDir, log) if err != nil { return fmt.Errorf("error downloading workspace %s binaries: %w", workspaceInfo.Workspace.ID, err) } @@ -248,8 +255,7 @@ func downloadWorkspaceBinaries(workspaceInfo *provider2.AgentWorkspaceInfo, log type workspaceInitializer struct { ctx context.Context - cancel context.CancelFunc - workspaceInfo *provider2.AgentWorkspaceInfo + workspaceInfo *provider.AgentWorkspaceInfo debug bool shouldInstallDaemon bool tunnelClient tunnel.TunnelClient @@ -258,49 +264,73 @@ type workspaceInitializer struct { gitCredentialsHelper string } -func initWorkspace(ctx context.Context, cancel context.CancelFunc, workspaceInfo *provider2.AgentWorkspaceInfo, debug, shouldInstallDaemon bool) (tunnel.TunnelClient, log.Logger, string, error) { +type initWorkspaceParams struct { + ctx context.Context + workspaceInfo *provider.AgentWorkspaceInfo + debug bool + shouldInstallDaemon bool +} + +func initWorkspace(params initWorkspaceParams) (tunnel.TunnelClient, log.Logger, string, error) { init := &workspaceInitializer{ - ctx: ctx, - cancel: cancel, - workspaceInfo: workspaceInfo, - debug: debug, - shouldInstallDaemon: shouldInstallDaemon, + ctx: params.ctx, + workspaceInfo: params.workspaceInfo, + debug: params.debug, + shouldInstallDaemon: params.shouldInstallDaemon, } - if err := init.initializeTunnel(); err != nil { - return nil, nil, "", err + if err := init.initialize(); err != nil { + return nil, init.logger, init.dockerCredentialsDir, err } - if err := init.setupCredentials(); err != nil { - init.logger.Errorf("error retrieving docker / git credentials: %v", err) + return init.tunnelClient, init.logger, init.dockerCredentialsDir, nil +} + +func (w *workspaceInitializer) initialize() error { + if err := w.initializeTunnel(); err != nil { + return err } - dockerErrChan := init.installDockerAsync() + if err := w.setupCredentials(); err != nil { + w.logger.Warnf("failed to set up docker/git credentials (continuing without them): %v", err) + } - if err := init.prepareWorkspaceContent(); err != nil { - return nil, init.logger, init.dockerCredentialsDir, err + dockerErrChan := w.installDockerAsync() + + if err := w.prepareWorkspaceContent(); err != nil { + return err } - if init.shouldInstallDaemon { - if err := installDaemon(init.workspaceInfo, init.logger); err != nil { - init.logger.Errorf("install DevPod daemon: %v", err) - } + w.setupDaemonIfNeeded() + + if err := w.waitForDocker(dockerErrChan); err != nil { + return err } - if err := init.waitForDocker(dockerErrChan); err != nil { - return nil, nil, init.dockerCredentialsDir, err + w.tryConfigureDockerDaemon() + return nil +} + +func (w *workspaceInitializer) setupDaemonIfNeeded() { + if w.shouldInstallDaemon { + if err := installDaemon(w.workspaceInfo, w.logger); err != nil { + w.logger.Errorf("install DevPod daemon: %v", err) + } } +} - daemonErrChan := init.configureDockerDaemonAsync() - if err := <-daemonErrChan; err != nil { - init.logger.Warn( +func (w *workspaceInitializer) tryConfigureDockerDaemon() { + if !w.shouldConfigureDockerDaemon() { + w.logger.Debug("skipping configuring docker daemon") + return + } + if err := configureDockerDaemon(w.ctx, w.logger); err != nil { + w.logger.Warn( "could not find docker daemon config file, if using the registry cache, " + "please ensure the daemon is configured with containerd-snapshotter=true, " + "more info at https://docs.docker.com/engine/storage/containerd/", ) } - - return init.tunnelClient, init.logger, init.dockerCredentialsDir, nil } func (w *workspaceInitializer) initializeTunnel() error { @@ -362,8 +392,11 @@ func (w *workspaceInitializer) ensureDockerInstalled() (string, error) { } if dockerCmd != "docker" { - _, err := exec.LookPath(dockerCmd) - return "", err + path, err := exec.LookPath(dockerCmd) + if err != nil { + return "", fmt.Errorf("custom docker path %q not found: %w", dockerCmd, err) + } + return path, nil } if w.isDockerInstallDisabled() { @@ -399,6 +432,8 @@ func (w *workspaceInitializer) prepareWorkspaceContent() error { }) } +// waitForDocker waits for the Docker installation to complete. +// Note: This function modifies workspaceInfo.Agent.Docker.Path if Docker was installed. func (w *workspaceInitializer) waitForDocker(resultChan <-chan dockerInstallResult) error { result := <-resultChan @@ -414,22 +449,6 @@ func (w *workspaceInitializer) waitForDocker(resultChan <-chan dockerInstallResu return nil } -func (w *workspaceInitializer) configureDockerDaemonAsync() <-chan error { - errChan := make(chan error, 1) - - if !w.shouldConfigureDockerDaemon() { - w.logger.Debug("skipping configuring docker daemon") - errChan <- nil - return errChan - } - - go func() { - errChan <- configureDockerDaemon(w.ctx, w.logger) - }() - - return errChan -} - func (w *workspaceInitializer) shouldConfigureDockerDaemon() bool { if !w.workspaceInfo.Agent.IsDockerDriver() { return false @@ -445,12 +464,14 @@ func (w *workspaceInitializer) shouldConfigureDockerDaemon() bool { type prepareWorkspaceParams struct { ctx context.Context - workspaceInfo *provider2.AgentWorkspaceInfo + workspaceInfo *provider.AgentWorkspaceInfo client tunnel.TunnelClient gitHelper string log log.Logger } +// prepareWorkspace initializes the workspace content folder and downloads/prepares the workspace source. +// Note: This function modifies params.workspaceInfo.ContentFolder when platform is enabled with a local folder. func prepareWorkspace(params prepareWorkspaceParams) error { if params.workspaceInfo.CLIOptions.Platform.Enabled && params.workspaceInfo.Workspace.Source.LocalFolder != "" { params.workspaceInfo.ContentFolder = agent.GetAgentWorkspaceContentDir(params.workspaceInfo.Origin) @@ -494,7 +515,7 @@ func prepareWorkspace(params prepareWorkspaceParams) error { type prepareGitWorkspaceParams struct { ctx context.Context - workspaceInfo *provider2.AgentWorkspaceInfo + workspaceInfo *provider.AgentWorkspaceInfo gitHelper string exists bool log log.Logger @@ -530,7 +551,12 @@ func prepareGitWorkspace(params prepareGitWorkspaceParams) error { ) } -func prepareLocalWorkspace(ctx context.Context, workspaceInfo *provider2.AgentWorkspaceInfo, client tunnel.TunnelClient, log log.Logger) error { +func prepareLocalWorkspace( + ctx context.Context, + workspaceInfo *provider.AgentWorkspaceInfo, + client tunnel.TunnelClient, + log log.Logger, +) error { if workspaceInfo.ContentFolder == workspaceInfo.Workspace.Source.LocalFolder { log.Debugf("local folder %s with local provider; skip downloading", workspaceInfo.ContentFolder) return nil @@ -540,7 +566,7 @@ func prepareLocalWorkspace(ctx context.Context, workspaceInfo *provider2.AgentWo return downloadLocalFolder(ctx, workspaceInfo.ContentFolder, client, log) } -func ensureLastDevContainerJson(workspaceInfo *provider2.AgentWorkspaceInfo) error { +func ensureLastDevContainerJson(workspaceInfo *provider.AgentWorkspaceInfo) error { filePath := filepath.Join(workspaceInfo.ContentFolder, filepath.FromSlash(workspaceInfo.LastDevContainerConfig.Path)) if _, err := os.Stat(filePath); err == nil { @@ -549,7 +575,7 @@ func ensureLastDevContainerJson(workspaceInfo *provider2.AgentWorkspaceInfo) err return fmt.Errorf("error stating %s: %w", filePath, err) } - if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { + if err := os.MkdirAll(filepath.Dir(filePath), 0o750); err != nil { return fmt.Errorf("create %s: %w", filepath.Dir(filePath), err) } @@ -567,7 +593,7 @@ func ensureLastDevContainerJson(workspaceInfo *provider2.AgentWorkspaceInfo) err type credentialsConfig struct { ctx context.Context - workspaceInfo *provider2.AgentWorkspaceInfo + workspaceInfo *provider.AgentWorkspaceInfo client tunnel.TunnelClient log log.Logger } @@ -608,7 +634,7 @@ func configureCredentials(cfg credentialsConfig) (string, string, error) { return dockerCredentials, gitCredentials, nil } -func installDaemon(workspaceInfo *provider2.AgentWorkspaceInfo, log log.Logger) error { +func installDaemon(workspaceInfo *provider.AgentWorkspaceInfo, log log.Logger) error { if len(workspaceInfo.Agent.Exec.Shutdown) == 0 { return nil } @@ -634,45 +660,40 @@ func prepareImage(workspaceDir, image string) error { return os.WriteFile(filepath.Join(workspaceDir, ".devcontainer.json"), devcontainerConfig, 0o600) } +// installDocker installs Docker and returns the path to the docker binary. +// This function assumes docker does not already exist - the caller should check first. func installDocker(log log.Logger) (dockerPath string, err error) { - if !command.Exists("docker") { - writer := log.Writer(logrus.InfoLevel, false) - defer func() { _ = writer.Close() }() - - log.Debug("Installing Docker") - - dockerPath, err = dockerinstall.Install(writer, writer) - } else { - dockerPath = "docker" - } - return dockerPath, err + writer := log.Writer(logrus.InfoLevel, false) + defer func() { _ = writer.Close() }() + log.Debug("installing Docker") + return dockerinstall.Install(writer, writer) } func configureDockerDaemon(ctx context.Context, log log.Logger) error { log.Info("configuring docker daemon") - daemonConfig := []byte(`{ - "features": { - "containerd-snapshotter": true - } - }`) - - if err := writeDockerDaemonConfig(daemonConfig); err != nil { + if err := mergeDockerDaemonConfig(); err != nil { return err } return reloadDockerDaemon(ctx) } -func writeDockerDaemonConfig(config []byte) error { - if err := tryWriteRootlessDockerConfig(config); err == nil { +func mergeDockerDaemonConfig() error { + rootlessErr := tryMergeRootlessDockerConfig() + if rootlessErr == nil { return nil } - return os.WriteFile("/etc/docker/daemon.json", config, 0644) + rootErr := tryMergeRootDockerConfig() + if rootErr == nil { + return nil + } + + return fmt.Errorf("failed to write docker daemon config (rootless: %v, root: %v)", rootlessErr, rootErr) } -func tryWriteRootlessDockerConfig(config []byte) error { +func tryMergeRootlessDockerConfig() error { homeDir, err := util.UserHomeDir() if err != nil { return err @@ -683,9 +704,79 @@ func tryWriteRootlessDockerConfig(config []byte) error { return err } - return os.WriteFile(filepath.Join(dockerConfigDir, "daemon.json"), config, 0644) + configPath := filepath.Join(dockerConfigDir, "daemon.json") + return mergeContainerdSnapshotterConfig(configPath) +} + +func tryMergeRootDockerConfig() error { + return mergeContainerdSnapshotterConfig("/etc/docker/daemon.json") +} + +func mergeContainerdSnapshotterConfig(configPath string) error { + existingConfig, err := readExistingConfig(configPath) + if err != nil { + return err + } + + features := ensureFeaturesMap(existingConfig) + features["containerd-snapshotter"] = true + + return writeConfig(configPath, existingConfig) +} + +func readExistingConfig(configPath string) (map[string]any, error) { + existingConfig := make(map[string]any) + // #nosec G304 -- configPath is controlled by the application + data, err := os.ReadFile(configPath) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("read existing config: %w", err) + } + + if len(data) > 0 { + if err := json.Unmarshal(data, &existingConfig); err != nil { + return nil, fmt.Errorf("parse existing config: %w", err) + } + } + return existingConfig, nil +} + +func ensureFeaturesMap(config map[string]any) map[string]any { + features, ok := config["features"].(map[string]any) + if !ok { + features = make(map[string]any) + config["features"] = features + } + return features +} + +func writeConfig(configPath string, config map[string]any) error { + mergedData, err := json.MarshalIndent(config, "", " ") + if err != nil { + return fmt.Errorf("marshal config: %w", err) + } + + // #nosec G301 -- directory needs to be accessible by docker daemon + if err := os.MkdirAll(filepath.Dir(configPath), 0755); err != nil { + return fmt.Errorf("create config directory: %w", err) + } + + // #nosec G306 -- daemon.json needs to be readable by docker daemon + if err := os.WriteFile(configPath, mergedData, 0644); err != nil { + return fmt.Errorf("write config: %w", err) + } + + return nil } func reloadDockerDaemon(ctx context.Context) error { - return exec.CommandContext(ctx, "pkill", "-HUP", "dockerd").Run() + err := exec.CommandContext(ctx, "pkill", "-HUP", "dockerd").Run() + if err != nil { + // pkill returns exit code 1 if no processes matched + var exitErr *exec.ExitError + if errors.As(err, &exitErr) && exitErr.ExitCode() == 1 { + return nil // No dockerd process found, nothing to reload + } + return err + } + return nil } diff --git a/e2e/tests/integration/integration.go b/e2e/tests/integration/integration.go index b33d26bbc..cc5f9b633 100644 --- a/e2e/tests/integration/integration.go +++ b/e2e/tests/integration/integration.go @@ -5,16 +5,15 @@ import ( "fmt" "os" "os/exec" + "path/filepath" "github.com/onsi/ginkgo/v2" - "github.com/onsi/gomega" "github.com/skevetter/devpod/e2e/framework" ) var _ = ginkgo.Describe("[integration]: devpod provider ssh test suite", ginkgo.Ordered, func() { ginkgo.Context("testing provider integration", ginkgo.Label("integration"), ginkgo.Ordered, func() { var initialDir string - ctx := context.Background() ginkgo.BeforeEach(func() { var err error @@ -22,35 +21,43 @@ var _ = ginkgo.Describe("[integration]: devpod provider ssh test suite", ginkgo. framework.ExpectNoError(err) }) - ginkgo.It("should generate ssh keypairs", func() { + ginkgo.It("should generate ssh keypairs", func(ctx context.Context) { sshDir := os.Getenv("HOME") + "/.ssh" if _, err := os.Stat(sshDir); os.IsNotExist(err) { err = os.MkdirAll(sshDir, 0700) framework.ExpectNoError(err) } - _, err := os.Stat(os.Getenv("HOME") + "/.ssh/id_rsa") + homeDir := os.Getenv("HOME") + sshKeyPath := filepath.Join(homeDir, ".ssh", "id_rsa") + sshPubKeyPath := filepath.Join(homeDir, ".ssh", "id_rsa.pub") + + _, err := os.Stat(sshKeyPath) if err != nil { fmt.Println("generating ssh keys") - cmd := exec.Command("ssh-keygen", "-q", "-t", "rsa", "-N", "", "-f", os.Getenv("HOME")+"/.ssh/id_rsa") + // #nosec G204 -- ssh-keygen with fixed arguments for test setup + cmd := exec.CommandContext(ctx, "ssh-keygen", "-q", "-t", "rsa", "-N", "", "-f", sshKeyPath) err = cmd.Run() framework.ExpectNoError(err) - cmd = exec.Command("ssh-keygen", "-y", "-f", os.Getenv("HOME")+"/.ssh/id_rsa") + // #nosec G204 -- ssh-keygen with fixed arguments for test setup + cmd = exec.CommandContext(ctx, "ssh-keygen", "-y", "-f", sshKeyPath) output, err := cmd.Output() framework.ExpectNoError(err) - err = os.WriteFile(os.Getenv("HOME")+"/.ssh/id_rsa.pub", output, 0600) + err = os.WriteFile(sshPubKeyPath, output, 0600) framework.ExpectNoError(err) } - cmd := exec.Command("ssh-keygen", "-y", "-f", os.Getenv("HOME")+"/.ssh/id_rsa") + // #nosec G204 -- ssh-keygen with fixed arguments for test setup + cmd := exec.CommandContext(ctx, "ssh-keygen", "-y", "-f", sshKeyPath) publicKey, err := cmd.Output() framework.ExpectNoError(err) - _, err = os.Stat(os.Getenv("HOME") + "/.ssh/authorized_keys") + authorizedKeysPath := filepath.Join(homeDir, ".ssh", "authorized_keys") + _, err = os.Stat(authorizedKeysPath) if err != nil { - err = os.WriteFile(os.Getenv("HOME")+"/.ssh/authorized_keys", publicKey, 0600) + err = os.WriteFile(authorizedKeysPath, publicKey, 0600) framework.ExpectNoError(err) } else { f, err := os.OpenFile(os.Getenv("HOME")+"/.ssh/authorized_keys", @@ -63,7 +70,7 @@ var _ = ginkgo.Describe("[integration]: devpod provider ssh test suite", ginkgo. } }) - ginkgo.It("should add provider to devpod", func() { + ginkgo.It("should add provider to devpod", func(ctx context.Context) { f := framework.NewDefaultFramework(initialDir + "/bin") // ensure we don't have the ssh provider present err := f.DevPodProviderDelete(ctx, "ssh") @@ -81,12 +88,11 @@ var _ = ginkgo.Describe("[integration]: devpod provider ssh test suite", ginkgo. framework.ExpectNoError(err) }) - ginkgo.It("should run commands to workspace via ssh", func() { - cmd := exec.Command("ssh", "testdata.devpod", "echo", "test") - output, err := cmd.Output() + ginkgo.It("should run commands to workspace via ssh", func(ctx context.Context) { + f := framework.NewDefaultFramework(initialDir + "/bin") + out, err := f.DevPodSSH(ctx, "testdata", "echo test") framework.ExpectNoError(err) - - gomega.Expect(output).To(gomega.Equal([]byte("test\n"))) + framework.ExpectEqual(out, "test\n") }) ginkgo.It("should cleanup devpod workspace", func(ctx context.Context) { diff --git a/pkg/client/clientimplementation/daemonclient/client.go b/pkg/client/clientimplementation/daemonclient/client.go index cfab821a3..ff500df3e 100644 --- a/pkg/client/clientimplementation/daemonclient/client.go +++ b/pkg/client/clientimplementation/daemonclient/client.go @@ -17,6 +17,7 @@ import ( "github.com/skevetter/devpod/pkg/config" daemon "github.com/skevetter/devpod/pkg/daemon/platform" "github.com/skevetter/devpod/pkg/options" + "github.com/skevetter/devpod/pkg/options/resolver" "github.com/skevetter/devpod/pkg/platform" platformclient "github.com/skevetter/devpod/pkg/platform/client" "github.com/skevetter/devpod/pkg/provider" @@ -105,7 +106,15 @@ func (c *client) RefreshOptions(ctx context.Context, userOptionsRaw []string, re return fmt.Errorf("parse options: %w", err) } - workspace, err := options.ResolveAndSaveOptionsProxy(ctx, c.devPodConfig, c.config, c.workspace, userOptions, c.log) + workspace, err := options.ResolveAndSaveOptionsWorkspace( + ctx, + c.devPodConfig, + c.config, + c.workspace, + userOptions, + c.log, + resolver.WithResolveSubOptions(), + ) if err != nil { return err } @@ -248,8 +257,8 @@ func (c *client) Ping(ctx context.Context, writer io.Writer) error { for range 10 { timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() result, err := c.tsClient.Ping(timeoutCtx, *ip, tailcfg.PingDisco) + cancel() if err != nil { return err } diff --git a/pkg/client/clientimplementation/proxy_client.go b/pkg/client/clientimplementation/proxy_client.go index e1f99fb4e..bbde80d56 100644 --- a/pkg/client/clientimplementation/proxy_client.go +++ b/pkg/client/clientimplementation/proxy_client.go @@ -19,6 +19,7 @@ import ( "github.com/skevetter/devpod/pkg/config" devpodlog "github.com/skevetter/devpod/pkg/log" "github.com/skevetter/devpod/pkg/options" + "github.com/skevetter/devpod/pkg/options/resolver" platformclient "github.com/skevetter/devpod/pkg/platform/client" "github.com/skevetter/devpod/pkg/provider" "github.com/skevetter/devpod/pkg/types" @@ -204,7 +205,15 @@ func (s *proxyClient) RefreshOptions(ctx context.Context, userOptionsRaw []strin return fmt.Errorf("parse options: %w", err) } - workspace, err := options.ResolveAndSaveOptionsProxy(ctx, s.devPodConfig, s.config, s.workspace, userOptions, s.log) + workspace, err := options.ResolveAndSaveOptionsWorkspace( + ctx, + s.devPodConfig, + s.config, + s.workspace, + userOptions, + s.log, + resolver.WithResolveSubOptions(), + ) if err != nil { return err } diff --git a/pkg/client/clientimplementation/workspace_client.go b/pkg/client/clientimplementation/workspace_client.go index 0942c62fa..2df0a4980 100644 --- a/pkg/client/clientimplementation/workspace_client.go +++ b/pkg/client/clientimplementation/workspace_client.go @@ -17,7 +17,6 @@ import ( "github.com/sirupsen/logrus" "github.com/skevetter/devpod/pkg/agent" "github.com/skevetter/devpod/pkg/agent/tunnelserver" - "github.com/skevetter/devpod/pkg/binaries" "github.com/skevetter/devpod/pkg/client" "github.com/skevetter/devpod/pkg/compress" "github.com/skevetter/devpod/pkg/config" @@ -506,23 +505,10 @@ func (s *workspaceClient) Stop(ctx context.Context, opt client.StopOptions) erro } func (s *workspaceClient) Command(ctx context.Context, commandOptions client.CommandOptions) (err error) { - // get environment variables - s.m.Lock() - environ, err := binaries.ToEnvironmentWithBinaries(binaries.EnvironmentOptions{ - Context: s.workspace.Context, - Workspace: s.workspace, - Machine: s.machine, - Options: s.devPodConfig.ProviderOptions(s.config.Name), - Config: s.config, - ExtraEnv: map[string]string{ - provider.CommandEnv: commandOptions.Command, - }, - Log: s.log, - }) + environ, err := s.buildEnvironment(commandOptions.Command) if err != nil { return err } - s.m.Unlock() return RunCommand(RunCommandOptions{ Ctx: ctx, @@ -647,8 +633,25 @@ type CommandOptions struct { Log log.Logger } +func (s *workspaceClient) buildEnvironment(command string) ([]string, error) { + s.m.Lock() + defer s.m.Unlock() + + return provider.ToEnvironmentWithBinaries(provider.EnvironmentOptions{ + Context: s.workspace.Context, + Workspace: s.workspace, + Machine: s.machine, + Options: s.devPodConfig.ProviderOptions(s.config.Name), + Config: s.config, + ExtraEnv: map[string]string{ + provider.CommandEnv: command, + }, + Log: s.log, + }) +} + func RunCommandWithBinaries(opts CommandOptions) error { - environ, err := binaries.ToEnvironmentWithBinaries(binaries.EnvironmentOptions{ + environ, err := provider.ToEnvironmentWithBinaries(provider.EnvironmentOptions{ Context: opts.Context, Workspace: opts.Workspace, Machine: opts.Machine, diff --git a/pkg/download/download.go b/pkg/download/download.go index 4f64d8225..acca5d9d2 100644 --- a/pkg/download/download.go +++ b/pkg/download/download.go @@ -13,8 +13,31 @@ import ( "github.com/skevetter/log" ) +// HTTPStatusError wraps HTTP status code errors for better error handling. +type HTTPStatusError struct { + StatusCode int + URL string + Body string +} + +func (e *HTTPStatusError) Error() string { + if e.Body != "" { + return fmt.Sprintf( + "received status code %d when trying to download %s: %s", + e.StatusCode, + e.URL, + e.Body, + ) + } + return fmt.Sprintf( + "received status code %d when trying to download %s", + e.StatusCode, + e.URL, + ) +} + func Head(rawURL string) (int, error) { - req, err := http.NewRequest("HEAD", rawURL, nil) + req, err := http.NewRequest(http.MethodHead, rawURL, nil) if err != nil { return 0, err } @@ -23,6 +46,7 @@ func Head(rawURL string) (int, error) { if err != nil { return 0, fmt.Errorf("download file: %w", err) } + defer func() { _ = resp.Body.Close() }() return resp.StatusCode, nil } @@ -33,7 +57,7 @@ func File(rawURL string, log log.Logger) (io.ReadCloser, error) { return nil, err } - req, err := http.NewRequest("GET", rawURL, nil) + req, err := http.NewRequest(http.MethodGet, rawURL, nil) if err != nil { return nil, err } @@ -67,8 +91,9 @@ func File(rawURL string, log log.Logger) (io.ReadCloser, error) { if err != nil { return nil, fmt.Errorf("download file: %w", err) } else if resp.StatusCode >= 400 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) _ = resp.Body.Close() - return nil, fmt.Errorf("received status code %d when trying to download %s", resp.StatusCode, rawURL) + return nil, &HTTPStatusError{StatusCode: resp.StatusCode, URL: rawURL, Body: string(body)} } return resp.Body, nil @@ -84,14 +109,29 @@ type GithubReleaseAsset struct { } func downloadGithubRelease(org, repo, release, file, token string) (io.ReadCloser, error) { - releaseURL := "" + var releasePath string if release == "" { - releaseURL = fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", org, repo) + releasePath = fmt.Sprintf( + "/repos/%s/%s/releases/latest", + url.PathEscape(org), + url.PathEscape(repo), + ) } else { - releaseURL = fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/tags/%s", org, repo, release) + releasePath = fmt.Sprintf( + "/repos/%s/%s/releases/tags/%s", + url.PathEscape(org), + url.PathEscape(repo), + url.PathEscape(release), + ) } - req, err := http.NewRequest("GET", releaseURL, nil) + releaseURL := (&url.URL{ + Scheme: "https", + Host: "api.github.com", + Path: releasePath, + }).String() + + req, err := http.NewRequest(http.MethodGet, releaseURL, nil) if err != nil { return nil, err } @@ -100,9 +140,16 @@ func downloadGithubRelease(org, repo, release, file, token string) (io.ReadClose resp, err := devpodhttp.GetHTTPClient().Do(req) if err != nil { return nil, err - } else if resp.StatusCode >= 400 { - _ = resp.Body.Close() - return nil, fmt.Errorf("received status code %d when trying to reach %s", resp.StatusCode, releaseURL) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return nil, &HTTPStatusError{ + StatusCode: resp.StatusCode, + URL: releaseURL, + Body: string(body), + } } raw, err := io.ReadAll(resp.Body) @@ -127,7 +174,14 @@ func downloadGithubRelease(org, repo, release, file, token string) (io.ReadClose return nil, fmt.Errorf("couldn't find asset %s in github release (%s)", file, releaseURL) } - req, err = http.NewRequest("GET", fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/assets/%d", org, repo, releaseAsset.ID), nil) + assetPath := fmt.Sprintf("/repos/%s/%s/releases/assets/%d", url.PathEscape(org), url.PathEscape(repo), releaseAsset.ID) + assetURL := (&url.URL{ + Scheme: "https", + Host: "api.github.com", + Path: assetPath, + }).String() + + req, err = http.NewRequest(http.MethodGet, assetURL, nil) if err != nil { return nil, err } @@ -137,8 +191,13 @@ func downloadGithubRelease(org, repo, release, file, token string) (io.ReadClose if err != nil { return nil, err } else if downloadResp.StatusCode >= 400 { + body, _ := io.ReadAll(io.LimitReader(downloadResp.Body, 1024)) _ = downloadResp.Body.Close() - return nil, fmt.Errorf("received status code %d when trying to reach %s", downloadResp.StatusCode, releaseURL) + return nil, &HTTPStatusError{ + StatusCode: downloadResp.StatusCode, + URL: assetURL, + Body: string(body), + } } return downloadResp.Body, nil diff --git a/pkg/driver/custom/custom.go b/pkg/driver/custom/custom.go index af26929f5..0aada37d1 100644 --- a/pkg/driver/custom/custom.go +++ b/pkg/driver/custom/custom.go @@ -11,17 +11,16 @@ import ( "github.com/sirupsen/logrus" "github.com/skevetter/devpod/pkg/agent" - "github.com/skevetter/devpod/pkg/binaries" "github.com/skevetter/devpod/pkg/client/clientimplementation" "github.com/skevetter/devpod/pkg/devcontainer/config" "github.com/skevetter/devpod/pkg/driver" - provider2 "github.com/skevetter/devpod/pkg/provider" + "github.com/skevetter/devpod/pkg/provider" "github.com/skevetter/devpod/pkg/types" "github.com/skevetter/log" "github.com/skevetter/log/scanner" ) -func NewCustomDriver(workspaceInfo *provider2.AgentWorkspaceInfo, log log.Logger) driver.Driver { +func NewCustomDriver(workspaceInfo *provider.AgentWorkspaceInfo, log log.Logger) driver.Driver { return &customDriver{ log: log, workspaceInfo: workspaceInfo, @@ -33,7 +32,7 @@ var _ driver.Driver = (*customDriver)(nil) type customDriver struct { log log.Logger - workspaceInfo *provider2.AgentWorkspaceInfo + workspaceInfo *provider.AgentWorkspaceInfo } // FindDevContainer returns a running devcontainer details @@ -292,7 +291,7 @@ func (c *customDriver) runCommand( if err != nil { return err } - environ = append(environ, provider2.DEVCONTAINER_ID+"="+workspaceId) + environ = append(environ, provider.DEVCONTAINER_ID+"="+workspaceId) environ = append(environ, extraEnv...) // set debug level @@ -311,7 +310,7 @@ func (c *customDriver) runCommand( }) } -func ToEnvironWithBinaries(workspace *provider2.AgentWorkspaceInfo, log log.Logger) ([]string, error) { +func ToEnvironWithBinaries(workspace *provider.AgentWorkspaceInfo, log log.Logger) ([]string, error) { // get binaries dir binariesDir, err := agent.GetAgentBinariesDirFromWorkspaceDir(workspace.Origin) if err != nil { @@ -322,13 +321,13 @@ func ToEnvironWithBinaries(workspace *provider2.AgentWorkspaceInfo, log log.Logg } // download binaries - agentBinaries, err := binaries.DownloadBinaries(workspace.Agent.Binaries, binariesDir, log) + agentBinaries, err := provider.DownloadBinaries(workspace.Agent.Binaries, binariesDir, log) if err != nil { return nil, fmt.Errorf("error downloading workspace %s binaries: %w", workspace.Workspace.ID, err) } // get environ - environ := provider2.ToEnvironment(workspace.Workspace, workspace.Machine, workspace.Options, nil) + environ := provider.ToEnvironment(workspace.Workspace, workspace.Machine, workspace.Options, nil) for k, v := range agentBinaries { environ = append(environ, k+"="+v) } diff --git a/pkg/options/resolve.go b/pkg/options/resolve.go index 83b649fe1..687e825ce 100644 --- a/pkg/options/resolve.go +++ b/pkg/options/resolve.go @@ -2,17 +2,17 @@ package options import ( "context" + "fmt" "maps" "os" "reflect" "strings" "github.com/skevetter/devpod/pkg/agent" - "github.com/skevetter/devpod/pkg/binaries" "github.com/skevetter/devpod/pkg/config" "github.com/skevetter/devpod/pkg/options/resolver" - provider2 "github.com/skevetter/devpod/pkg/provider" + "github.com/skevetter/devpod/pkg/provider" "github.com/skevetter/devpod/pkg/types" "github.com/skevetter/log" ) @@ -20,13 +20,16 @@ import ( func ResolveAndSaveOptionsMachine( ctx context.Context, devConfig *config.Config, - provider *provider2.ProviderConfig, - originalMachine *provider2.Machine, + providerConfig *provider.ProviderConfig, + originalMachine *provider.Machine, userOptions map[string]string, log log.Logger, -) (*provider2.Machine, error) { +) (*provider.Machine, error) { + if originalMachine == nil { + return nil, fmt.Errorf("originalMachine cannot be nil") + } // reload config - machine, err := provider2.LoadMachineConfig(originalMachine.Context, originalMachine.ID) + machine, err := provider.LoadMachineConfig(originalMachine.Context, originalMachine.ID) if err != nil { return originalMachine, err } @@ -38,7 +41,7 @@ func ResolveAndSaveOptionsMachine( } // get binary paths - binaryPaths, err := binaries.GetBinaries(devConfig.DefaultContext, provider) + binaryPaths, err := provider.GetBinaries(devConfig.DefaultContext, providerConfig) if err != nil { return nil, err } @@ -46,28 +49,34 @@ func ResolveAndSaveOptionsMachine( // resolve options resolvedOptions, _, err := resolver.New( userOptions, - provider2.Merge(provider2.ToOptionsMachine(machine), binaryPaths), + provider.Merge(provider.ToOptionsMachine(machine), binaryPaths), log, resolver.WithResolveLocal(), ).Resolve( ctx, - devConfig.DynamicProviderOptionDefinitions(provider.Name), - provider.Options, - provider2.CombineOptions(nil, machine, devConfig.ProviderOptions(provider.Name)), + devConfig.DynamicProviderOptionDefinitions(providerConfig.Name), + providerConfig.Options, + provider.CombineOptions(nil, machine, devConfig.ProviderOptions(providerConfig.Name)), ) if err != nil { return nil, err } // remove global options - filterResolvedOptions(resolvedOptions, beforeConfigOptions, devConfig.ProviderOptions(provider.Name), provider.Options, userOptions) + filterResolvedOptions( + resolvedOptions, + beforeConfigOptions, + devConfig.ProviderOptions(providerConfig.Name), + providerConfig.Options, + userOptions, + ) // save machine config if machine != nil { machine.Provider.Options = resolvedOptions if !reflect.DeepEqual(beforeConfigOptions, machine.Provider.Options) { - err = provider2.SaveMachineConfig(machine) + err = provider.SaveMachineConfig(machine) if err != nil { return machine, err } @@ -80,26 +89,29 @@ func ResolveAndSaveOptionsMachine( func ResolveAndSaveOptionsWorkspace( ctx context.Context, devConfig *config.Config, - provider *provider2.ProviderConfig, - originalWorkspace *provider2.Workspace, + providerConfig *provider.ProviderConfig, + originalWorkspace *provider.Workspace, userOptions map[string]string, log log.Logger, options ...resolver.Option, -) (*provider2.Workspace, error) { +) (*provider.Workspace, error) { + if originalWorkspace == nil { + return nil, fmt.Errorf("originalWorkspace cannot be nil") + } // reload config - workspace, err := provider2.LoadWorkspaceConfig(originalWorkspace.Context, originalWorkspace.ID) + workspace, err := provider.LoadWorkspaceConfig(originalWorkspace.Context, originalWorkspace.ID) if err != nil { return originalWorkspace, err } + if workspace == nil { + return nil, fmt.Errorf("failed to load workspace config: workspace not found") + } // resolve devconfig options - var beforeConfigOptions map[string]config.OptionValue - if workspace != nil { - beforeConfigOptions = workspace.Provider.Options - } + beforeConfigOptions := workspace.Provider.Options // get binary paths - binaryPaths, err := binaries.GetBinaries(devConfig.DefaultContext, provider) + binaryPaths, err := provider.GetBinaries(devConfig.DefaultContext, providerConfig) if err != nil { return nil, err } @@ -108,52 +120,44 @@ func ResolveAndSaveOptionsWorkspace( // resolve options resolvedOptions, _, err := resolver.New( userOptions, - provider2.Merge(provider2.ToOptionsWorkspace(workspace), binaryPaths), + provider.Merge(provider.ToOptionsWorkspace(workspace), binaryPaths), log, options..., ).Resolve( ctx, - devConfig.DynamicProviderOptionDefinitions(provider.Name), - provider.Options, - provider2.CombineOptions(workspace, nil, devConfig.ProviderOptions(provider.Name)), + devConfig.DynamicProviderOptionDefinitions(providerConfig.Name), + providerConfig.Options, + provider.CombineOptions(workspace, nil, devConfig.ProviderOptions(providerConfig.Name)), ) if err != nil { return nil, err } // remove global options - filterResolvedOptions(resolvedOptions, beforeConfigOptions, devConfig.ProviderOptions(provider.Name), provider.Options, userOptions) + filterResolvedOptions( + resolvedOptions, + beforeConfigOptions, + devConfig.ProviderOptions(providerConfig.Name), + providerConfig.Options, + userOptions, + ) // save workspace config - if workspace != nil { - workspace.Provider.Options = resolvedOptions - - if !reflect.DeepEqual(beforeConfigOptions, workspace.Provider.Options) { - err = provider2.SaveWorkspaceConfig(workspace) - if err != nil { - return workspace, err - } + workspace.Provider.Options = resolvedOptions + if !reflect.DeepEqual(beforeConfigOptions, workspace.Provider.Options) { + err = provider.SaveWorkspaceConfig(workspace) + if err != nil { + return workspace, err } } return workspace, nil } -func ResolveAndSaveOptionsProxy( - ctx context.Context, - devConfig *config.Config, - provider *provider2.ProviderConfig, - originalWorkspace *provider2.Workspace, - userOptions map[string]string, - log log.Logger, -) (*provider2.Workspace, error) { - return ResolveAndSaveOptionsWorkspace(ctx, devConfig, provider, originalWorkspace, userOptions, log, resolver.WithResolveSubOptions()) -} - func ResolveOptions( ctx context.Context, devConfig *config.Config, - provider *provider2.ProviderConfig, + providerConfig *provider.ProviderConfig, userOptions map[string]string, skipRequired bool, skipSubOptions bool, @@ -161,7 +165,7 @@ func ResolveOptions( log log.Logger, ) (*config.Config, error) { // get binary paths - binaryPaths, err := binaries.GetBinaries(devConfig.DefaultContext, provider) + binaryPaths, err := provider.GetBinaries(devConfig.DefaultContext, providerConfig) if err != nil { return nil, err } @@ -177,7 +181,7 @@ func ResolveOptions( // create new resolver resolve := resolver.New( userOptions, - provider2.Merge(provider2.GetBaseEnvironment(devConfig.DefaultContext, provider.Name), binaryPaths), + provider.Merge(provider.GetBaseEnvironment(devConfig.DefaultContext, providerConfig.Name), binaryPaths), log, resolverOpts..., ) @@ -186,8 +190,8 @@ func ResolveOptions( resolvedOptionValues, dynamicOptionDefinitions, err := resolve.Resolve( ctx, nil, - provider.Options, - devConfig.ProviderOptions(provider.Name), + providerConfig.Options, + devConfig.ProviderOptions(providerConfig.Name), ) if err != nil { return nil, err @@ -199,66 +203,120 @@ func ResolveOptions( if devConfig.Current().Providers == nil { devConfig.Current().Providers = map[string]*config.ProviderConfig{} } - if devConfig.Current().Providers[provider.Name] == nil { - devConfig.Current().Providers[provider.Name] = &config.ProviderConfig{} + if devConfig.Current().Providers[providerConfig.Name] == nil { + devConfig.Current().Providers[providerConfig.Name] = &config.ProviderConfig{} } - devConfig.Current().Providers[provider.Name].Options = map[string]config.OptionValue{} - maps.Copy(devConfig.Current().Providers[provider.Name].Options, resolvedOptionValues) - devConfig.Current().Providers[provider.Name].DynamicOptions = config.OptionDefinitions{} - maps.Copy(devConfig.Current().Providers[provider.Name].DynamicOptions, dynamicOptionDefinitions) + providerCfg := devConfig.Current().Providers[providerConfig.Name] + providerCfg.Options = map[string]config.OptionValue{} + maps.Copy(providerCfg.Options, resolvedOptionValues) + + providerCfg.DynamicOptions = config.OptionDefinitions{} + maps.Copy(providerCfg.DynamicOptions, dynamicOptionDefinitions) if singleMachine != nil { - devConfig.Current().Providers[provider.Name].SingleMachine = *singleMachine + providerCfg.SingleMachine = *singleMachine } } return devConfig, nil } -func ResolveAgentConfig(devConfig *config.Config, provider *provider2.ProviderConfig, workspace *provider2.Workspace, machine *provider2.Machine) provider2.ProviderAgentConfig { - // fill in agent config - options := provider2.ToOptions(workspace, machine, devConfig.ProviderOptions(provider.Name)) - agentConfig := provider.Agent +// ResolveAgentConfig resolves and returns the complete agent configuration for a provider. +// It merges configuration from the provider, workspace, machine, and devConfig, resolving +// all dynamic values and setting appropriate defaults for agent paths, Docker settings, +// Kubernetes settings, and credentials. +// +// Parameters: +// - devConfig: The DevPod configuration containing global settings +// - providerConfig: The provider's configuration +// - workspace: The workspace configuration (can be nil for machine-only operations) +// - machine: The machine configuration (can be nil for workspace-only operations) +// +// Returns a fully resolved ProviderAgentConfig ready for use by the agent. +func ResolveAgentConfig( + devConfig *config.Config, + providerConfig *provider.ProviderConfig, + workspace *provider.Workspace, + machine *provider.Machine, +) provider.ProviderAgentConfig { + if providerConfig == nil || devConfig == nil { + return provider.ProviderAgentConfig{} + } + options := provider.ToOptions(workspace, machine, devConfig.ProviderOptions(providerConfig.Name)) + agentConfig := providerConfig.Agent + + resolveAgentBaseConfig(&agentConfig, options, devConfig) + resolveAgentDockerConfig(&agentConfig, options) + resolveAgentKubernetesConfig(&agentConfig, options) + resolveAgentPathAndURL(&agentConfig, options, devConfig) + resolveAgentCredentials(&agentConfig, options, devConfig) + + return agentConfig +} + +func resolveAgentBaseConfig( + agentConfig *provider.ProviderAgentConfig, + options map[string]string, + devConfig *config.Config, +) { agentConfig.Dockerless.Image = resolver.ResolveDefaultValue(agentConfig.Dockerless.Image, options) - agentConfig.Dockerless.Disabled = types.StrBool(resolver.ResolveDefaultValue(string(agentConfig.Dockerless.Disabled), options)) + agentConfig.Dockerless.Disabled = types.StrBool( + resolver.ResolveDefaultValue(string(agentConfig.Dockerless.Disabled), options), + ) agentConfig.Dockerless.IgnorePaths = resolver.ResolveDefaultValue(agentConfig.Dockerless.IgnorePaths, options) agentConfig.Dockerless.RegistryCache = devConfig.ContextOption(config.ContextOptionRegistryCache) agentConfig.Driver = resolver.ResolveDefaultValue(agentConfig.Driver, options) agentConfig.Local = types.StrBool(resolver.ResolveDefaultValue(string(agentConfig.Local), options)) +} - // docker driver +func resolveAgentDockerConfig(agentConfig *provider.ProviderAgentConfig, options map[string]string) { agentConfig.Docker.Path = resolver.ResolveDefaultValue(agentConfig.Docker.Path, options) agentConfig.Docker.Builder = resolver.ResolveDefaultValue(agentConfig.Docker.Builder, options) - agentConfig.Docker.Install = types.StrBool(resolver.ResolveDefaultValue(string(agentConfig.Docker.Install), options)) + agentConfig.Docker.Install = types.StrBool( + resolver.ResolveDefaultValue(string(agentConfig.Docker.Install), options), + ) agentConfig.Docker.Env = resolver.ResolveDefaultValues(agentConfig.Docker.Env, options) +} - // kubernetes driver - agentConfig.Kubernetes.KubernetesContext = resolver.ResolveDefaultValue(agentConfig.Kubernetes.KubernetesContext, options) - agentConfig.Kubernetes.KubernetesConfig = resolver.ResolveDefaultValue(agentConfig.Kubernetes.KubernetesConfig, options) - agentConfig.Kubernetes.KubernetesNamespace = resolver.ResolveDefaultValue(agentConfig.Kubernetes.KubernetesNamespace, options) - agentConfig.Kubernetes.Architecture = resolver.ResolveDefaultValue(agentConfig.Kubernetes.Architecture, options) - agentConfig.Kubernetes.InactivityTimeout = resolver.ResolveDefaultValue(agentConfig.Kubernetes.InactivityTimeout, options) - agentConfig.Kubernetes.StorageClass = resolver.ResolveDefaultValue(agentConfig.Kubernetes.StorageClass, options) - agentConfig.Kubernetes.PvcAccessMode = resolver.ResolveDefaultValue(agentConfig.Kubernetes.PvcAccessMode, options) - agentConfig.Kubernetes.PvcAnnotations = resolver.ResolveDefaultValue(agentConfig.Kubernetes.PvcAnnotations, options) - agentConfig.Kubernetes.NodeSelector = resolver.ResolveDefaultValue(agentConfig.Kubernetes.NodeSelector, options) - agentConfig.Kubernetes.Resources = resolver.ResolveDefaultValue(agentConfig.Kubernetes.Resources, options) - agentConfig.Kubernetes.WorkspaceVolumeMount = resolver.ResolveDefaultValue(agentConfig.Kubernetes.WorkspaceVolumeMount, options) - agentConfig.Kubernetes.PodManifestTemplate = resolver.ResolveDefaultValue(agentConfig.Kubernetes.PodManifestTemplate, options) - agentConfig.Kubernetes.Labels = resolver.ResolveDefaultValue(agentConfig.Kubernetes.Labels, options) - agentConfig.Kubernetes.StrictSecurity = resolver.ResolveDefaultValue(agentConfig.Kubernetes.StrictSecurity, options) - agentConfig.Kubernetes.CreateNamespace = resolver.ResolveDefaultValue(agentConfig.Kubernetes.CreateNamespace, options) - agentConfig.Kubernetes.ClusterRole = resolver.ResolveDefaultValue(agentConfig.Kubernetes.ClusterRole, options) - agentConfig.Kubernetes.ServiceAccount = resolver.ResolveDefaultValue(agentConfig.Kubernetes.ServiceAccount, options) - agentConfig.Kubernetes.PodTimeout = resolver.ResolveDefaultValue(agentConfig.Kubernetes.PodTimeout, options) - agentConfig.Kubernetes.KubernetesPullSecretsEnabled = resolver.ResolveDefaultValue(agentConfig.Kubernetes.KubernetesPullSecretsEnabled, options) - agentConfig.Kubernetes.DiskSize = resolver.ResolveDefaultValue(agentConfig.Kubernetes.DiskSize, options) +func resolveAgentKubernetesConfig(agentConfig *provider.ProviderAgentConfig, options map[string]string) { + k8s := &agentConfig.Kubernetes + k8s.KubernetesContext = resolver.ResolveDefaultValue(k8s.KubernetesContext, options) + k8s.KubernetesConfig = resolver.ResolveDefaultValue(k8s.KubernetesConfig, options) + k8s.KubernetesNamespace = resolver.ResolveDefaultValue(k8s.KubernetesNamespace, options) + k8s.Architecture = resolver.ResolveDefaultValue(k8s.Architecture, options) + k8s.InactivityTimeout = resolver.ResolveDefaultValue(k8s.InactivityTimeout, options) + k8s.StorageClass = resolver.ResolveDefaultValue(k8s.StorageClass, options) + k8s.PvcAccessMode = resolver.ResolveDefaultValue(k8s.PvcAccessMode, options) + k8s.PvcAnnotations = resolver.ResolveDefaultValue(k8s.PvcAnnotations, options) + k8s.NodeSelector = resolver.ResolveDefaultValue(k8s.NodeSelector, options) + k8s.Resources = resolver.ResolveDefaultValue(k8s.Resources, options) + k8s.WorkspaceVolumeMount = resolver.ResolveDefaultValue(k8s.WorkspaceVolumeMount, options) + k8s.PodManifestTemplate = resolver.ResolveDefaultValue(k8s.PodManifestTemplate, options) + k8s.Labels = resolver.ResolveDefaultValue(k8s.Labels, options) + k8s.StrictSecurity = resolver.ResolveDefaultValue(k8s.StrictSecurity, options) + k8s.CreateNamespace = resolver.ResolveDefaultValue(k8s.CreateNamespace, options) + k8s.ClusterRole = resolver.ResolveDefaultValue(k8s.ClusterRole, options) + k8s.ServiceAccount = resolver.ResolveDefaultValue(k8s.ServiceAccount, options) + k8s.PodTimeout = resolver.ResolveDefaultValue(k8s.PodTimeout, options) + k8s.KubernetesPullSecretsEnabled = resolver.ResolveDefaultValue(k8s.KubernetesPullSecretsEnabled, options) + k8s.DiskSize = resolver.ResolveDefaultValue(k8s.DiskSize, options) +} +func resolveAgentPathAndURL( + agentConfig *provider.ProviderAgentConfig, + options map[string]string, + devConfig *config.Config, +) { agentConfig.DataPath = resolver.ResolveDefaultValue(agentConfig.DataPath, options) agentConfig.Path = resolver.ResolveDefaultValue(agentConfig.Path, options) - if agentConfig.Path == "" && agentConfig.Local == "true" { - agentConfig.Path, _ = os.Executable() - } else if agentConfig.Path == "" { + if agentConfig.Path == "" && strings.EqualFold(string(agentConfig.Local), "true") { + // Try to use the current executable path for local agent + // Error is silently handled as we have a fallback to RemoteDevPodHelperLocation + if execPath, err := os.Executable(); err == nil { + agentConfig.Path = execPath + } + } + if agentConfig.Path == "" { agentConfig.Path = agent.RemoteDevPodHelperLocation } agentConfig.DownloadURL = resolver.ResolveDefaultValue(agentConfig.DownloadURL, options) @@ -267,18 +325,30 @@ func ResolveAgentConfig(devConfig *config.Config, provider *provider2.ProviderCo } agentConfig.Timeout = resolver.ResolveDefaultValue(agentConfig.Timeout, options) agentConfig.ContainerTimeout = resolver.ResolveDefaultValue(agentConfig.ContainerTimeout, options) - agentConfig.InjectGitCredentials = types.StrBool(resolver.ResolveDefaultValue(string(agentConfig.InjectGitCredentials), options)) +} + +func resolveAgentCredentials( + agentConfig *provider.ProviderAgentConfig, + options map[string]string, + devConfig *config.Config, +) { + agentConfig.InjectGitCredentials = types.StrBool( + resolver.ResolveDefaultValue(string(agentConfig.InjectGitCredentials), options), + ) if devConfig.ContextOption(config.ContextOptionSSHInjectGitCredentials) != "" { - agentConfig.InjectGitCredentials = types.StrBool(devConfig.ContextOption(config.ContextOptionSSHInjectGitCredentials)) + agentConfig.InjectGitCredentials = types.StrBool( + devConfig.ContextOption(config.ContextOptionSSHInjectGitCredentials), + ) } - agentConfig.InjectDockerCredentials = types.StrBool(resolver.ResolveDefaultValue(string(agentConfig.InjectDockerCredentials), options)) + agentConfig.InjectDockerCredentials = types.StrBool( + resolver.ResolveDefaultValue(string(agentConfig.InjectDockerCredentials), options), + ) if dockerCredOpt := devConfig.ContextOption(config.ContextOptionSSHInjectDockerCredentials); dockerCredOpt != "" { agentConfig.InjectDockerCredentials = types.StrBool(dockerCredOpt) } - return agentConfig } -// resolveAgentDownloadURL resolves the agent download URL (env -> context -> default) +// resolveAgentDownloadURL resolves the agent download URL (env -> context -> default). func resolveAgentDownloadURL(devConfig *config.Config) string { devPodAgentURL := os.Getenv(agent.EnvDevPodAgentURL) if devPodAgentURL != "" { diff --git a/pkg/binaries/download.go b/pkg/provider/download.go similarity index 62% rename from pkg/binaries/download.go rename to pkg/provider/download.go index 422ff8f41..65d3a055c 100644 --- a/pkg/binaries/download.go +++ b/pkg/provider/download.go @@ -1,8 +1,11 @@ -package binaries +package provider import ( + "errors" "fmt" "io" + "net" + "net/http" "os" "path" "path/filepath" @@ -13,13 +16,12 @@ import ( "github.com/skevetter/devpod/pkg/copy" "github.com/skevetter/devpod/pkg/download" "github.com/skevetter/devpod/pkg/extract" - provider2 "github.com/skevetter/devpod/pkg/provider" "github.com/skevetter/log" "github.com/skevetter/log/hash" + "k8s.io/client-go/util/retry" ) const ( - retryCount = 3 dirPerms = 0750 filePerms = 0755 windowsOS = "windows" @@ -33,18 +35,23 @@ const ( cacheDir = "devpod-binaries" ) +var ( + downloadBackoff = retry.DefaultBackoff + errChecksumVerificationFailed = errors.New("checksum verification failed") +) + type EnvironmentOptions struct { Context string - Workspace *provider2.Workspace - Machine *provider2.Machine + Workspace *Workspace + Machine *Machine Options map[string]config.OptionValue - Config *provider2.ProviderConfig + Config *ProviderConfig ExtraEnv map[string]string Log log.Logger } func ToEnvironmentWithBinaries(opts EnvironmentOptions) ([]string, error) { - environ := provider2.ToEnvironment(opts.Workspace, opts.Machine, opts.Options, opts.ExtraEnv) + environ := ToEnvironment(opts.Workspace, opts.Machine, opts.Options, opts.ExtraEnv) binariesMap, err := GetBinaries(opts.Context, opts.Config) if err != nil { return nil, err @@ -56,7 +63,7 @@ func ToEnvironmentWithBinaries(opts EnvironmentOptions) ([]string, error) { return environ, nil } -func GetBinariesFrom(config *provider2.ProviderConfig, binariesDir string) (map[string]string, error) { +func GetBinariesFrom(config *ProviderConfig, binariesDir string) (map[string]string, error) { retBinaries := map[string]string{} for binaryName, binaryLocations := range config.Binaries { found := false @@ -85,8 +92,8 @@ func GetBinariesFrom(config *provider2.ProviderConfig, binariesDir string) (map[ return retBinaries, nil } -func GetBinaries(context string, config *provider2.ProviderConfig) (map[string]string, error) { - binariesDir, err := provider2.GetProviderBinariesDir(context, config.Name) +func GetBinaries(context string, config *ProviderConfig) (map[string]string, error) { + binariesDir, err := GetProviderBinariesDir(context, config.Name) if err != nil { return nil, err } @@ -95,7 +102,7 @@ func GetBinaries(context string, config *provider2.ProviderConfig) (map[string]s } func DownloadBinaries( - binaries map[string][]*provider2.ProviderBinary, + binaries map[string][]*ProviderBinary, targetFolder string, log log.Logger, ) (map[string]string, error) { @@ -113,7 +120,7 @@ func DownloadBinaries( func downloadBinaryForPlatform( binaryName string, - binaryLocations []*provider2.ProviderBinary, + binaryLocations []*ProviderBinary, targetFolder string, log log.Logger, ) (string, error) { @@ -125,7 +132,8 @@ func downloadBinaryForPlatform( // check if binary is correct binaryTargetFolder := filepath.Join(targetFolder, strings.ToLower(binaryName)) binaryPath := getBinaryPath(binary, binaryTargetFolder) - if verifyBinary(binaryPath, binary.Checksum) || fromCache(binary, binaryTargetFolder, log) { + if verifyOrRemoveBinary(binaryPath, binary.Checksum) || + fromCache(binary, binaryTargetFolder, log) { return binaryPath, nil } @@ -143,34 +151,80 @@ func downloadBinaryForPlatform( func downloadWithRetry( binaryName string, - binary *provider2.ProviderBinary, + binary *ProviderBinary, targetFolder string, log log.Logger, ) (string, error) { - var lastErr error - for range retryCount { - binaryPath, err := downloadBinary(binaryName, binary, targetFolder, log) + var binaryPath string + err := retry.OnError(downloadBackoff, isRetriableError, func() error { + path, err := downloadBinary(binaryName, binary, targetFolder, log) if err != nil { - lastErr = err - continue + return err } if binary.Checksum != "" { - if !verifyDownloadedBinary(binaryPath, binary, binaryName, log) { - lastErr = fmt.Errorf("checksum verification failed") - continue + if !verifyDownloadedBinary(path, binary, binaryName, log) { + return errChecksumVerificationFailed } } - toCache(binary, binaryPath, log) - return binaryPath, nil + binaryPath = path + return nil + }) + if err != nil { + return "", fmt.Errorf("failed to download binary %s: %w", binaryName, err) + } + + toCache(binary, binaryPath, log) + return binaryPath, nil +} + +func isRetriableError(err error) bool { + if err == nil { + return false + } + + if errors.Is(err, errChecksumVerificationFailed) { + return false + } + + var httpErr *download.HTTPStatusError + if errors.As(err, &httpErr) { + return isRetriableHTTPStatus(httpErr.StatusCode) + } + + var netErr net.Error + if errors.As(err, &netErr) { + return netErr.Timeout() + } + + var opErr *net.OpError + if errors.As(err, &opErr) { + return true + } + + if errors.Is(err, io.ErrUnexpectedEOF) { + return true + } + + return false +} + +// isRetriableHTTPStatus checks if the HTTP status code is retriable. +// 408 (Request Timeout) and 429 (Too Many Requests) are retriable. +// 5xx server errors are retriable, 4xx client errors are not. +func isRetriableHTTPStatus(statusCode int) bool { + switch statusCode { + case http.StatusRequestTimeout, http.StatusTooManyRequests: + return true + default: + return statusCode >= http.StatusInternalServerError } - return "", fmt.Errorf("failed to download binary %s after %d attempts: %w", binaryName, retryCount, lastErr) } func verifyDownloadedBinary( binaryPath string, - binary *provider2.ProviderBinary, + binary *ProviderBinary, binaryName string, log log.Logger, ) bool { @@ -189,13 +243,14 @@ func verifyDownloadedBinary( return true } -func toCache(binary *provider2.ProviderBinary, binaryPath string, log log.Logger) { +func toCache(binary *ProviderBinary, binaryPath string, log log.Logger) { if !isRemotePath(binary.Path) { return } cachedBinaryPath := getCachedBinaryPath(binary.Path) if err := os.MkdirAll(filepath.Dir(cachedBinaryPath), dirPerms); err != nil { + log.Warnf("error creating cache directory: %v", err) return } @@ -204,14 +259,14 @@ func toCache(binary *provider2.ProviderBinary, binaryPath string, log log.Logger } } -func fromCache(binary *provider2.ProviderBinary, targetFolder string, log log.Logger) bool { +func fromCache(binary *ProviderBinary, targetFolder string, log log.Logger) bool { if !isRemotePath(binary.Path) { return false } binaryPath := getBinaryPath(binary, targetFolder) cachedBinaryPath := getCachedBinaryPath(binary.Path) - if !verifyBinary(cachedBinaryPath, binary.Checksum) { + if !verifyOrRemoveBinary(cachedBinaryPath, binary.Checksum) { return false } @@ -229,10 +284,15 @@ func fromCache(binary *provider2.ProviderBinary, targetFolder string, log log.Lo } func getCachedBinaryPath(url string) string { - return filepath.Join(os.TempDir(), cacheDir, hash.String(url)[:16]) + h := hash.String(url) + cacheBase := os.TempDir() + if userCache, err := os.UserCacheDir(); err == nil { + cacheBase = userCache + } + return filepath.Join(cacheBase, cacheDir, h) } -func verifyBinary(binaryPath, checksum string) bool { +func verifyOrRemoveBinary(binaryPath, checksum string) bool { _, err := os.Stat(binaryPath) if err != nil { return false @@ -250,7 +310,21 @@ func verifyBinary(binaryPath, checksum string) bool { return true } -func getBinaryPath(binary *provider2.ProviderBinary, targetFolder string) string { +// getBinaryFileName extracts or constructs the binary filename from the ProviderBinary. +// If Name is set, it uses that. Otherwise, it derives the filename from Path and adds +// .exe suffix on Windows if needed. +func getBinaryFileName(binary *ProviderBinary) string { + if binary.Name != "" { + return binary.Name + } + name := path.Base(binary.Path) + if runtime.GOOS == windowsOS && !strings.HasSuffix(name, exeSuffix) { + name += exeSuffix + } + return name +} + +func getBinaryPath(binary *ProviderBinary, targetFolder string) string { if filepath.IsAbs(binary.Path) { return binary.Path } @@ -260,17 +334,40 @@ func getBinaryPath(binary *provider2.ProviderBinary, targetFolder string) string } if binary.ArchivePath != "" { - return path.Join(filepath.ToSlash(targetFolder), binary.ArchivePath) + safePath, err := securePath(targetFolder, binary.ArchivePath) + if err != nil { + return filepath.Join(targetFolder, filepath.Base(binary.ArchivePath)) + } + return safePath } - name := binary.Name - if name == "" { - name = path.Base(binary.Path) - if runtime.GOOS == windowsOS && !strings.HasSuffix(name, exeSuffix) { - name += exeSuffix - } + name := getBinaryFileName(binary) + safePath, err := securePath(targetFolder, name) + if err != nil { + return filepath.Join(targetFolder, filepath.Base(name)) } - return path.Join(filepath.ToSlash(targetFolder), name) + return safePath +} + +// securePath ensures that the resolved path stays within the base directory. +// It protects against path traversal attacks using ../ sequences. +func securePath(baseDir, unsafePath string) (string, error) { + fullPath := filepath.Join(baseDir, filepath.Clean(unsafePath)) + absBase, err := filepath.Abs(baseDir) + if err != nil { + return "", fmt.Errorf("failed to resolve base directory: %w", err) + } + + absPath, err := filepath.Abs(fullPath) + if err != nil { + return "", fmt.Errorf("failed to resolve path: %w", err) + } + + if !strings.HasPrefix(absPath, absBase+string(filepath.Separator)) && absPath != absBase { + return "", fmt.Errorf("path %q escapes base directory", unsafePath) + } + + return absPath, nil } func isRemotePath(p string) bool { @@ -279,26 +376,25 @@ func isRemotePath(p string) bool { func downloadBinary( binaryName string, - binary *provider2.ProviderBinary, + binary *ProviderBinary, targetFolder string, log log.Logger, ) (string, error) { - if _, err := os.Stat(binary.Path); err == nil { - return handleLocalBinary(binary, targetFolder) - } - - if !isRemotePath(binary.Path) { - return handleNonHTTPBinary(binary, targetFolder) + if isRemotePath(binary.Path) { + if err := os.MkdirAll(targetFolder, dirPerms); err != nil { + return "", fmt.Errorf("create folder: %w", err) + } + return downloadRemoteBinary(binaryName, binary, targetFolder, log) } - if err := os.MkdirAll(targetFolder, dirPerms); err != nil { - return "", fmt.Errorf("create folder: %w", err) + if _, err := os.Stat(binary.Path); err == nil { + return handleLocalBinary(binary, targetFolder) } - return downloadRemoteBinary(binaryName, binary, targetFolder, log) + return handleNonHTTPBinary(binary, targetFolder) } -func handleLocalBinary(binary *provider2.ProviderBinary, targetFolder string) (string, error) { +func handleLocalBinary(binary *ProviderBinary, targetFolder string) (string, error) { if filepath.IsAbs(binary.Path) { return binary.Path, nil } @@ -316,7 +412,7 @@ func handleLocalBinary(binary *provider2.ProviderBinary, targetFolder string) (s return targetPath, nil } -func handleNonHTTPBinary(binary *provider2.ProviderBinary, targetFolder string) (string, error) { +func handleNonHTTPBinary(binary *ProviderBinary, targetFolder string) (string, error) { targetPath := localTargetPath(binary, targetFolder) if _, err := os.Stat(targetPath); err == nil { return targetPath, nil @@ -326,7 +422,7 @@ func handleNonHTTPBinary(binary *provider2.ProviderBinary, targetFolder string) func downloadRemoteBinary( binaryName string, - binary *provider2.ProviderBinary, + binary *ProviderBinary, targetFolder string, log log.Logger, ) (string, error) { @@ -340,7 +436,9 @@ func downloadRemoteBinary( } if err != nil { - _ = os.Remove(targetPath) + if targetPath != "" { + _ = os.Remove(targetPath) + } return "", err } @@ -353,29 +451,23 @@ func downloadRemoteBinary( func downloadFile( binaryName string, - binary *provider2.ProviderBinary, + binary *ProviderBinary, targetFolder string, log log.Logger, ) (string, error) { - name := binary.Name - if name == "" { - name = path.Base(binary.Path) - if runtime.GOOS == windowsOS && !strings.HasSuffix(name, exeSuffix) { - name += exeSuffix - } - } - targetPath := path.Join(filepath.ToSlash(targetFolder), name) - _, err := os.Stat(targetPath) - if err == nil { - return targetPath, nil - } + name := getBinaryFileName(binary) + targetPath := filepath.Join(targetFolder, name) + + // Remove any existing file to ensure clean download + // (could be partial download from previous failed attempt) + _ = os.Remove(targetPath) return downloadAndSaveFile(binaryName, binary, targetPath, log) } func downloadAndSaveFile( binaryName string, - binary *provider2.ProviderBinary, + binary *ProviderBinary, targetPath string, log log.Logger, ) (string, error) { @@ -404,16 +496,19 @@ func downloadAndSaveFile( func downloadArchive( binaryName string, - binary *provider2.ProviderBinary, + binary *ProviderBinary, targetFolder string, log log.Logger, ) (string, error) { - targetPath := path.Join(filepath.ToSlash(targetFolder), binary.ArchivePath) - _, err := os.Stat(targetPath) - if err == nil { - return targetPath, nil + targetPath, err := securePath(targetFolder, binary.ArchivePath) + if err != nil { + return "", fmt.Errorf("invalid archive path %q: %w", binary.ArchivePath, err) } + // Remove any existing file to ensure clean download + // (could be partial download from previous failed attempt) + _ = os.Remove(targetPath) + return extractArchive(archiveDownloadParams{ binaryName: binaryName, binary: binary, @@ -425,7 +520,7 @@ func downloadArchive( type archiveDownloadParams struct { binaryName string - binary *provider2.ProviderBinary + binary *ProviderBinary targetFolder string targetPath string log log.Logger @@ -485,7 +580,7 @@ func extractZipArchive(body io.ReadCloser, targetFolder, targetPath string) (str } func downloadToTempFile(reader io.Reader) (string, error) { - tempFile, err := os.CreateTemp("", "") + tempFile, err := os.CreateTemp("", "devpod-archive-*") if err != nil { return "", err } @@ -499,7 +594,9 @@ func downloadToTempFile(reader io.Reader) (string, error) { return tempFile.Name(), nil } -func localTargetPath(binary *provider2.ProviderBinary, targetFolder string) string { +// localTargetPath constructs the target path for a local binary. +// Note: Does not add .exe suffix as local paths already have correct extensions. +func localTargetPath(binary *ProviderBinary, targetFolder string) string { name := binary.Name if name == "" { name = path.Base(binary.Path) @@ -507,14 +604,15 @@ func localTargetPath(binary *provider2.ProviderBinary, targetFolder string) stri return filepath.Join(targetFolder, name) } -func copyLocal(binary *provider2.ProviderBinary, targetPath string) error { +func copyLocal(binary *ProviderBinary, targetPath string) error { targetPathStat, err := os.Stat(targetPath) if err == nil { binaryStat, err := os.Stat(binary.Path) if err != nil { return err } - if targetPathStat.Size() == binaryStat.Size() { + if targetPathStat.Size() == binaryStat.Size() && + !binaryStat.ModTime().After(targetPathStat.ModTime()) { return nil } } diff --git a/pkg/workspace/provider.go b/pkg/workspace/provider.go index 903b0acbd..fceb1b0bc 100644 --- a/pkg/workspace/provider.go +++ b/pkg/workspace/provider.go @@ -11,11 +11,10 @@ import ( "strings" devpodhttp "github.com/skevetter/devpod/pkg/http" - providerpkg "github.com/skevetter/devpod/pkg/provider" + "github.com/skevetter/devpod/pkg/provider" "github.com/skevetter/devpod/pkg/types" "github.com/skevetter/devpod/providers" - "github.com/skevetter/devpod/pkg/binaries" "github.com/skevetter/devpod/pkg/config" "github.com/skevetter/devpod/pkg/download" "github.com/skevetter/log" @@ -36,15 +35,15 @@ var ( ) type ProviderWithOptions struct { - Config *providerpkg.ProviderConfig `json:"config,omitempty"` - State *config.ProviderConfig `json:"state,omitempty"` + Config *provider.ProviderConfig `json:"config,omitempty"` + State *config.ProviderConfig `json:"state,omitempty"` } type ProviderParams struct { DevPodConfig *config.Config ProviderName string Raw []byte - Source *providerpkg.ProviderSource + Source *provider.ProviderSource Log log.Logger } @@ -98,28 +97,28 @@ func ProviderFromHost( devPodConfig *config.Config, proHost string, log log.Logger, -) (*providerpkg.ProviderConfig, error) { - proInstanceConfig, err := providerpkg.LoadProInstanceConfig(devPodConfig.DefaultContext, proHost) +) (*provider.ProviderConfig, error) { + proInstanceConfig, err := provider.LoadProInstanceConfig(devPodConfig.DefaultContext, proHost) if err != nil { return nil, fmt.Errorf("load pro instance %s: %w", proHost, err) } - provider, err := FindProvider(devPodConfig, proInstanceConfig.Provider, log) + foundProvider, err := FindProvider(devPodConfig, proInstanceConfig.Provider, log) if err != nil { return nil, fmt.Errorf("find provider: %w", err) } - if !provider.Config.IsProxyProvider() && !provider.Config.IsDaemonProvider() { + if !foundProvider.Config.IsProxyProvider() && !foundProvider.Config.IsDaemonProvider() { return nil, fmt.Errorf("provider is not a pro provider") } - return provider.Config, nil + return foundProvider.Config, nil } func AddProvider( devPodConfig *config.Config, providerName, providerSourceRaw string, log log.Logger, -) (*providerpkg.ProviderConfig, error) { +) (*provider.ProviderConfig, error) { providerRaw, providerSource, err := ResolveProvider(providerSourceRaw, log) if err != nil { return nil, err @@ -134,7 +133,7 @@ func AddProvider( }) } -func AddProviderRaw(p ProviderParams) (*providerpkg.ProviderConfig, error) { +func AddProviderRaw(p ProviderParams) (*provider.ProviderConfig, error) { providerConfig, err := installRawProvider(p) if err != nil { return nil, err @@ -160,7 +159,7 @@ func UpdateProvider( devPodConfig *config.Config, providerName, providerSourceRaw string, log log.Logger, -) (*providerpkg.ProviderConfig, error) { +) (*provider.ProviderConfig, error) { if devPodConfig.Current().Providers[providerName] == nil { return nil, fmt.Errorf("provider %s not found", providerName) } @@ -227,36 +226,50 @@ func ResolveProviderSource(devPodConfig *config.Config, providerName string, log return source, nil } -func ResolveProvider(providerSource string, log log.Logger) ([]byte, *providerpkg.ProviderSource, error) { - retSource := &providerpkg.ProviderSource{Raw: strings.TrimSpace(providerSource)} +func ResolveProvider(providerSource string, log log.Logger) ([]byte, *provider.ProviderSource, error) { + retSource := &provider.ProviderSource{Raw: strings.TrimSpace(providerSource)} if out, ok := resolveInternalProvider(providerSource, retSource); ok { return out, retSource, nil } - if out, ok, err := resolveURLProvider(providerSource, retSource, log); ok { - if err != nil { - return nil, nil, err - } - return out, retSource, nil + if out, err := tryResolveURLProvider(providerSource, retSource, log); hasOutputOrError(out, err) { + return out, retSource, err } - if out, ok := resolveFileProvider(providerSource, retSource); ok { - return out, retSource, nil + if out, err := tryResolveFileProvider(providerSource, retSource); hasOutputOrError(out, err) { + return out, retSource, err } out, source, err := downloadProviderGithub(providerSource, log) - if err != nil { - return nil, nil, fmt.Errorf("download github: %w", err) - } - if len(out) > 0 { - return out, source, nil + if len(out) > 0 || err != nil { + return out, source, err } return nil, nil, fmt.Errorf("provider type not recognized: specify a local file, url, or github repository") } -func downloadProviderGithub(originalPath string, log log.Logger) ([]byte, *providerpkg.ProviderSource, error) { +func hasOutputOrError(out []byte, err error) bool { + return out != nil || err != nil +} + +func tryResolveURLProvider(providerSource string, retSource *provider.ProviderSource, log log.Logger) ([]byte, error) { + out, ok, err := resolveURLProvider(providerSource, retSource, log) + if !ok { + return nil, nil + } + return out, err +} + +func tryResolveFileProvider(providerSource string, retSource *provider.ProviderSource) ([]byte, error) { + out, ok, err := resolveFileProvider(providerSource, retSource) + if !ok { + return nil, nil + } + return out, err +} + +func downloadProviderGithub(originalPath string, log log.Logger) ([]byte, *provider.ProviderSource, error) { path := strings.TrimPrefix(originalPath, githubPrefix) release := "" @@ -270,7 +283,10 @@ func downloadProviderGithub(originalPath string, log log.Logger) ([]byte, *provi if len(splitted) == 1 { path = providerPrefix + path } else if len(splitted) != 2 { - return nil, nil, nil + return nil, nil, fmt.Errorf( + "invalid github path format: expected 'owner/repo' or 'provider-name', got %q", + originalPath, + ) } requestURL := buildGithubURL(path, release) @@ -286,7 +302,7 @@ func downloadProviderGithub(originalPath string, log log.Logger) ([]byte, *provi return nil, nil, err } - return out, &providerpkg.ProviderSource{ + return out, &provider.ProviderSource{ Raw: originalPath, Github: path, }, nil @@ -304,7 +320,7 @@ func loadConfiguredProviders( continue } - providerConfig, err := providerpkg.LoadProviderConfig(devPodConfig.DefaultContext, providerName) + providerConfig, err := provider.LoadProviderConfig(devPodConfig.DefaultContext, providerName) if err != nil { log.Warnf("error loading provider %s: %v", providerName, err) continue @@ -318,7 +334,7 @@ func loadConfiguredProviders( } func loadUnconfiguredProviders(devPodConfig *config.Config, retProviders map[string]*ProviderWithOptions) error { - providerDir, err := providerpkg.GetProvidersDir(devPodConfig.DefaultContext) + providerDir, err := provider.GetProvidersDir(devPodConfig.DefaultContext) if err != nil { return err } @@ -350,7 +366,7 @@ func loadProviderEntry( entry os.DirEntry, retProviders map[string]*ProviderWithOptions, ) error { - providerConfig, err := providerpkg.LoadProviderConfig(devPodConfig.DefaultContext, entry.Name()) + providerConfig, err := provider.LoadProviderConfig(devPodConfig.DefaultContext, entry.Name()) if err != nil { return err } @@ -362,8 +378,8 @@ func loadProviderEntry( return nil } -func installRawProvider(p ProviderParams) (*providerpkg.ProviderConfig, error) { - providerConfig, err := providerpkg.ParseProvider(bytes.NewReader(p.Raw)) +func installRawProvider(p ProviderParams) (*provider.ProviderConfig, error) { + providerConfig, err := provider.ParseProvider(bytes.NewReader(p.Raw)) if err != nil { return nil, err } @@ -377,8 +393,12 @@ func installRawProvider(p ProviderParams) (*providerpkg.ProviderConfig, error) { func installProvider( p ProviderParams, - providerConfig *providerpkg.ProviderConfig, -) (*providerpkg.ProviderConfig, error) { + providerConfig *provider.ProviderConfig, +) (*provider.ProviderConfig, error) { + if p.Source == nil { + return nil, fmt.Errorf("provider source is required") + } + providerConfig.Source = *p.Source if p.ProviderName != "" { providerConfig.Name = p.ProviderName @@ -395,33 +415,41 @@ func installProvider( return providerConfig, nil } -func updateProvider(p ProviderParams) (*providerpkg.ProviderConfig, error) { - providerConfig, err := providerpkg.ParseProvider(bytes.NewReader(p.Raw)) +func updateProvider(p ProviderParams) (*provider.ProviderConfig, error) { + providerConfig, err := parseAndValidateProvider(p) if err != nil { return nil, err } - providerConfig.Source = *p.Source - if p.ProviderName != "" { - providerConfig.Name = p.ProviderName - } - if providerConfig.Options == nil { - providerConfig.Options = map[string]*types.Option{} - } - cleanupOldOptions(p.DevPodConfig, providerConfig) if err := config.SaveConfig(p.DevPodConfig); err != nil { return nil, err } - if err := downloadProviderBinaries(p, providerConfig); err != nil { + if err := downloadAndSaveProvider(p, providerConfig); err != nil { return nil, err } - if err := providerpkg.SaveProviderConfig(p.DevPodConfig.DefaultContext, providerConfig); err != nil { + return providerConfig, nil +} + +func parseAndValidateProvider(p ProviderParams) (*provider.ProviderConfig, error) { + providerConfig, err := provider.ParseProvider(bytes.NewReader(p.Raw)) + if err != nil { return nil, err } + if p.Source == nil { + return nil, fmt.Errorf("provider source is required") + } + + providerConfig.Source = *p.Source + if p.ProviderName != "" { + providerConfig.Name = p.ProviderName + } + if providerConfig.Options == nil { + providerConfig.Options = map[string]*types.Option{} + } return providerConfig, nil } @@ -431,7 +459,7 @@ func checkProviderNotExists(devPodConfig *config.Config, providerName string) er return fmt.Errorf("provider %s already exists", providerName) } - providerDir, err := providerpkg.GetProviderDir(devPodConfig.DefaultContext, providerName) + providerDir, err := provider.GetProviderDir(devPodConfig.DefaultContext, providerName) if err != nil { return err } @@ -443,48 +471,39 @@ func checkProviderNotExists(devPodConfig *config.Config, providerName string) er return nil } -func downloadAndSaveProvider(p ProviderParams, providerConfig *providerpkg.ProviderConfig) error { - binariesDir, err := providerpkg.GetProviderBinariesDir(p.DevPodConfig.DefaultContext, providerConfig.Name) +func downloadAndSaveProvider(p ProviderParams, providerConfig *provider.ProviderConfig) error { + binariesDir, err := provider.GetProviderBinariesDir(p.DevPodConfig.DefaultContext, providerConfig.Name) if err != nil { return fmt.Errorf("get binaries dir: %w", err) } - providerDir, err := providerpkg.GetProviderDir(p.DevPodConfig.DefaultContext, providerConfig.Name) + providerDir, err := provider.GetProviderDir(p.DevPodConfig.DefaultContext, providerConfig.Name) if err != nil { return fmt.Errorf("get provider dir: %w", err) } - if _, err := binaries.DownloadBinaries(providerConfig.Binaries, binariesDir, p.Log); err != nil { + if _, err := provider.DownloadBinaries(providerConfig.Binaries, binariesDir, p.Log); err != nil { _ = os.RemoveAll(providerDir) return fmt.Errorf("download binaries: %w", err) } - return providerpkg.SaveProviderConfig(p.DevPodConfig.DefaultContext, providerConfig) -} - -func cleanupOldOptions(devPodConfig *config.Config, providerConfig *providerpkg.ProviderConfig) { - for optionName := range devPodConfig.Current().Providers[providerConfig.Name].Options { - if _, ok := providerConfig.Options[optionName]; !ok { - delete(devPodConfig.Current().Providers[providerConfig.Name].Options, optionName) - } - } + return provider.SaveProviderConfig(p.DevPodConfig.DefaultContext, providerConfig) } -func downloadProviderBinaries(p ProviderParams, providerConfig *providerpkg.ProviderConfig) error { - binariesDir, err := providerpkg.GetProviderBinariesDir(p.DevPodConfig.DefaultContext, providerConfig.Name) - if err != nil { - return fmt.Errorf("get binaries dir: %w", err) +func cleanupOldOptions(devPodConfig *config.Config, providerConfig *provider.ProviderConfig) { + providerState := devPodConfig.Current().Providers[providerConfig.Name] + if providerState == nil || providerState.Options == nil { + return } - if _, err := binaries.DownloadBinaries(providerConfig.Binaries, binariesDir, p.Log); err != nil { - _ = os.RemoveAll(binariesDir) - return fmt.Errorf("download binaries: %w", err) + for optionName := range providerState.Options { + if _, ok := providerConfig.Options[optionName]; !ok { + delete(providerState.Options, optionName) + } } - - return nil } -func getProviderSource(src providerpkg.ProviderSource, configName string) string { +func getProviderSource(src provider.ProviderSource, configName string) string { switch { case src.Internal: if src.Raw == "" { @@ -502,7 +521,7 @@ func getProviderSource(src providerpkg.ProviderSource, configName string) string } } -func resolveInternalProvider(providerSource string, retSource *providerpkg.ProviderSource) ([]byte, bool) { +func resolveInternalProvider(providerSource string, retSource *provider.ProviderSource) ([]byte, bool) { internalProviders := providers.GetBuiltInProviders() if internalProviders[providerSource] != "" { retSource.Internal = true @@ -513,7 +532,7 @@ func resolveInternalProvider(providerSource string, retSource *providerpkg.Provi func resolveURLProvider( providerSource string, - retSource *providerpkg.ProviderSource, + retSource *provider.ProviderSource, log log.Logger, ) ([]byte, bool, error) { if !strings.HasPrefix(providerSource, httpPrefix) && !strings.HasPrefix(providerSource, httpsPrefix) { @@ -529,27 +548,30 @@ func resolveURLProvider( return out, true, nil } -func resolveFileProvider(providerSource string, retSource *providerpkg.ProviderSource) ([]byte, bool) { +func resolveFileProvider(providerSource string, retSource *provider.ProviderSource) ([]byte, bool, error) { if !strings.HasSuffix(providerSource, yamlExt) && !strings.HasSuffix(providerSource, ymlExt) { - return nil, false + return nil, false, nil } if _, err := os.Stat(providerSource); err != nil { - return nil, false + if os.IsNotExist(err) { + return nil, false, nil + } + return nil, true, fmt.Errorf("stat provider file %q: %w", providerSource, err) } // #nosec G304 - providerSource is user-provided path for loading provider config out, err := os.ReadFile(providerSource) if err != nil { - return nil, false + return nil, true, fmt.Errorf("read provider file %q: %w", providerSource, err) } absPath, err := filepath.Abs(providerSource) if err != nil { - return nil, false + return nil, true, fmt.Errorf("resolve absolute path for %q: %w", providerSource, err) } retSource.File = absPath - return out, true + return out, true, nil } func downloadProvider(url string) ([]byte, error) { diff --git a/pkg/workspace/workspace.go b/pkg/workspace/workspace.go index 471f64913..1eedf18fb 100644 --- a/pkg/workspace/workspace.go +++ b/pkg/workspace/workspace.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "os" "sort" "strings" @@ -30,6 +31,13 @@ var errProvideWorkspaceArg = errors.New( "please provide a workspace name. E.g. 'devpod up ./my-folder', " + "'devpod up github.com/my-org/my-repo' or 'devpod up ubuntu'") +// RemoteCreator defines the interface for clients that support remote workspace creation. +// This interface is implemented by ProxyClient and DaemonClient to enable workspace +// creation on remote platforms. +type RemoteCreator interface { + Create(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer) error +} + // Resolve takes the `devpod up|build` CLI input and either finds an existing workspace or creates a new one type ResolveParams struct { IDE string @@ -382,7 +390,16 @@ func createWorkspace( return nil, nil, nil, fmt.Errorf("save config: %w", err) } - err := resolveProInstance(ctx, devPodConfig, provider.Config.Name, workspace, log) + err := resolveProInstance(proInstanceParams{ + ctx: ctx, + devPodConfig: devPodConfig, + providerName: provider.Config.Name, + workspace: workspace, + stdin: os.Stdin, + stdout: os.Stdout, + stderr: os.Stderr, + log: log, + }) if err != nil { return nil, nil, nil, err } @@ -688,23 +705,43 @@ func loadExistingWorkspace(devPodConfig *config.Config, workspaceID string, chan return providerWithOptions.Config, workspaceConfig, machineConfig, nil } -func resolveProInstance(ctx context.Context, devPodConfig *config.Config, providerName string, workspace *providerpkg.Workspace, log log.Logger) error { - provider, err := FindProvider(devPodConfig, providerName, log) +type proInstanceParams struct { + ctx context.Context + devPodConfig *config.Config + providerName string + workspace *providerpkg.Workspace + stdin io.Reader + stdout io.Writer + stderr io.Writer + log log.Logger +} + +func resolveProInstance(params proInstanceParams) error { + foundProvider, err := FindProvider(params.devPodConfig, params.providerName, params.log) if err != nil { return err } - workspaceClient, err := getWorkspaceClient(devPodConfig, provider.Config, workspace, nil, log) + workspaceClient, err := getWorkspaceClient( + params.devPodConfig, + foundProvider.Config, + params.workspace, + nil, + params.log, + ) if err != nil { return err } - switch c := workspaceClient.(type) { - case client.ProxyClient: - return c.Create(ctx, os.Stdin, os.Stdout, os.Stderr) - case client.DaemonClient: - return c.Create(ctx, os.Stdin, os.Stdout, os.Stderr) - default: - return fmt.Errorf("client does not support remote workspaces") + if c, ok := workspaceClient.(RemoteCreator); ok { + return c.Create(params.ctx, params.stdin, params.stdout, params.stderr) } + + // This should never happen - indicates a programming error where a proxy/daemon provider + // client does not implement the RemoteCreator interface + return fmt.Errorf( + "internal error: client %T for provider %q does not implement RemoteCreator interface", + workspaceClient, + params.providerName, + ) }