diff --git a/docker-compose.yml b/docker-compose.yml index 00a859f15..1305d3177 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -122,6 +122,10 @@ services: - WEKNORA_SANDBOX_MODE=${WEKNORA_SANDBOX_MODE:-docker} - WEKNORA_SANDBOX_TIMEOUT=${WEKNORA_SANDBOX_TIMEOUT:-60} - WEKNORA_SANDBOX_DOCKER_IMAGE=${WEKNORA_SANDBOX_DOCKER_IMAGE:-wechatopenai/weknora-sandbox:${WEKNORA_VERSION:-latest}} + - WEKNORA_SANDBOX_IMAGE=${WEKNORA_SANDBOX_IMAGE:-} + - WEKNORA_SANDBOX_MAX_CONCURRENT=${WEKNORA_SANDBOX_MAX_CONCURRENT:-} + - WEKNORA_SANDBOX_OPENSANDBOX_API_URL=${WEKNORA_SANDBOX_OPENSANDBOX_API_URL:-} + - WEKNORA_SANDBOX_OPENSANDBOX_API_KEY=${WEKNORA_SANDBOX_OPENSANDBOX_API_KEY:-} - APK_MIRROR_ARG=${APK_MIRROR_ARG:-} depends_on: redis: diff --git a/helm/templates/app.yaml b/helm/templates/app.yaml index b475fcefc..d4a7a0d8f 100644 --- a/helm/templates/app.yaml +++ b/helm/templates/app.yaml @@ -29,6 +29,9 @@ spec: spec: {{- include "weknora.imagePullSecrets" . | nindent 6 }} serviceAccountName: {{ include "weknora.serviceAccountName" . }} + {{- if eq .Values.app.sandbox.mode "kubernetes" }} + automountServiceAccountToken: true + {{- end }} {{- with .Values.app.podSecurityContext | default .Values.global.podSecurityContext }} securityContext: {{- toYaml . | nindent 8 }} @@ -141,6 +144,32 @@ spec: name: {{ include "weknora.secretName" . }} key: NEO4J_PASSWORD {{- end }} + # Sandbox configuration + - name: WEKNORA_SANDBOX_MODE + value: {{ .Values.app.sandbox.mode | quote }} + - name: WEKNORA_SANDBOX_TIMEOUT + value: {{ .Values.app.sandbox.timeout | quote }} + - name: WEKNORA_SANDBOX_IMAGE + value: {{ .Values.app.sandbox.image | quote }} + {{- if eq .Values.app.sandbox.mode "kubernetes" }} + - name: WEKNORA_SANDBOX_KUBE_NAMESPACE + value: {{ .Values.app.sandbox.kubernetes.namespace | quote }} + - name: WEKNORA_SANDBOX_MAX_CONCURRENT + value: {{ .Values.app.sandbox.kubernetes.maxConcurrentSandboxes | quote }} + {{- end }} + {{- if eq .Values.app.sandbox.mode "opensandbox" }} + - name: WEKNORA_SANDBOX_OPENSANDBOX_API_URL + value: {{ .Values.app.sandbox.opensandbox.apiURL | quote }} + - name: WEKNORA_SANDBOX_OPENSANDBOX_API_KEY + {{- if .Values.app.sandbox.opensandbox.existingSecret }} + valueFrom: + secretKeyRef: + name: {{ .Values.app.sandbox.opensandbox.existingSecret }} + key: apiKey + {{- else }} + value: {{ .Values.app.sandbox.opensandbox.apiKey | quote }} + {{- end }} + {{- end }} {{- with .Values.app.extraEnv }} # Additional environment variables {{- toYaml . | nindent 12 }} diff --git a/helm/templates/sandbox-rbac.yaml b/helm/templates/sandbox-rbac.yaml new file mode 100644 index 000000000..563389a20 --- /dev/null +++ b/helm/templates/sandbox-rbac.yaml @@ -0,0 +1,53 @@ +{{/* +Sandbox RBAC resources for kubernetes sandbox mode. +Only created when sandbox mode is "kubernetes". +*/}} +{{- if eq .Values.app.sandbox.mode "kubernetes" }} +--- +apiVersion: v1 +kind: Namespace +metadata: + name: {{ .Values.app.sandbox.kubernetes.namespace }} + labels: + app.kubernetes.io/managed-by: {{ .Release.Service }} + app.kubernetes.io/part-of: weknora +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: weknora-sandbox-runner + namespace: {{ .Values.app.sandbox.kubernetes.namespace }} + labels: + app.kubernetes.io/managed-by: {{ .Release.Service }} +automountServiceAccountToken: false +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: weknora-sandbox-manager + namespace: {{ .Values.app.sandbox.kubernetes.namespace }} +rules: + - apiGroups: ["batch"] + resources: ["jobs"] + verbs: ["create", "delete", "get", "list", "watch"] + - apiGroups: [""] + resources: ["configmaps"] + verbs: ["create", "delete", "get", "list"] + - apiGroups: [""] + resources: ["pods", "pods/log"] + verbs: ["get", "list"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: weknora-sandbox-manager + namespace: {{ .Values.app.sandbox.kubernetes.namespace }} +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: Role + name: weknora-sandbox-manager +subjects: + - kind: ServiceAccount + name: {{ include "weknora.serviceAccountName" . }} + namespace: {{ .Release.Namespace }} +{{- end }} diff --git a/helm/templates/sandbox-resources.yaml b/helm/templates/sandbox-resources.yaml new file mode 100644 index 000000000..248fcaf2f --- /dev/null +++ b/helm/templates/sandbox-resources.yaml @@ -0,0 +1,30 @@ +{{/* +Sandbox ResourceQuota and NetworkPolicy for kubernetes sandbox mode. +Only created when sandbox mode is "kubernetes". +*/}} +{{- if eq .Values.app.sandbox.mode "kubernetes" }} +--- +apiVersion: v1 +kind: ResourceQuota +metadata: + name: weknora-sandbox-quota + namespace: {{ .Values.app.sandbox.kubernetes.namespace }} +spec: + hard: + pods: {{ .Values.app.sandbox.kubernetes.quota.maxPods | quote }} + limits.memory: {{ .Values.app.sandbox.kubernetes.quota.maxMemory | quote }} + limits.cpu: {{ .Values.app.sandbox.kubernetes.quota.maxCPU | quote }} +--- +apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: weknora-sandbox-deny-all + namespace: {{ .Values.app.sandbox.kubernetes.namespace }} +spec: + podSelector: + matchLabels: + app.kubernetes.io/component: sandbox + policyTypes: + - Ingress + - Egress +{{- end }} diff --git a/helm/values.yaml b/helm/values.yaml index 70aea7b84..0904886f7 100644 --- a/helm/values.yaml +++ b/helm/values.yaml @@ -104,6 +104,43 @@ app: # - name: OLLAMA_BASE_URL # value: "http://ollama:11434" + # -- Sandbox configuration for Agent Skills script execution + sandbox: + # -- Sandbox mode: local, docker, kubernetes, opensandbox, disabled + # In Kubernetes, "local" is recommended as the baseline. + # Use "kubernetes" for container-level isolation without Docker daemon. + # Use "opensandbox" to integrate with an external OpenSandbox service. + mode: local + # -- Script execution timeout in seconds + timeout: 60 + # -- Sandbox container image (used by docker/kubernetes modes) + image: wechatopenai/weknora-sandbox:latest + # -- Kubernetes sandbox mode configuration + kubernetes: + # -- Namespace for sandbox Pods + namespace: weknora-sandbox + # -- Max concurrent sandbox Jobs per WeKnora instance + maxConcurrentSandboxes: 5 + # -- Resource limits per sandbox Pod + podResources: + limits: + cpu: "500m" + memory: "256Mi" + # -- ResourceQuota for sandbox namespace + quota: + maxPods: 10 + maxMemory: "2Gi" + maxCPU: "4" + # -- OpenSandbox mode configuration + opensandbox: + # -- OpenSandbox Server URL + apiURL: "" + # -- API Key (plain text, used when existingSecret is not set) + apiKey: "" + # -- Use an existing Secret for the API key (recommended for production) + # The Secret must contain a key named "apiKey" + existingSecret: "" + # -- Service configuration service: type: ClusterIP diff --git a/internal/application/service/agent_service.go b/internal/application/service/agent_service.go index 936a7703c..8b44e0b07 100644 --- a/internal/application/service/agent_service.go +++ b/internal/application/service/agent_service.go @@ -251,7 +251,10 @@ func (s *agentService) initializeSkillsManager( if sandboxMode == "" { sandboxMode = "disabled" } - dockerImage := os.Getenv("WEKNORA_SANDBOX_DOCKER_IMAGE") + dockerImage := os.Getenv("WEKNORA_SANDBOX_IMAGE") + if dockerImage == "" { + dockerImage = os.Getenv("WEKNORA_SANDBOX_DOCKER_IMAGE") // backward compat + } if dockerImage == "" { dockerImage = sandbox.DefaultDockerImage } @@ -263,6 +266,19 @@ func (s *agentService) initializeSkillsManager( } } + // Read additional env vars for new sandbox modes + kubeNamespace := os.Getenv("WEKNORA_SANDBOX_KUBE_NAMESPACE") + kubeServiceAccount := os.Getenv("WEKNORA_SANDBOX_KUBE_SERVICE_ACCOUNT") + maxConcurrentStr := os.Getenv("WEKNORA_SANDBOX_MAX_CONCURRENT") + maxConcurrent := sandbox.DefaultMaxConcurrentSandboxes + if maxConcurrentStr != "" { + if v, err := strconv.Atoi(maxConcurrentStr); err == nil && v > 0 { + maxConcurrent = v + } + } + opensandboxAPIURL := os.Getenv("WEKNORA_SANDBOX_OPENSANDBOX_API_URL") + opensandboxAPIKey := os.Getenv("WEKNORA_SANDBOX_OPENSANDBOX_API_KEY") + switch sandboxMode { case "docker": sandboxMgr, err = sandbox.NewManagerFromType("docker", true, dockerImage) // Enable fallback to local @@ -276,6 +292,36 @@ func (s *agentService) initializeSkillsManager( logger.Warnf(ctx, "Failed to initialize local sandbox: %v", err) sandboxMgr = sandbox.NewDisabledManager() } + case "kubernetes": + config := sandbox.DefaultConfig() + config.Type = sandbox.SandboxTypeKubernetes + config.FallbackEnabled = true + config.DockerImage = dockerImage + if kubeNamespace != "" { + config.KubeNamespace = kubeNamespace + } + if kubeServiceAccount != "" { + config.KubeServiceAccount = kubeServiceAccount + } + config.MaxConcurrentSandboxes = maxConcurrent + sandboxMgr, err = sandbox.NewManager(config) + if err != nil { + logger.Warnf(ctx, "Failed to initialize kubernetes sandbox, falling back to disabled: %v", err) + sandboxMgr = sandbox.NewDisabledManager() + } + case "opensandbox": + config := sandbox.DefaultConfig() + config.Type = sandbox.SandboxTypeOpenSandbox + config.FallbackEnabled = true + config.DockerImage = dockerImage + config.OpenSandboxAPIURL = opensandboxAPIURL + config.OpenSandboxAPIKey = opensandboxAPIKey + config.MaxConcurrentSandboxes = maxConcurrent + sandboxMgr, err = sandbox.NewManager(config) + if err != nil { + logger.Warnf(ctx, "Failed to initialize opensandbox, falling back to disabled: %v", err) + sandboxMgr = sandbox.NewDisabledManager() + } default: sandboxMgr = sandbox.NewDisabledManager() } diff --git a/internal/sandbox/kubeclient.go b/internal/sandbox/kubeclient.go new file mode 100644 index 000000000..8515cc405 --- /dev/null +++ b/internal/sandbox/kubeclient.go @@ -0,0 +1,471 @@ +package sandbox + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strconv" + "sync" + "time" +) + +const ( + saTokenPath = "/var/run/secrets/kubernetes.io/serviceaccount/token" + saCACertPath = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" +) + +const ( + tokenCacheTTL = 30 * time.Second + maxResponseBytes = 10 * 1024 * 1024 // 10MB cap on K8s API response reads +) + +// kubeClient is a lightweight Kubernetes API client using only net/http + encoding/json. +// For in-cluster use, tokenPath is set so the token is re-read periodically +// to handle bound SA token rotation (tokens expire, typically after 1 hour). +type kubeClient struct { + apiServer string + token string + tokenPath string + tokenMu sync.Mutex // guards token, tokenCachedAt + tokenCachedAt time.Time + httpClient *http.Client +} + +// newKubeClientInCluster reads the service account token and CA cert from the default paths +// and returns a kubeClient configured for in-cluster use. +func newKubeClientInCluster() (*kubeClient, error) { + // Verify the token file is readable at init time. + if _, err := os.ReadFile(saTokenPath); err != nil { + return nil, fmt.Errorf("failed to read SA token: %w", err) + } + + caBytes, err := os.ReadFile(saCACertPath) + if err != nil { + return nil, fmt.Errorf("failed to read CA cert: %w", err) + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caBytes) { + return nil, fmt.Errorf("failed to parse CA cert") + } + + httpClient := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: pool, + }, + }, + } + + return &kubeClient{ + apiServer: "https://kubernetes.default.svc", + tokenPath: saTokenPath, + httpClient: httpClient, + }, nil +} + +// newKubeClient creates a kubeClient with the given apiServer, token, and optional httpClient. +// If httpClient is nil, http.DefaultClient is used. +func newKubeClient(apiServer, token string, httpClient *http.Client) *kubeClient { + if httpClient == nil { + httpClient = http.DefaultClient + } + return &kubeClient{ + apiServer: apiServer, + token: token, + httpClient: httpClient, + } +} + +// getToken returns the current bearer token. If tokenPath is set (in-cluster mode), +// it re-reads the file periodically (every tokenCacheTTL) to handle bound SA token rotation. +func (c *kubeClient) getToken() (string, error) { + if c.tokenPath != "" { + c.tokenMu.Lock() + defer c.tokenMu.Unlock() + if c.token != "" && time.Since(c.tokenCachedAt) < tokenCacheTTL { + return c.token, nil + } + data, err := os.ReadFile(c.tokenPath) + if err != nil { + return "", err + } + c.token = string(data) + c.tokenCachedAt = time.Now() + return c.token, nil + } + return c.token, nil +} + +// inClusterAvailable returns true if the SA token file exists. +func inClusterAvailable() bool { + _, err := os.Stat(saTokenPath) + return err == nil +} + +// do performs an HTTP request to the K8s API with Bearer token auth. +// Returns the response body bytes, HTTP status code, and any error. +func (c *kubeClient) do(ctx context.Context, method, path string, body any) ([]byte, int, error) { + var reqBody io.Reader + if body != nil { + data, err := json.Marshal(body) + if err != nil { + return nil, 0, fmt.Errorf("failed to marshal request body: %w", err) + } + reqBody = bytes.NewReader(data) + } + + req, err := http.NewRequestWithContext(ctx, method, c.apiServer+path, reqBody) + if err != nil { + return nil, 0, fmt.Errorf("failed to create request: %w", err) + } + + token, err := c.getToken() + if err != nil { + return nil, 0, fmt.Errorf("failed to read token: %w", err) + } + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, 0, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + respBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + if err != nil { + return nil, resp.StatusCode, fmt.Errorf("failed to read response: %w", err) + } + + return respBytes, resp.StatusCode, nil +} + +// createConfigMap creates a ConfigMap in the given namespace. +func (c *kubeClient) createConfigMap(ctx context.Context, namespace, name string, data map[string]string, labels map[string]string) error { + cm := map[string]any{ + "apiVersion": "v1", + "kind": "ConfigMap", + "metadata": map[string]any{ + "name": name, + "namespace": namespace, + "labels": labels, + }, + "data": data, + } + + path := fmt.Sprintf("/api/v1/namespaces/%s/configmaps", namespace) + _, status, err := c.do(ctx, http.MethodPost, path, cm) + if err != nil { + return err + } + if status != http.StatusCreated && status != http.StatusOK { + return fmt.Errorf("createConfigMap: unexpected status %d", status) + } + return nil +} + +// deleteConfigMap deletes a ConfigMap by name in the given namespace. +func (c *kubeClient) deleteConfigMap(ctx context.Context, namespace, name string) error { + path := fmt.Sprintf("/api/v1/namespaces/%s/configmaps/%s", namespace, name) + _, status, err := c.do(ctx, http.MethodDelete, path, nil) + if err != nil { + return err + } + if status != http.StatusOK && status != http.StatusAccepted && status != http.StatusNoContent { + return fmt.Errorf("deleteConfigMap: unexpected status %d", status) + } + return nil +} + +// jobSpec holds the parameters needed to create a Kubernetes Job. +type jobSpec struct { + Name string + Image string + Command []string + ConfigMapName string + TimeoutSeconds int + ServiceAccountName string + MemoryLimit string + CPULimit string +} + +// jobStatus represents the current state of a Kubernetes Job. +type jobStatus struct { + succeeded bool + failed bool + active bool +} + +// createJob creates a Kubernetes batch/v1 Job from the given jobSpec. +func (c *kubeClient) createJob(ctx context.Context, namespace string, spec *jobSpec) error { + ttl := int32(120) + backoffLimit := int32(0) + activeDeadlineSeconds := int64(spec.TimeoutSeconds) + if activeDeadlineSeconds <= 0 { + activeDeadlineSeconds = 300 + } + runAsUser := int64(1000) + runAsGroup := int64(1000) + runAsNonRoot := true + allowPrivEsc := false + readOnlyRootfs := true + + job := map[string]any{ + "apiVersion": "batch/v1", + "kind": "Job", + "metadata": map[string]any{ + "name": spec.Name, + "namespace": namespace, + "labels": map[string]string{ + "app.kubernetes.io/component": "sandbox", + "app.kubernetes.io/managed-by": "weknora", + }, + }, + "spec": map[string]any{ + "backoffLimit": backoffLimit, + "ttlSecondsAfterFinished": ttl, + "activeDeadlineSeconds": activeDeadlineSeconds, + "template": map[string]any{ + "metadata": map[string]any{ + "labels": map[string]string{ + "app.kubernetes.io/component": "sandbox", + "app.kubernetes.io/managed-by": "weknora", + "batch.kubernetes.io/job-name": spec.Name, + }, + }, + "spec": map[string]any{ + "restartPolicy": "Never", + "automountServiceAccountToken": false, + "serviceAccountName": spec.ServiceAccountName, + "securityContext": map[string]any{ + "runAsUser": runAsUser, + "runAsGroup": runAsGroup, + "runAsNonRoot": runAsNonRoot, + "seccompProfile": map[string]any{ + "type": "RuntimeDefault", + }, + }, + "volumes": []map[string]any{ + { + "name": "workspace", + "configMap": map[string]any{ + "name": spec.ConfigMapName, + }, + }, + { + "name": "tmp", + "emptyDir": map[string]any{ + "medium": "Memory", + "sizeLimit": "64Mi", + }, + }, + }, + "containers": []map[string]any{ + { + "name": "sandbox", + "image": spec.Image, + "command": spec.Command, + "securityContext": map[string]any{ + "allowPrivilegeEscalation": allowPrivEsc, + "readOnlyRootFilesystem": readOnlyRootfs, + "runAsUser": runAsUser, + "runAsGroup": runAsGroup, + "runAsNonRoot": runAsNonRoot, + "capabilities": map[string]any{ + "drop": []string{"ALL"}, + }, + "seccompProfile": map[string]any{ + "type": "RuntimeDefault", + }, + }, + "resources": map[string]any{ + "limits": map[string]any{ + "memory": spec.MemoryLimit, + "cpu": spec.CPULimit, + }, + }, + "volumeMounts": []map[string]any{ + { + "name": "workspace", + "mountPath": "/workspace", + "readOnly": true, + }, + { + "name": "tmp", + "mountPath": "/tmp", + }, + }, + }, + }, + }, + }, + }, + } + + path := fmt.Sprintf("/apis/batch/v1/namespaces/%s/jobs", namespace) + _, status, err := c.do(ctx, http.MethodPost, path, job) + if err != nil { + return err + } + if status != http.StatusCreated && status != http.StatusOK { + return fmt.Errorf("createJob: unexpected status %d", status) + } + return nil +} + +// getJobStatus returns the current status of a Kubernetes Job. +func (c *kubeClient) getJobStatus(ctx context.Context, namespace, name string) (*jobStatus, error) { + path := fmt.Sprintf("/apis/batch/v1/namespaces/%s/jobs/%s", namespace, name) + body, status, err := c.do(ctx, http.MethodGet, path, nil) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("getJobStatus: unexpected status %d", status) + } + + var result struct { + Status struct { + Succeeded int32 `json:"succeeded"` + Failed int32 `json:"failed"` + Active int32 `json:"active"` + } `json:"status"` + } + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("getJobStatus: failed to parse response: %w", err) + } + + return &jobStatus{ + succeeded: result.Status.Succeeded > 0, + failed: result.Status.Failed > 0, + active: result.Status.Active > 0, + }, nil +} + +// findJobPod returns the name of the first pod belonging to the given job. +func (c *kubeClient) findJobPod(ctx context.Context, namespace, jobName string) (string, error) { + q := url.Values{} + q.Set("labelSelector", "batch.kubernetes.io/job-name="+jobName) + path := fmt.Sprintf("/api/v1/namespaces/%s/pods?%s", namespace, q.Encode()) + body, status, err := c.do(ctx, http.MethodGet, path, nil) + if err != nil { + return "", err + } + if status != http.StatusOK { + return "", fmt.Errorf("findJobPod: unexpected status %d", status) + } + + var result struct { + Items []struct { + Metadata struct { + Name string `json:"name"` + } `json:"metadata"` + } `json:"items"` + } + if err := json.Unmarshal(body, &result); err != nil { + return "", fmt.Errorf("findJobPod: failed to parse response: %w", err) + } + + if len(result.Items) == 0 { + return "", fmt.Errorf("findJobPod: no pods found for job %s", jobName) + } + + return result.Items[0].Metadata.Name, nil +} + +// getPodLogs returns the logs of a pod, optionally limiting the response size. +func (c *kubeClient) getPodLogs(ctx context.Context, namespace, podName string, limitBytes int64) (string, error) { + path := fmt.Sprintf("/api/v1/namespaces/%s/pods/%s/log", namespace, podName) + if limitBytes > 0 { + path += "?limitBytes=" + strconv.FormatInt(limitBytes, 10) + } + + body, status, err := c.do(ctx, http.MethodGet, path, nil) + if err != nil { + return "", err + } + if status != http.StatusOK { + return "", fmt.Errorf("getPodLogs: unexpected status %d", status) + } + + return string(body), nil +} + +// deleteJob deletes a Kubernetes Job with cascading deletion via Background propagation policy. +func (c *kubeClient) deleteJob(ctx context.Context, namespace, name string) error { + path := fmt.Sprintf("/apis/batch/v1/namespaces/%s/jobs/%s?propagationPolicy=Background", namespace, name) + _, status, err := c.do(ctx, http.MethodDelete, path, nil) + if err != nil { + return err + } + if status != http.StatusOK && status != http.StatusAccepted && status != http.StatusNoContent { + return fmt.Errorf("deleteJob: unexpected status %d", status) + } + return nil +} + +// configMapEntry holds a ConfigMap's name and labels. +type configMapEntry struct { + name string + labels map[string]string +} + +// listConfigMapsWithLabels returns ConfigMap entries (name + labels) matching the label selector. +func (c *kubeClient) listConfigMapsWithLabels(ctx context.Context, namespace, labelSelector string) ([]configMapEntry, error) { + path := fmt.Sprintf("/api/v1/namespaces/%s/configmaps", namespace) + if labelSelector != "" { + q := url.Values{} + q.Set("labelSelector", labelSelector) + path += "?" + q.Encode() + } + + body, status, err := c.do(ctx, http.MethodGet, path, nil) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("listConfigMapsWithLabels: unexpected status %d", status) + } + + var result struct { + Items []struct { + Metadata struct { + Name string `json:"name"` + Labels map[string]string `json:"labels"` + } `json:"metadata"` + } `json:"items"` + } + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("listConfigMapsWithLabels: failed to parse response: %w", err) + } + + entries := make([]configMapEntry, 0, len(result.Items)) + for _, item := range result.Items { + entries = append(entries, configMapEntry{ + name: item.Metadata.Name, + labels: item.Metadata.Labels, + }) + } + return entries, nil +} + +// checkAccess returns true if the client can operate in the given namespace. +// It verifies by listing configmaps (which only requires namespace-scoped RBAC). +func (c *kubeClient) checkAccess(ctx context.Context, namespace string) bool { + path := fmt.Sprintf("/api/v1/namespaces/%s/configmaps?limit=1", namespace) + _, status, err := c.do(ctx, http.MethodGet, path, nil) + return err == nil && status == http.StatusOK +} diff --git a/internal/sandbox/kubeclient_test.go b/internal/sandbox/kubeclient_test.go new file mode 100644 index 000000000..1d5609e75 --- /dev/null +++ b/internal/sandbox/kubeclient_test.go @@ -0,0 +1,458 @@ +package sandbox + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// newTestClient creates a kubeClient pointing at the given httptest server URL. +func newTestClient(server *httptest.Server) *kubeClient { + return newKubeClient(server.URL, "test-token", server.Client()) +} + +func TestKubeClientCreateConfigMap(t *testing.T) { + var gotMethod, gotPath string + var gotBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + gotPath = r.URL.Path + + if err := json.NewDecoder(r.Body).Decode(&gotBody); err != nil { + http.Error(w, "bad body", http.StatusBadRequest) + return + } + + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{}`)) + })) + defer server.Close() + + client := newTestClient(server) + ctx := context.Background() + + data := map[string]string{"script.py": "print('hello')"} + labels := map[string]string{"app": "test"} + + err := client.createConfigMap(ctx, "test-ns", "my-cm", data, labels) + if err != nil { + t.Fatalf("createConfigMap failed: %v", err) + } + + if gotMethod != http.MethodPost { + t.Errorf("expected POST, got %s", gotMethod) + } + if gotPath != "/api/v1/namespaces/test-ns/configmaps" { + t.Errorf("unexpected path: %s", gotPath) + } + + meta, ok := gotBody["metadata"].(map[string]any) + if !ok { + t.Fatalf("metadata not found in body") + } + if meta["name"] != "my-cm" { + t.Errorf("unexpected name: %v", meta["name"]) + } + if meta["namespace"] != "test-ns" { + t.Errorf("unexpected namespace: %v", meta["namespace"]) + } +} + +func TestKubeClientDeleteConfigMap(t *testing.T) { + var gotMethod, gotPath string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + gotPath = r.URL.Path + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer server.Close() + + client := newTestClient(server) + err := client.deleteConfigMap(context.Background(), "test-ns", "my-cm") + if err != nil { + t.Fatalf("deleteConfigMap failed: %v", err) + } + + if gotMethod != http.MethodDelete { + t.Errorf("expected DELETE, got %s", gotMethod) + } + if gotPath != "/api/v1/namespaces/test-ns/configmaps/my-cm" { + t.Errorf("unexpected path: %s", gotPath) + } +} + +func TestKubeClientCreateJob(t *testing.T) { + var gotMethod, gotPath string + var gotBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + gotPath = r.URL.Path + + if err := json.NewDecoder(r.Body).Decode(&gotBody); err != nil { + http.Error(w, "bad body", http.StatusBadRequest) + return + } + + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{}`)) + })) + defer server.Close() + + client := newTestClient(server) + ctx := context.Background() + + spec := &jobSpec{ + Name: "sandbox-abc123", + Image: "python:3.11-slim", + Command: []string{"python", "/workspace/script.py"}, + ConfigMapName: "cm-abc123", + TimeoutSeconds: 60, + ServiceAccountName: "sandbox-runner", + MemoryLimit: "256Mi", + CPULimit: "500m", + } + + err := client.createJob(ctx, "test-ns", spec) + if err != nil { + t.Fatalf("createJob failed: %v", err) + } + + if gotMethod != http.MethodPost { + t.Errorf("expected POST, got %s", gotMethod) + } + if gotPath != "/apis/batch/v1/namespaces/test-ns/jobs" { + t.Errorf("unexpected path: %s", gotPath) + } + + if gotBody["apiVersion"] != "batch/v1" { + t.Errorf("unexpected apiVersion: %v", gotBody["apiVersion"]) + } + + meta, _ := gotBody["metadata"].(map[string]any) + if meta["name"] != "sandbox-abc123" { + t.Errorf("unexpected job name: %v", meta["name"]) + } + + // Verify security settings in pod template + jobSpec, _ := gotBody["spec"].(map[string]any) + template, _ := jobSpec["template"].(map[string]any) + podSpec, _ := template["spec"].(map[string]any) + + if podSpec["automountServiceAccountToken"] != false { + t.Errorf("automountServiceAccountToken should be false") + } + if podSpec["restartPolicy"] != "Never" { + t.Errorf("restartPolicy should be Never, got: %v", podSpec["restartPolicy"]) + } + + podSC, _ := podSpec["securityContext"].(map[string]any) + if podSC["runAsNonRoot"] != true { + t.Errorf("runAsNonRoot should be true") + } + + containers, _ := podSpec["containers"].([]any) + if len(containers) == 0 { + t.Fatal("no containers in job spec") + } + container, _ := containers[0].(map[string]any) + csc, _ := container["securityContext"].(map[string]any) + + if csc["allowPrivilegeEscalation"] != false { + t.Errorf("allowPrivilegeEscalation should be false") + } + if csc["readOnlyRootFilesystem"] != true { + t.Errorf("readOnlyRootFilesystem should be true") + } + caps, _ := csc["capabilities"].(map[string]any) + drop, _ := caps["drop"].([]any) + if len(drop) == 0 || drop[0] != "ALL" { + t.Errorf("capabilities.drop should contain ALL, got: %v", drop) + } + + // Verify volume mounts + mounts, _ := container["volumeMounts"].([]any) + if len(mounts) < 2 { + t.Errorf("expected at least 2 volume mounts, got %d", len(mounts)) + } + + // Verify pod labels include job-name + templateMeta, _ := template["metadata"].(map[string]any) + podLabels, _ := templateMeta["labels"].(map[string]any) + if podLabels["batch.kubernetes.io/job-name"] != "sandbox-abc123" { + t.Errorf("pod label batch.kubernetes.io/job-name should match job name") + } +} + +func TestKubeClientGetJobStatus(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "status": { + "succeeded": 1, + "failed": 0, + "active": 0 + } + }`)) + })) + defer server.Close() + + client := newTestClient(server) + status, err := client.getJobStatus(context.Background(), "test-ns", "my-job") + if err != nil { + t.Fatalf("getJobStatus failed: %v", err) + } + + if !status.succeeded { + t.Errorf("expected succeeded=true") + } + if status.failed { + t.Errorf("expected failed=false") + } + if status.active { + t.Errorf("expected active=false") + } +} + +func TestKubeClientGetJobStatusActive(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": {"succeeded": 0, "failed": 0, "active": 1}}`)) + })) + defer server.Close() + + client := newTestClient(server) + status, err := client.getJobStatus(context.Background(), "test-ns", "my-job") + if err != nil { + t.Fatalf("getJobStatus failed: %v", err) + } + + if status.succeeded || status.failed { + t.Errorf("expected only active=true") + } + if !status.active { + t.Errorf("expected active=true") + } +} + +func TestKubeClientGetPodLogs(t *testing.T) { + const logContent = "hello from sandbox\nline 2\n" + var gotQuery string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotQuery = r.URL.RawQuery + w.WriteHeader(http.StatusOK) + w.Write([]byte(logContent)) + })) + defer server.Close() + + client := newTestClient(server) + logs, err := client.getPodLogs(context.Background(), "test-ns", "my-pod", 1024) + if err != nil { + t.Fatalf("getPodLogs failed: %v", err) + } + + if logs != logContent { + t.Errorf("unexpected logs: %q", logs) + } + + if !strings.Contains(gotQuery, "limitBytes=1024") { + t.Errorf("expected limitBytes=1024 in query, got: %s", gotQuery) + } +} + +func TestKubeClientGetPodLogsNoLimit(t *testing.T) { + var gotPath string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.RawQuery + w.WriteHeader(http.StatusOK) + w.Write([]byte("logs")) + })) + defer server.Close() + + client := newTestClient(server) + _, err := client.getPodLogs(context.Background(), "test-ns", "my-pod", 0) + if err != nil { + t.Fatalf("getPodLogs failed: %v", err) + } + + if strings.Contains(gotPath, "limitBytes") { + t.Errorf("expected no limitBytes in query when limit=0, got: %s", gotPath) + } +} + +func TestKubeClientDeleteJob(t *testing.T) { + var gotMethod, gotPath, gotQuery string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + gotPath = r.URL.Path + gotQuery = r.URL.RawQuery + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer server.Close() + + client := newTestClient(server) + err := client.deleteJob(context.Background(), "test-ns", "my-job") + if err != nil { + t.Fatalf("deleteJob failed: %v", err) + } + + if gotMethod != http.MethodDelete { + t.Errorf("expected DELETE, got %s", gotMethod) + } + if gotPath != "/apis/batch/v1/namespaces/test-ns/jobs/my-job" { + t.Errorf("unexpected path: %s", gotPath) + } + if !strings.Contains(gotQuery, "propagationPolicy=Background") { + t.Errorf("expected propagationPolicy=Background in query, got: %s", gotQuery) + } +} + +func TestKubeClientFindJobPod(t *testing.T) { + var gotQuery string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotQuery = r.URL.RawQuery + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "items": [ + {"metadata": {"name": "my-job-abc12"}}, + {"metadata": {"name": "my-job-xyz99"}} + ] + }`)) + })) + defer server.Close() + + client := newTestClient(server) + podName, err := client.findJobPod(context.Background(), "test-ns", "my-job") + if err != nil { + t.Fatalf("findJobPod failed: %v", err) + } + + if podName != "my-job-abc12" { + t.Errorf("expected first pod 'my-job-abc12', got %q", podName) + } + + if !strings.Contains(gotQuery, "labelSelector") { + t.Errorf("expected labelSelector in query, got: %s", gotQuery) + } + if !strings.Contains(gotQuery, "my-job") { + t.Errorf("expected job name in query selector, got: %s", gotQuery) + } +} + +func TestKubeClientFindJobPodNotFound(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"items": []}`)) + })) + defer server.Close() + + client := newTestClient(server) + _, err := client.findJobPod(context.Background(), "test-ns", "missing-job") + if err == nil { + t.Fatal("expected error when no pods found") + } +} + +func TestKubeClientListConfigMapsWithLabels(t *testing.T) { + var gotQuery string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotQuery = r.URL.RawQuery + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "items": [ + {"metadata": {"name": "cm-one", "labels": {"env": "test"}}}, + {"metadata": {"name": "cm-two", "labels": {"env": "prod"}}} + ] + }`)) + })) + defer server.Close() + + client := newTestClient(server) + entries, err := client.listConfigMapsWithLabels(context.Background(), "test-ns", "app=sandbox") + if err != nil { + t.Fatalf("listConfigMapsWithLabels failed: %v", err) + } + + if len(entries) != 2 { + t.Fatalf("expected 2 entries, got %d", len(entries)) + } + if entries[0].name != "cm-one" || entries[1].name != "cm-two" { + t.Errorf("unexpected names: %v", entries) + } + if entries[0].labels["env"] != "test" { + t.Errorf("expected label env=test on first entry, got: %v", entries[0].labels) + } + if !strings.Contains(gotQuery, "labelSelector") { + t.Errorf("expected labelSelector in query, got: %s", gotQuery) + } +} + +func TestKubeClientCheckAccess(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"items": []}`)) + })) + defer server.Close() + + client := newTestClient(server) + if !client.checkAccess(context.Background(), "test-ns") { + t.Errorf("expected checkAccess to return true") + } +} + +func TestKubeClientCheckAccessDenied(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte(`{}`)) + })) + defer server.Close() + + client := newTestClient(server) + if client.checkAccess(context.Background(), "test-ns") { + t.Errorf("expected checkAccess to return false on 403") + } +} + +func TestKubeClientBearerToken(t *testing.T) { + var gotAuth string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"items": []}`)) + })) + defer server.Close() + + client := newKubeClient(server.URL, "my-secret-token", server.Client()) + _, err := client.listConfigMapsWithLabels(context.Background(), "ns", "") + if err != nil { + t.Fatalf("listConfigMapsWithLabels failed: %v", err) + } + + expected := "Bearer my-secret-token" + if gotAuth != expected { + t.Errorf("expected Authorization header %q, got %q", expected, gotAuth) + } +} + +func TestInClusterAvailable(t *testing.T) { + // In test environments, SA token is not present — expect false. + result := inClusterAvailable() + if result { + t.Log("inClusterAvailable returned true — running inside a K8s cluster") + } else { + t.Log("inClusterAvailable returned false — expected outside cluster") + } +} diff --git a/internal/sandbox/kubernetes.go b/internal/sandbox/kubernetes.go new file mode 100644 index 000000000..29df1f086 --- /dev/null +++ b/internal/sandbox/kubernetes.go @@ -0,0 +1,312 @@ +package sandbox + +import ( + "context" + "fmt" + "strconv" + "sync" + "time" + + "github.com/google/uuid" +) + +const ( + sandboxLabelKey = "weknora-sandbox" + sandboxLabelValue = "true" + sandboxCreatedAtKey = "weknora-sandbox-created-at" + gcOrphanAge = 30 * time.Minute + gcTickerInterval = 5 * time.Minute + gcListTimeout = 30 * time.Second + gcDeleteTimeout = 5 * time.Second + cleanupTimeout = 10 * time.Second + jobPollInterval = 500 * time.Millisecond +) + +// KubernetesSandbox implements the Sandbox interface using Kubernetes Jobs. +type KubernetesSandbox struct { + config *Config + mu sync.Mutex + client *kubeClient + semaphore chan struct{} + stopCh chan struct{} + stopOnce sync.Once +} + +// NewKubernetesSandbox creates a new KubernetesSandbox. +// If config is nil, DefaultConfig() is used. +// The client may be nil; it will be initialized lazily in IsAvailable/Execute. +func NewKubernetesSandbox(config *Config, client *kubeClient) *KubernetesSandbox { + if config == nil { + config = DefaultConfig() + } + + s := &KubernetesSandbox{ + config: config, + client: client, + semaphore: make(chan struct{}, config.MaxConcurrentSandboxes), + stopCh: make(chan struct{}), + } + + go s.gcOrphanResources() + + return s +} + +// Type returns the sandbox type. +func (s *KubernetesSandbox) Type() SandboxType { + return SandboxTypeKubernetes +} + +// IsAvailable checks if the Kubernetes sandbox is available. +func (s *KubernetesSandbox) IsAvailable(ctx context.Context) bool { + s.mu.Lock() + if s.client == nil { + if !inClusterAvailable() { + s.mu.Unlock() + return false + } + c, err := newKubeClientInCluster() + if err != nil { + s.mu.Unlock() + return false + } + s.client = c + } + c := s.client + s.mu.Unlock() + return c.checkAccess(ctx, s.config.KubeNamespace) +} + +// getClient returns the current kubeClient under the mutex. +func (s *KubernetesSandbox) getClient() *kubeClient { + s.mu.Lock() + defer s.mu.Unlock() + return s.client +} + +// Execute runs a script in an ephemeral Kubernetes Job. +func (s *KubernetesSandbox) Execute(ctx context.Context, config *ExecuteConfig) (*ExecuteResult, error) { + if config == nil { + return nil, ErrInvalidScript + } + + scriptContent, scriptName, err := resolveScript(config) + if err != nil { + return nil, err + } + + if s.config.MaxScriptSize > 0 && int64(len(scriptContent)) > s.config.MaxScriptSize { + return nil, fmt.Errorf("%w: script size %d exceeds limit %d", ErrInvalidScript, len(scriptContent), s.config.MaxScriptSize) + } + + client := s.getClient() + if client == nil { + return nil, fmt.Errorf("kubernetes sandbox: client not initialized (call IsAvailable first)") + } + + select { + case s.semaphore <- struct{}{}: + case <-ctx.Done(): + return nil, ctx.Err() + } + defer func() { <-s.semaphore }() + + timeout := config.Timeout + if timeout == 0 { + timeout = s.config.DefaultTimeout + } + if timeout == 0 { + timeout = DefaultTimeout + } + + execCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + resourceName := fmt.Sprintf("weknora-sandbox-%s", uuid.New().String()[:8]) + + labels := map[string]string{ + sandboxLabelKey: sandboxLabelValue, + sandboxCreatedAtKey: strconv.FormatInt(time.Now().Unix(), 10), + } + + cmData := map[string]string{ + scriptName: scriptContent, + } + if config.Stdin != "" { + cmData[stdinFileName] = config.Stdin + } + + if err := client.createConfigMap(execCtx, s.config.KubeNamespace, resourceName, cmData, labels); err != nil { + return nil, fmt.Errorf("failed to create ConfigMap: %w", err) + } + + defer func() { + cleanCtx, cleanCancel := context.WithTimeout(context.Background(), cleanupTimeout) + defer cleanCancel() + _ = client.deleteJob(cleanCtx, s.config.KubeNamespace, resourceName) + _ = client.deleteConfigMap(cleanCtx, s.config.KubeNamespace, resourceName) + }() + + interpreter := getInterpreter(scriptName) + scriptPath := "/workspace/" + scriptName + + // Build command: use exec-form when possible, shell-form only when stdin piping is needed + var command []string + if config.Stdin != "" { + command = []string{"sh", "-c", buildShellCommand(interpreter, scriptPath, config.Args, true)} + } else { + command = buildExecCommand(interpreter, scriptPath, config.Args) + } + + memoryLimit := s.config.MaxMemory + if config.MemoryLimit > 0 { + memoryLimit = config.MemoryLimit + } + memoryLimitStr := fmt.Sprintf("%dMi", memoryLimit/(1024*1024)) + + cpuLimit := s.config.MaxCPU + if config.CPULimit > 0 { + cpuLimit = config.CPULimit + } + cpuLimitStr := fmt.Sprintf("%dm", int(cpuLimit*1000)) + + spec := &jobSpec{ + Name: resourceName, + Image: s.config.DockerImage, + Command: command, + ConfigMapName: resourceName, + TimeoutSeconds: int(timeout.Seconds()), + ServiceAccountName: s.config.KubeServiceAccount, + MemoryLimit: memoryLimitStr, + CPULimit: cpuLimitStr, + } + + if err := client.createJob(execCtx, s.config.KubeNamespace, spec); err != nil { + return nil, fmt.Errorf("failed to create Job: %w", err) + } + + startTime := time.Now() + + pollTicker := time.NewTicker(jobPollInterval) + defer pollTicker.Stop() + + var finalStatus *jobStatus + for { + select { + case <-execCtx.Done(): + return &ExecuteResult{ + ExitCode: -1, + Killed: true, + Error: ErrTimeout.Error(), + Duration: time.Since(startTime), + }, nil + case <-pollTicker.C: + } + + status, err := client.getJobStatus(execCtx, s.config.KubeNamespace, resourceName) + if err != nil { + // Context may have been cancelled + if execCtx.Err() != nil { + return &ExecuteResult{ + ExitCode: -1, + Killed: true, + Error: ErrTimeout.Error(), + Duration: time.Since(startTime), + }, nil + } + continue + } + + if status.succeeded || status.failed { + finalStatus = status + break + } + } + + duration := time.Since(startTime) + + exitCode := 0 + if finalStatus != nil && finalStatus.failed { + exitCode = 1 + } + + podName, err := client.findJobPod(execCtx, s.config.KubeNamespace, resourceName) + if err != nil { + return &ExecuteResult{ + ExitCode: exitCode, + Duration: duration, + Error: fmt.Sprintf("failed to find pod: %v", err), + }, nil + } + + logs, err := client.getPodLogs(execCtx, s.config.KubeNamespace, podName, s.config.MaxLogSize) + if err != nil { + return &ExecuteResult{ + ExitCode: exitCode, + Duration: duration, + Error: fmt.Sprintf("failed to get pod logs: %v", err), + }, nil + } + + return &ExecuteResult{ + Stdout: logs, + ExitCode: exitCode, + Duration: duration, + }, nil +} + +// Cleanup stops the background GC goroutine and releases sandbox resources. +// Safe to call multiple times. +func (s *KubernetesSandbox) Cleanup(ctx context.Context) error { + s.stopOnce.Do(func() { close(s.stopCh) }) + return nil +} + +// gcOrphanResources periodically cleans up orphaned ConfigMaps left behind by failed executions. +// It uses the weknora-sandbox-created-at label to determine age and only deletes resources +// older than gcOrphanAge. Normal cleanup is handled by defer in Execute; this catches leaks. +func (s *KubernetesSandbox) gcOrphanResources() { + ticker := time.NewTicker(gcTickerInterval) + defer ticker.Stop() + + for { + select { + case <-s.stopCh: + return + case <-ticker.C: + } + + client := s.getClient() + if client == nil { + continue + } + + ctx, cancel := context.WithTimeout(context.Background(), gcListTimeout) + labelSelector := sandboxLabelKey + "=" + sandboxLabelValue + cms, err := client.listConfigMapsWithLabels(ctx, s.config.KubeNamespace, labelSelector) + cancel() + + if err != nil { + continue + } + + now := time.Now().Unix() + for _, cm := range cms { + createdAtStr, ok := cm.labels[sandboxCreatedAtKey] + if !ok { + continue + } + createdAt, err := strconv.ParseInt(createdAtStr, 10, 64) + if err != nil { + continue + } + if now-createdAt < int64(gcOrphanAge.Seconds()) { + continue + } + + cleanCtx, cleanCancel := context.WithTimeout(context.Background(), gcDeleteTimeout) + _ = client.deleteConfigMap(cleanCtx, s.config.KubeNamespace, cm.name) + cleanCancel() + } + } +} diff --git a/internal/sandbox/kubernetes_test.go b/internal/sandbox/kubernetes_test.go new file mode 100644 index 000000000..c02cc6ea0 --- /dev/null +++ b/internal/sandbox/kubernetes_test.go @@ -0,0 +1,439 @@ +package sandbox + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestKubernetesSandboxType(t *testing.T) { + s := NewKubernetesSandbox(nil, nil) + t.Cleanup(func() { s.Cleanup(context.Background()) }) + if s.Type() != SandboxTypeKubernetes { + t.Errorf("expected Type() = %q, got %q", SandboxTypeKubernetes, s.Type()) + } +} + +func TestKubernetesSandboxScriptSizeValidation(t *testing.T) { + cfg := DefaultConfig() + cfg.MaxScriptSize = 100 // very small limit + + s := NewKubernetesSandbox(cfg, nil) + t.Cleanup(func() { s.Cleanup(context.Background()) }) + + // Create a script larger than MaxScriptSize + largeContent := strings.Repeat("x", 200) + + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "big.py") + if err := os.WriteFile(scriptPath, []byte(largeContent), 0644); err != nil { + t.Fatalf("failed to write temp script: %v", err) + } + + ctx := context.Background() + _, err := s.Execute(ctx, &ExecuteConfig{ + Script: scriptPath, + }) + + if err == nil { + t.Fatal("expected error for oversized script, got nil") + } + if !strings.Contains(err.Error(), "exceeds limit") { + t.Errorf("expected 'exceeds limit' in error message, got: %v", err) + } +} + +func TestKubernetesSandboxExecute(t *testing.T) { + const mockLogOutput = "hello from kubernetes sandbox\n" + const jobName = "test-sandbox-job" + + // Track which paths were called + type call struct { + method string + path string + } + var calls []call + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls = append(calls, call{r.Method, r.URL.Path}) + + switch { + // POST configmaps + case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/configmaps"): + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{}`)) + + // POST jobs + case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/jobs"): + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{}`)) + + // GET job status + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/jobs/"): + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": {"succeeded": 1, "failed": 0, "active": 0}}`)) + + // GET pods (find job pod via label selector) + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/pods") && !strings.Contains(r.URL.Path, "/log"): + w.WriteHeader(http.StatusOK) + resp := map[string]any{ + "items": []map[string]any{ + {"metadata": map[string]any{"name": "sandbox-pod-abc"}}, + }, + } + data, _ := json.Marshal(resp) + w.Write(data) + + // GET pod logs + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/log"): + w.WriteHeader(http.StatusOK) + w.Write([]byte(mockLogOutput)) + + // DELETE jobs + case r.Method == http.MethodDelete && strings.Contains(r.URL.Path, "/jobs/"): + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + + // DELETE configmaps + case r.Method == http.MethodDelete && strings.Contains(r.URL.Path, "/configmaps/"): + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + + // Access check (list configmaps with limit=1) + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/configmaps") && strings.Contains(r.URL.RawQuery, "limit=1"): + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"items": []}`)) + + default: + t.Logf("unhandled request: %s %s", r.Method, r.URL.Path) + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + client := newKubeClient(server.URL, "test-token", server.Client()) + + cfg := DefaultConfig() + cfg.KubeNamespace = "test-ns" + cfg.MaxConcurrentSandboxes = 2 + + s := NewKubernetesSandbox(cfg, client) + t.Cleanup(func() { s.Cleanup(context.Background()) }) + + // Create a temp script file + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "test.py") + scriptContent := `print("hello from kubernetes sandbox")` + if err := os.WriteFile(scriptPath, []byte(scriptContent), 0644); err != nil { + t.Fatalf("failed to write temp script: %v", err) + } + + ctx := context.Background() + result, err := s.Execute(ctx, &ExecuteConfig{ + Script: scriptPath, + }) + + if err != nil { + t.Fatalf("Execute returned error: %v", err) + } + if result == nil { + t.Fatal("Execute returned nil result") + } + if result.Stdout != mockLogOutput { + t.Errorf("expected Stdout %q, got %q", mockLogOutput, result.Stdout) + } + if result.ExitCode != 0 { + t.Errorf("expected ExitCode=0, got %d", result.ExitCode) + } + + // Verify at minimum: ConfigMap was created, Job was created, Job was polled, logs were fetched + foundCMCreate := false + foundJobCreate := false + foundLogFetch := false + for _, c := range calls { + if c.method == http.MethodPost && strings.Contains(c.path, "/configmaps") { + foundCMCreate = true + } + if c.method == http.MethodPost && strings.Contains(c.path, "/jobs") { + foundJobCreate = true + } + if c.method == http.MethodGet && strings.Contains(c.path, "/log") { + foundLogFetch = true + } + } + if !foundCMCreate { + t.Error("expected ConfigMap creation call") + } + if !foundJobCreate { + t.Error("expected Job creation call") + } + if !foundLogFetch { + t.Error("expected pod log fetch call") + } +} + +func TestKubernetesSandboxExecuteWithScriptContent(t *testing.T) { + const mockLogOutput = "content-based execution\n" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/configmaps"): + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{}`)) + case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/jobs"): + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{}`)) + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/jobs/"): + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": {"succeeded": 1, "failed": 0, "active": 0}}`)) + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/pods") && !strings.Contains(r.URL.Path, "/log"): + w.WriteHeader(http.StatusOK) + resp := map[string]any{ + "items": []map[string]any{ + {"metadata": map[string]any{"name": "pod-xyz"}}, + }, + } + data, _ := json.Marshal(resp) + w.Write(data) + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/log"): + w.WriteHeader(http.StatusOK) + w.Write([]byte(mockLogOutput)) + case r.Method == http.MethodDelete: + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + client := newKubeClient(server.URL, "token", server.Client()) + s := NewKubernetesSandbox(nil, client) + t.Cleanup(func() { s.Cleanup(context.Background()) }) + + ctx := context.Background() + result, err := s.Execute(ctx, &ExecuteConfig{ + Script: "myscript.sh", + ScriptContent: `echo "content-based execution"`, + }) + + if err != nil { + t.Fatalf("Execute returned error: %v", err) + } + if result.Stdout != mockLogOutput { + t.Errorf("expected Stdout %q, got %q", mockLogOutput, result.Stdout) + } +} + +func TestKubernetesSandboxExecuteJobFailed(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/configmaps"): + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{}`)) + case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/jobs"): + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{}`)) + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/jobs/"): + // Job failed + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": {"succeeded": 0, "failed": 1, "active": 0}}`)) + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/pods") && !strings.Contains(r.URL.Path, "/log"): + w.WriteHeader(http.StatusOK) + resp := map[string]any{ + "items": []map[string]any{ + {"metadata": map[string]any{"name": "failed-pod"}}, + }, + } + data, _ := json.Marshal(resp) + w.Write(data) + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/log"): + w.WriteHeader(http.StatusOK) + w.Write([]byte("error output\n")) + case r.Method == http.MethodDelete: + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + client := newKubeClient(server.URL, "token", server.Client()) + s := NewKubernetesSandbox(nil, client) + t.Cleanup(func() { s.Cleanup(context.Background()) }) + + ctx := context.Background() + result, err := s.Execute(ctx, &ExecuteConfig{ + Script: "script.py", + ScriptContent: `raise Exception("fail")`, + }) + + if err != nil { + t.Fatalf("Execute should not return Go error for failed jobs, got: %v", err) + } + if result.ExitCode != 1 { + t.Errorf("expected ExitCode=1 for failed job, got %d", result.ExitCode) + } +} + +func TestKubernetesSandboxExecuteWithArgsAndStdin(t *testing.T) { + var capturedJobBody []byte + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/configmaps"): + // Verify stdin file is included in ConfigMap data + body, _ := io.ReadAll(r.Body) + var cm map[string]any + json.Unmarshal(body, &cm) + data, _ := cm["data"].(map[string]any) + if _, ok := data[".stdin"]; !ok { + t.Error("expected .stdin key in ConfigMap data") + } + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{}`)) + case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/jobs"): + capturedJobBody, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{}`)) + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/jobs/"): + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": {"succeeded": 1, "failed": 0, "active": 0}}`)) + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/pods") && !strings.Contains(r.URL.Path, "/log"): + w.WriteHeader(http.StatusOK) + resp := map[string]any{"items": []map[string]any{{"metadata": map[string]any{"name": "pod-1"}}}} + data, _ := json.Marshal(resp) + w.Write(data) + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/log"): + w.WriteHeader(http.StatusOK) + w.Write([]byte("output\n")) + case r.Method == http.MethodDelete: + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + client := newKubeClient(server.URL, "token", server.Client()) + s := NewKubernetesSandbox(nil, client) + t.Cleanup(func() { s.Cleanup(context.Background()) }) + + _, err := s.Execute(context.Background(), &ExecuteConfig{ + Script: "script.py", + ScriptContent: `import sys; print(sys.stdin.read())`, + Args: []string{"--format", "json"}, + Stdin: `{"key": "value"}`, + }) + if err != nil { + t.Fatalf("Execute returned error: %v", err) + } + + // Verify the Job command uses sh -c with stdin piping + var job map[string]any + json.Unmarshal(capturedJobBody, &job) + spec := job["spec"].(map[string]any)["template"].(map[string]any)["spec"].(map[string]any) + containers := spec["containers"].([]any) + container := containers[0].(map[string]any) + command := container["command"].([]any) + + if len(command) != 3 || command[0] != "sh" || command[1] != "-c" { + t.Fatalf("expected sh -c command for stdin mode, got: %v", command) + } + cmdStr := command[2].(string) + if !strings.Contains(cmdStr, ".stdin") { + t.Errorf("expected .stdin in command, got: %s", cmdStr) + } + if !strings.Contains(cmdStr, "--format") { + t.Errorf("expected --format arg in command, got: %s", cmdStr) + } +} + +func TestKubernetesSandboxExecuteExecFormWithoutStdin(t *testing.T) { + var capturedJobBody []byte + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/configmaps"): + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{}`)) + case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/jobs"): + capturedJobBody, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{}`)) + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/jobs/"): + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": {"succeeded": 1, "failed": 0, "active": 0}}`)) + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/pods") && !strings.Contains(r.URL.Path, "/log"): + w.WriteHeader(http.StatusOK) + resp := map[string]any{"items": []map[string]any{{"metadata": map[string]any{"name": "pod-1"}}}} + data, _ := json.Marshal(resp) + w.Write(data) + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/log"): + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok\n")) + case r.Method == http.MethodDelete: + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + client := newKubeClient(server.URL, "token", server.Client()) + s := NewKubernetesSandbox(nil, client) + t.Cleanup(func() { s.Cleanup(context.Background()) }) + + _, err := s.Execute(context.Background(), &ExecuteConfig{ + Script: "script.py", + ScriptContent: `print("ok")`, + Args: []string{"--verbose"}, + }) + if err != nil { + t.Fatalf("Execute returned error: %v", err) + } + + // Verify exec-form (no shell wrapping) when there's no stdin + var job map[string]any + json.Unmarshal(capturedJobBody, &job) + spec := job["spec"].(map[string]any)["template"].(map[string]any)["spec"].(map[string]any) + containers := spec["containers"].([]any) + container := containers[0].(map[string]any) + command := container["command"].([]any) + + // Should be exec-form: ["python3", "/workspace/script.py", "--verbose"] + if len(command) < 3 { + t.Fatalf("expected at least 3 elements in exec-form command, got: %v", command) + } + if command[0] == "sh" { + t.Errorf("expected exec-form (no shell) without stdin, got sh -c: %v", command) + } + if command[0] != "python3" { + t.Errorf("expected python3 interpreter, got: %v", command[0]) + } + found := false + for _, c := range command { + if c == "--verbose" { + found = true + } + } + if !found { + t.Errorf("expected --verbose in command args, got: %v", command) + } +} + +func TestKubernetesSandboxCleanup(t *testing.T) { + s := NewKubernetesSandbox(nil, nil) + if err := s.Cleanup(context.Background()); err != nil { + t.Errorf("Cleanup should return nil, got: %v", err) + } +} diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index afd330e68..89387ff2c 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -74,6 +74,30 @@ func (m *DefaultManager) initializeSandbox(ctx context.Context) error { m.sandbox = NewLocalSandbox(m.config) return nil + case SandboxTypeKubernetes: + kubeSandbox := NewKubernetesSandbox(m.config, nil) + if kubeSandbox.IsAvailable(ctx) { + m.sandbox = kubeSandbox + return nil + } + if m.config.FallbackEnabled { + m.sandbox = NewLocalSandbox(m.config) + return nil + } + return fmt.Errorf("kubernetes is not available and fallback is disabled") + + case SandboxTypeOpenSandbox: + osSandbox := NewOpenSandboxSandbox(m.config) + if osSandbox.IsAvailable(ctx) { + m.sandbox = osSandbox + return nil + } + if m.config.FallbackEnabled { + m.sandbox = NewLocalSandbox(m.config) + return nil + } + return fmt.Errorf("opensandbox is not available and fallback is disabled") + default: return fmt.Errorf("unknown sandbox type: %s", m.config.Type) } @@ -233,6 +257,10 @@ func NewManagerFromType(sandboxType string, fallbackEnabled bool, dockerImage st sType = SandboxTypeLocal case "disabled", "": sType = SandboxTypeDisabled + case "kubernetes": + sType = SandboxTypeKubernetes + case "opensandbox": + sType = SandboxTypeOpenSandbox default: return nil, fmt.Errorf("unknown sandbox type: %s", sandboxType) } diff --git a/internal/sandbox/opensandbox.go b/internal/sandbox/opensandbox.go new file mode 100644 index 000000000..f7c0614ca --- /dev/null +++ b/internal/sandbox/opensandbox.go @@ -0,0 +1,150 @@ +package sandbox + +import ( + "context" + "fmt" + "time" +) + +// OpenSandboxSandbox implements the Sandbox interface using the external OpenSandbox REST API. +type OpenSandboxSandbox struct { + config *Config + client *openSandboxClient + semaphore chan struct{} +} + +// NewOpenSandboxSandbox creates a new OpenSandbox-based sandbox. +func NewOpenSandboxSandbox(config *Config) *OpenSandboxSandbox { + if config == nil { + config = DefaultConfig() + } + + var client *openSandboxClient + if config.OpenSandboxAPIURL != "" { + client = newOpenSandboxClient(config.OpenSandboxAPIURL, config.OpenSandboxAPIKey) + } + + return &OpenSandboxSandbox{ + config: config, + client: client, + semaphore: make(chan struct{}, config.MaxConcurrentSandboxes), + } +} + +// Type returns the sandbox type. +func (s *OpenSandboxSandbox) Type() SandboxType { + return SandboxTypeOpenSandbox +} + +// IsAvailable checks if the OpenSandbox API is reachable. +func (s *OpenSandboxSandbox) IsAvailable(ctx context.Context) bool { + if s.client == nil { + return false + } + return s.client.healthCheck(ctx) +} + +// Execute runs a script via the OpenSandbox REST API. +func (s *OpenSandboxSandbox) Execute(ctx context.Context, config *ExecuteConfig) (*ExecuteResult, error) { + if config == nil { + return nil, ErrInvalidScript + } + + scriptContent, scriptName, err := resolveScript(config) + if err != nil { + return nil, fmt.Errorf("opensandbox: %w", err) + } + + if s.config.MaxScriptSize > 0 && int64(len(scriptContent)) > s.config.MaxScriptSize { + return nil, fmt.Errorf("opensandbox: script size %d exceeds limit %d", len(scriptContent), s.config.MaxScriptSize) + } + + select { + case s.semaphore <- struct{}{}: + case <-ctx.Done(): + return nil, ctx.Err() + } + defer func() { <-s.semaphore }() + + timeout := config.Timeout + if timeout == 0 { + timeout = s.config.DefaultTimeout + } + if timeout == 0 { + timeout = DefaultTimeout + } + + execCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // Create sandbox + sbResp, err := s.client.createSandbox(execCtx, s.config.DockerImage, timeout) + if err != nil { + return nil, fmt.Errorf("opensandbox: create sandbox: %w", err) + } + + // Defer sandbox cleanup with a fresh timeout context + defer func() { + cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), cleanupTimeout) + defer cleanupCancel() + _ = s.client.deleteSandbox(cleanupCtx, sbResp.ID) + }() + + // Get execd URL via server proxy endpoint + execdURL, err := s.client.getExecdURL(execCtx, sbResp.ID) + if err != nil { + return nil, fmt.Errorf("opensandbox: %w", err) + } + + // Wait for execd to become ready + if err := s.client.waitForExecd(execCtx, execdURL); err != nil { + return nil, fmt.Errorf("opensandbox: %w", err) + } + + // Upload script + if err := s.client.uploadFile(execCtx, execdURL, scriptName, scriptContent); err != nil { + return nil, fmt.Errorf("opensandbox: upload script: %w", err) + } + + // Upload stdin content as a file if provided + if config.Stdin != "" { + if err := s.client.uploadFile(execCtx, execdURL, stdinFileName, config.Stdin); err != nil { + return nil, fmt.Errorf("opensandbox: upload stdin: %w", err) + } + } + + // Build and execute command + interpreter := getInterpreter(scriptName) + command := buildShellCommand(interpreter, "/workspace/"+scriptName, config.Args, config.Stdin != "") + + startTime := time.Now() + execResp, err := s.client.executeCommand(execCtx, execdURL, command) + duration := time.Since(startTime) + if err != nil { + return nil, fmt.Errorf("opensandbox: execute command: %w", err) + } + + // Truncate stdout/stderr to MaxLogSize + stdout := execResp.Stdout + stderr := execResp.Stderr + if s.config.MaxLogSize > 0 { + if int64(len(stdout)) > s.config.MaxLogSize { + stdout = stdout[:s.config.MaxLogSize] + } + if int64(len(stderr)) > s.config.MaxLogSize { + stderr = stderr[:s.config.MaxLogSize] + } + } + + return &ExecuteResult{ + Stdout: stdout, + Stderr: stderr, + ExitCode: execResp.ExitCode, + Duration: duration, + }, nil +} + +// Cleanup releases sandbox resources. OpenSandbox sandboxes are cleaned up per-execution. +func (s *OpenSandboxSandbox) Cleanup(ctx context.Context) error { + return nil +} diff --git a/internal/sandbox/opensandbox_client.go b/internal/sandbox/opensandbox_client.go new file mode 100644 index 000000000..2f8e9d3e3 --- /dev/null +++ b/internal/sandbox/opensandbox_client.go @@ -0,0 +1,330 @@ +package sandbox + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "regexp" + "strings" + "time" + +) + +const ( + maxErrorBodyBytes = 4096 // cap on error response body reads + execdPort = 44772 +) + +var validSandboxIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + +type openSandboxClient struct { + apiURL string + apiKey string + httpClient *http.Client +} + +// openSandboxCreateResponse is the response from POST /v1/sandboxes. +type openSandboxCreateResponse struct { + ID string `json:"id"` +} + +// openSandboxExecResponse aggregates stdout/stderr/exitCode from the execd streaming response. +type openSandboxExecResponse struct { + Stdout string + Stderr string + ExitCode int +} + +// openSandboxEndpointResponse is the response from GET /v1/sandboxes/{id}/endpoints/{port}. +type openSandboxEndpointResponse struct { + Endpoint string `json:"endpoint"` +} + +func newOpenSandboxClient(apiURL, apiKey string) *openSandboxClient { + if !strings.HasPrefix(apiURL, "https://") && !strings.HasPrefix(apiURL, "http://") { + apiURL = "https://" + apiURL + } + return &openSandboxClient{ + apiURL: apiURL, + apiKey: apiKey, + httpClient: &http.Client{Timeout: 30 * time.Second}, + } +} + +// readErrorBody reads up to maxErrorBodyBytes from a response body for error messages. +func readErrorBody(body io.Reader) string { + data, _ := io.ReadAll(io.LimitReader(body, maxErrorBodyBytes)) + return string(data) +} + +// createSandbox creates a new sandbox with the given image. +// POST {apiURL}/v1/sandboxes +func (c *openSandboxClient) createSandbox(ctx context.Context, image string, timeout time.Duration) (*openSandboxCreateResponse, error) { + reqBody := map[string]any{ + "image": map[string]string{"uri": image}, + "resourceLimits": map[string]string{"cpu": "500m", "memory": "256Mi"}, + "entrypoint": []string{"sleep", fmt.Sprintf("%d", int(timeout.Seconds())+60)}, + } + if timeout > 0 { + reqBody["timeout"] = int(timeout.Seconds()) + 60 // sandbox lives slightly longer than execution timeout + } + + body, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("opensandbox: marshal create request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURL+"/v1/sandboxes", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("opensandbox: build create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + if c.apiKey != "" { + req.Header.Set("OPEN-SANDBOX-API-KEY", c.apiKey) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("opensandbox: create sandbox: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusAccepted { + return nil, fmt.Errorf("opensandbox: create sandbox: unexpected status %d: %s", resp.StatusCode, readErrorBody(resp.Body)) + } + + var result openSandboxCreateResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("opensandbox: decode create response: %w", err) + } + if !validSandboxIDPattern.MatchString(result.ID) { + return nil, fmt.Errorf("opensandbox: invalid sandbox ID returned by API: %q", result.ID) + } + return &result, nil +} + +// getExecdURL retrieves the execd endpoint for a sandbox and returns a full HTTP URL. +// GET {apiURL}/v1/sandboxes/{id}/endpoints/{port} +func (c *openSandboxClient) getExecdURL(ctx context.Context, sandboxID string) (string, error) { + url := fmt.Sprintf("%s/v1/sandboxes/%s/endpoints/%d", c.apiURL, sandboxID, execdPort) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", fmt.Errorf("opensandbox: build endpoint request: %w", err) + } + if c.apiKey != "" { + req.Header.Set("OPEN-SANDBOX-API-KEY", c.apiKey) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("opensandbox: get endpoint: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("opensandbox: get endpoint: unexpected status %d: %s", resp.StatusCode, readErrorBody(resp.Body)) + } + + var result openSandboxEndpointResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("opensandbox: decode endpoint response: %w", err) + } + + ep := result.Endpoint + if !strings.HasPrefix(ep, "http://") && !strings.HasPrefix(ep, "https://") { + ep = "http://" + ep + } + return ep, nil +} + +// waitForExecd polls the execd /ping endpoint until it responds 200 or context expires. +func (c *openSandboxClient) waitForExecd(ctx context.Context, execdURL string) error { + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return fmt.Errorf("opensandbox: execd not ready: %w", ctx.Err()) + case <-ticker.C: + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, execdURL+"/ping", nil) + if err != nil { + continue + } + resp, err := c.httpClient.Do(req) + if err != nil { + continue + } + resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return nil + } + } +} + +// uploadFile uploads a file to the sandbox via multipart form. +// POST {execdURL}/files/upload with "metadata" file (JSON {"path": "..."}) + "file" file. +func (c *openSandboxClient) uploadFile(ctx context.Context, execdURL, filename, content string) error { + var buf bytes.Buffer + writer := multipart.NewWriter(&buf) + + // metadata must be uploaded as a file part, not a text field + metaPart, err := writer.CreateFormFile("metadata", "metadata.json") + if err != nil { + return fmt.Errorf("opensandbox: create metadata part: %w", err) + } + metaJSON, _ := json.Marshal(map[string]string{"path": "/workspace/" + filename}) + if _, err := metaPart.Write(metaJSON); err != nil { + return fmt.Errorf("opensandbox: write metadata: %w", err) + } + + filePart, err := writer.CreateFormFile("file", filename) + if err != nil { + return fmt.Errorf("opensandbox: create form file: %w", err) + } + if _, err := io.WriteString(filePart, content); err != nil { + return fmt.Errorf("opensandbox: write file content: %w", err) + } + writer.Close() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, execdURL+"/files/upload", &buf) + if err != nil { + return fmt.Errorf("opensandbox: build upload request: %w", err) + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("opensandbox: upload file: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusNoContent { + return fmt.Errorf("opensandbox: upload file: unexpected status %d: %s", resp.StatusCode, readErrorBody(resp.Body)) + } + return nil +} + +// executeCommand runs a command inside the sandbox. +// POST {execdURL}/command with body {"command": command}. +// The response is streamed as JSONL events; this method collects stdout/stderr and exit code. +func (c *openSandboxClient) executeCommand(ctx context.Context, execdURL, command string) (*openSandboxExecResponse, error) { + body, err := json.Marshal(map[string]string{"command": command}) + if err != nil { + return nil, fmt.Errorf("opensandbox: marshal exec request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, execdURL+"/command", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("opensandbox: build exec request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("opensandbox: execute command: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("opensandbox: execute command: unexpected status %d: %s", resp.StatusCode, readErrorBody(resp.Body)) + } + + return parseExecdStream(resp.Body) +} + +// execdEvent represents a single event in the execd streaming response. +type execdEvent struct { + Type string `json:"type"` + Text string `json:"text"` + Error *struct { + EValue string `json:"evalue"` + } `json:"error,omitempty"` +} + +// parseExecdStream reads the JSONL streaming response from execd and aggregates results. +func parseExecdStream(r io.Reader) (*openSandboxExecResponse, error) { + var stdout, stderr strings.Builder + exitCode := 0 + + scanner := bufio.NewScanner(io.LimitReader(r, maxResponseBytes)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + var event execdEvent + if err := json.Unmarshal([]byte(line), &event); err != nil { + continue // skip unparseable lines + } + + switch event.Type { + case "stdout": + stdout.WriteString(event.Text) + case "stderr": + stderr.WriteString(event.Text) + case "error": + if event.Error != nil { + // evalue contains the exit code as string + fmt.Sscanf(event.Error.EValue, "%d", &exitCode) + } + if exitCode == 0 { + exitCode = 1 // error event with no parseable exit code + } + } + } + + return &openSandboxExecResponse{ + Stdout: stdout.String(), + Stderr: stderr.String(), + ExitCode: exitCode, + }, nil +} + +// deleteSandbox deletes a sandbox by ID. +// DELETE {apiURL}/v1/sandboxes/{id} +func (c *openSandboxClient) deleteSandbox(ctx context.Context, sandboxID string) error { + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, fmt.Sprintf("%s/v1/sandboxes/%s", c.apiURL, sandboxID), nil) + if err != nil { + return fmt.Errorf("opensandbox: build delete request: %w", err) + } + if c.apiKey != "" { + req.Header.Set("OPEN-SANDBOX-API-KEY", c.apiKey) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("opensandbox: delete sandbox: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusAccepted { + return fmt.Errorf("opensandbox: delete sandbox: unexpected status %d: %s", resp.StatusCode, readErrorBody(resp.Body)) + } + return nil +} + +// healthCheck returns true if the OpenSandbox lifecycle API is reachable and healthy. +// GET {apiURL}/health +func (c *openSandboxClient) healthCheck(ctx context.Context) bool { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.apiURL+"/health", nil) + if err != nil { + return false + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + + return resp.StatusCode == http.StatusOK +} diff --git a/internal/sandbox/opensandbox_client_test.go b/internal/sandbox/opensandbox_client_test.go new file mode 100644 index 000000000..0a3a00d37 --- /dev/null +++ b/internal/sandbox/opensandbox_client_test.go @@ -0,0 +1,283 @@ +package sandbox + +import ( + "context" + "encoding/json" + "io" + "mime" + "mime/multipart" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestOpenSandboxClientCreateSandbox(t *testing.T) { + const wantAPIKey = "test-api-key" + const fakeSandboxID = "sb-abc123" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/sandboxes" { + t.Errorf("unexpected path: got %s, want /v1/sandboxes", r.URL.Path) + } + if r.Method != http.MethodPost { + t.Errorf("unexpected method: got %s, want POST", r.Method) + } + gotKey := r.Header.Get("OPEN-SANDBOX-API-KEY") + if gotKey != wantAPIKey { + t.Errorf("unexpected API key header: got %q, want %q", gotKey, wantAPIKey) + } + + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Errorf("decode request body: %v", err) + } + // image should be {"uri": "..."} + img, ok := body["image"].(map[string]any) + if !ok { + t.Errorf("expected image to be object, got %T", body["image"]) + } else if img["uri"] != "python:3.11" { + t.Errorf("unexpected image uri: got %v", img["uri"]) + } + // entrypoint should be an array + if _, ok := body["entrypoint"]; !ok { + t.Error("expected entrypoint in request body") + } + // resourceLimits should be present + if _, ok := body["resourceLimits"]; !ok { + t.Error("expected resourceLimits in request body") + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + json.NewEncoder(w).Encode(openSandboxCreateResponse{ID: fakeSandboxID}) + })) + defer server.Close() + + client := newOpenSandboxClient(server.URL, wantAPIKey) + resp, err := client.createSandbox(t.Context(), "python:3.11", 60*time.Second) + if err != nil { + t.Fatalf("createSandbox returned error: %v", err) + } + if resp.ID != fakeSandboxID { + t.Errorf("ID: got %q, want %q", resp.ID, fakeSandboxID) + } +} + +func TestOpenSandboxClientGetExecdURL(t *testing.T) { + const fakeSandboxID = "sb-abc123" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("unexpected method: %s", r.Method) + } + if !strings.Contains(r.URL.Path, "/endpoints/") { + t.Errorf("unexpected path: %s", r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(openSandboxEndpointResponse{ + Endpoint: "localhost:8080/sandboxes/" + fakeSandboxID + "/proxy/44772", + }) + })) + defer server.Close() + + client := newOpenSandboxClient(server.URL, "key") + url, err := client.getExecdURL(t.Context(), fakeSandboxID) + if err != nil { + t.Fatalf("getExecdURL returned error: %v", err) + } + if !strings.HasPrefix(url, "http://") { + t.Errorf("expected http:// prefix, got: %s", url) + } + if !strings.Contains(url, fakeSandboxID) { + t.Errorf("expected sandbox ID in URL, got: %s", url) + } +} + +func TestOpenSandboxClientUploadFile(t *testing.T) { + const wantFilename = "script.py" + const wantContent = "print('hello')" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/files/upload" { + t.Errorf("unexpected path: got %s, want /files/upload", r.URL.Path) + } + + mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil || !strings.HasPrefix(mediaType, "multipart/") { + t.Errorf("expected multipart content-type, got %q", r.Header.Get("Content-Type")) + w.WriteHeader(http.StatusBadRequest) + return + } + + mr := multipart.NewReader(r.Body, params["boundary"]) + gotMetadata := false + gotFile := false + for { + part, err := mr.NextPart() + if err == io.EOF { + break + } + if err != nil { + t.Errorf("read multipart: %v", err) + break + } + data, _ := io.ReadAll(part) + switch part.FormName() { + case "metadata": + gotMetadata = true + var meta map[string]string + if err := json.Unmarshal(data, &meta); err != nil { + t.Errorf("metadata not valid JSON: %v", err) + } else if !strings.Contains(meta["path"], wantFilename) { + t.Errorf("metadata path: got %q, want to contain %q", meta["path"], wantFilename) + } + case "file": + gotFile = true + if string(data) != wantContent { + t.Errorf("file content: got %q, want %q", string(data), wantContent) + } + } + } + if !gotMetadata { + t.Error("missing metadata part") + } + if !gotFile { + t.Error("missing file part") + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := newOpenSandboxClient("http://unused", "key") + err := client.uploadFile(t.Context(), server.URL, wantFilename, wantContent) + if err != nil { + t.Fatalf("uploadFile returned error: %v", err) + } +} + +func TestOpenSandboxClientExecuteCommand(t *testing.T) { + const wantCommand = "python3 /workspace/script.py" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/command" { + t.Errorf("unexpected path: got %s, want /command", r.URL.Path) + } + + var body map[string]string + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Errorf("decode request body: %v", err) + } + if body["command"] != wantCommand { + t.Errorf("command: got %q, want %q", body["command"], wantCommand) + } + + // Return streaming JSONL response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"type":"init","text":"abc123","timestamp":1000}` + "\n\n")) + w.Write([]byte(`{"type":"stdout","text":"hello world\n","timestamp":1001}` + "\n\n")) + w.Write([]byte(`{"type":"execution_complete","execution_time":5,"timestamp":1002}` + "\n\n")) + })) + defer server.Close() + + client := newOpenSandboxClient("http://unused", "key") + resp, err := client.executeCommand(t.Context(), server.URL, wantCommand) + if err != nil { + t.Fatalf("executeCommand returned error: %v", err) + } + if resp.Stdout != "hello world\n" { + t.Errorf("Stdout: got %q, want %q", resp.Stdout, "hello world\n") + } + if resp.ExitCode != 0 { + t.Errorf("ExitCode: got %d, want 0", resp.ExitCode) + } +} + +func TestOpenSandboxClientExecuteCommandError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"type":"stderr","text":"error msg","timestamp":1000}` + "\n")) + w.Write([]byte(`{"type":"error","timestamp":1001,"error":{"ename":"CommandExecError","evalue":"2","traceback":["exit status 2"]}}` + "\n")) + })) + defer server.Close() + + client := newOpenSandboxClient("http://unused", "key") + resp, err := client.executeCommand(t.Context(), server.URL, "bad-cmd") + if err != nil { + t.Fatalf("executeCommand returned error: %v", err) + } + if resp.Stderr != "error msg" { + t.Errorf("Stderr: got %q, want %q", resp.Stderr, "error msg") + } + if resp.ExitCode != 2 { + t.Errorf("ExitCode: got %d, want 2", resp.ExitCode) + } +} + +func TestOpenSandboxClientDeleteSandbox(t *testing.T) { + const wantID = "sb-xyz789" + const wantAPIKey = "delete-key" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wantPath := "/v1/sandboxes/" + wantID + if r.URL.Path != wantPath { + t.Errorf("unexpected path: got %s, want %s", r.URL.Path, wantPath) + } + if r.Method != http.MethodDelete { + t.Errorf("unexpected method: got %s, want DELETE", r.Method) + } + gotKey := r.Header.Get("OPEN-SANDBOX-API-KEY") + if gotKey != wantAPIKey { + t.Errorf("API key header: got %q, want %q", gotKey, wantAPIKey) + } + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + client := newOpenSandboxClient(server.URL, wantAPIKey) + err := client.deleteSandbox(t.Context(), wantID) + if err != nil { + t.Fatalf("deleteSandbox returned error: %v", err) + } +} + +func TestOpenSandboxClientHealthCheck(t *testing.T) { + t.Run("healthy", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/health" { + t.Errorf("unexpected path: got %s, want /health", r.URL.Path) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"healthy"}`)) + })) + defer server.Close() + + client := newOpenSandboxClient(server.URL, "key") + if !client.healthCheck(t.Context()) { + t.Error("healthCheck returned false, want true") + } + }) + + t.Run("unhealthy", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer server.Close() + + client := newOpenSandboxClient(server.URL, "key") + if client.healthCheck(t.Context()) { + t.Error("healthCheck returned true, want false") + } + }) + + t.Run("unreachable", func(t *testing.T) { + client := newOpenSandboxClient("http://127.0.0.1:1", "key") + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if client.healthCheck(ctx) { + t.Error("healthCheck returned true for unreachable server") + } + }) +} diff --git a/internal/sandbox/opensandbox_test.go b/internal/sandbox/opensandbox_test.go new file mode 100644 index 000000000..e51743801 --- /dev/null +++ b/internal/sandbox/opensandbox_test.go @@ -0,0 +1,114 @@ +package sandbox + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestOpenSandboxSandboxType(t *testing.T) { + s := NewOpenSandboxSandbox(nil) + if got := s.Type(); got != SandboxTypeOpenSandbox { + t.Errorf("Type() = %q, want %q", got, SandboxTypeOpenSandbox) + } +} + +func TestOpenSandboxSandboxExecute(t *testing.T) { + t.Setenv("SSRF_WHITELIST", "127.0.0.1") + + const fakeSandboxID = "sb-test-001" + const wantStdout = "hello from opensandbox\n" + const wantExitCode = 0 + + // Server 2: execd server (handles file upload, command execution, and ping) + execdServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/ping" && r.Method == http.MethodGet: + w.WriteHeader(http.StatusOK) + + case r.URL.Path == "/files/upload" && r.Method == http.MethodPost: + w.WriteHeader(http.StatusOK) + + case r.URL.Path == "/command" && r.Method == http.MethodPost: + var body map[string]string + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Errorf("execd: decode command body: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + cmd := body["command"] + if !strings.Contains(cmd, "/workspace/") { + t.Errorf("execd: command %q does not reference /workspace/", cmd) + } + // Return streaming JSONL response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"type":"stdout","text":%q,"timestamp":1000}`+"\n\n", wantStdout) + fmt.Fprintf(w, `{"type":"execution_complete","execution_time":5,"timestamp":1001}`+"\n\n") + + default: + t.Errorf("execd: unexpected request: %s %s", r.Method, r.URL.Path) + w.WriteHeader(http.StatusNotFound) + } + })) + defer execdServer.Close() + + // Server 1: lifecycle server (handles sandbox create/delete/endpoints) + lifecycleServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/v1/sandboxes" && r.Method == http.MethodPost: + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + json.NewEncoder(w).Encode(openSandboxCreateResponse{ID: fakeSandboxID}) + + case strings.Contains(r.URL.Path, "/endpoints/") && r.Method == http.MethodGet: + w.Header().Set("Content-Type", "application/json") + // Return execd server URL as endpoint (strip http:// prefix since client adds it back) + ep := strings.TrimPrefix(execdServer.URL, "http://") + json.NewEncoder(w).Encode(openSandboxEndpointResponse{Endpoint: ep}) + + case strings.HasPrefix(r.URL.Path, "/v1/sandboxes/") && r.Method == http.MethodDelete: + gotID := strings.TrimPrefix(r.URL.Path, "/v1/sandboxes/") + if gotID != fakeSandboxID { + t.Errorf("lifecycle: delete called with id %q, want %q", gotID, fakeSandboxID) + } + w.WriteHeader(http.StatusNoContent) + + default: + t.Errorf("lifecycle: unexpected request: %s %s", r.Method, r.URL.Path) + w.WriteHeader(http.StatusNotFound) + } + })) + defer lifecycleServer.Close() + + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "test_script.py") + if err := os.WriteFile(scriptPath, []byte("print('hello from opensandbox')"), 0644); err != nil { + t.Fatalf("write temp script: %v", err) + } + + cfg := DefaultConfig() + cfg.OpenSandboxAPIURL = lifecycleServer.URL + cfg.OpenSandboxAPIKey = "test-key" + + sb := NewOpenSandboxSandbox(cfg) + + result, err := sb.Execute(t.Context(), &ExecuteConfig{ + Script: scriptPath, + }) + if err != nil { + t.Fatalf("Execute returned error: %v", err) + } + + if result.Stdout != wantStdout { + t.Errorf("Stdout: got %q, want %q", result.Stdout, wantStdout) + } + if result.ExitCode != wantExitCode { + t.Errorf("ExitCode: got %d, want %d", result.ExitCode, wantExitCode) + } +} diff --git a/internal/sandbox/sandbox.go b/internal/sandbox/sandbox.go index 29f22cf62..663ae6f7c 100644 --- a/internal/sandbox/sandbox.go +++ b/internal/sandbox/sandbox.go @@ -5,6 +5,10 @@ package sandbox import ( "context" "errors" + "fmt" + "os" + "path/filepath" + "strings" "time" ) @@ -18,6 +22,10 @@ const ( SandboxTypeLocal SandboxType = "local" // SandboxTypeDisabled means script execution is disabled SandboxTypeDisabled SandboxType = "disabled" + // SandboxTypeKubernetes uses K8s Jobs for isolation + SandboxTypeKubernetes SandboxType = "kubernetes" + // SandboxTypeOpenSandbox uses external OpenSandbox service + SandboxTypeOpenSandbox SandboxType = "opensandbox" ) // Default configuration values @@ -25,7 +33,12 @@ const ( DefaultTimeout = 60 * time.Second DefaultMemoryLimit = 256 * 1024 * 1024 // 256MB DefaultCPULimit = 1.0 // 1 CPU core - DefaultDockerImage = "wechatopenai/weknora-sandbox:latest" + DefaultDockerImage = "wechatopenai/weknora-sandbox:latest" + DefaultMaxConcurrentSandboxes = 5 + DefaultMaxScriptSize = 512 * 1024 // 512KB (ConfigMap has 1MB limit; leave room for metadata) + DefaultMaxLogSize = 1 * 1024 * 1024 // 1MB + DefaultKubeNamespace = "weknora-sandbox" + DefaultKubeServiceAccount = "weknora-sandbox-runner" ) // Common errors @@ -149,6 +162,59 @@ func (r *ExecuteResult) GetOutput() string { return r.Stderr } +const stdinFileName = ".stdin" + +// shellQuote wraps s in single quotes, escaping any embedded single quotes. +// Null bytes are stripped to prevent shell string truncation attacks. +func shellQuote(s string) string { + s = strings.ReplaceAll(s, "\x00", "") + return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'" +} + +// buildShellCommand builds a shell command string from interpreter, script path, and args. +// All arguments are shell-quoted. If hasStdin is true, the command is wrapped with +// cat /workspace/.stdin | ... to pipe stdin content from a pre-uploaded file. +func buildShellCommand(interpreter, scriptPath string, args []string, hasStdin bool) string { + parts := []string{shellQuote(interpreter), shellQuote(scriptPath)} + for _, arg := range args { + parts = append(parts, shellQuote(arg)) + } + cmd := strings.Join(parts, " ") + if hasStdin { + cmd = "cat " + shellQuote("/workspace/"+stdinFileName) + " | " + cmd + } + return cmd +} + +// buildExecCommand builds an exec-form command (no shell) for use as a container entrypoint. +// Returns ["interpreter", "scriptPath", args...]. This form is preferred when no stdin +// piping is needed because it avoids shell interpretation entirely. +func buildExecCommand(interpreter, scriptPath string, args []string) []string { + cmd := []string{interpreter, scriptPath} + cmd = append(cmd, args...) + return cmd +} + +// resolveScript reads script content and name from an ExecuteConfig. +// It returns the script body, a filename suitable for use as an entry point, and any error. +func resolveScript(config *ExecuteConfig) (content string, name string, err error) { + if config.ScriptContent != "" { + name = "script.sh" + if config.Script != "" { + name = filepath.Base(config.Script) + } + return config.ScriptContent, name, nil + } + if config.Script != "" { + data, err := os.ReadFile(config.Script) + if err != nil { + return "", "", fmt.Errorf("%w: %v", ErrScriptNotFound, err) + } + return string(data), filepath.Base(config.Script), nil + } + return "", "", ErrInvalidScript +} + // Config holds sandbox manager configuration type Config struct { // Type is the preferred sandbox type @@ -174,18 +240,44 @@ type Config struct { // MaxCPU is the maximum CPU cores MaxCPU float64 + + // MaxConcurrentSandboxes limits parallel sandbox executions per instance (kubernetes/opensandbox) + MaxConcurrentSandboxes int + + // MaxScriptSize is the maximum script content size in bytes (kubernetes/opensandbox) + MaxScriptSize int64 + + // MaxLogSize is the maximum output size to read in bytes (kubernetes/opensandbox) + MaxLogSize int64 + + // KubeNamespace is the namespace for sandbox Jobs (kubernetes mode) + KubeNamespace string + + // KubeServiceAccount is the ServiceAccount name for sandbox Pods (kubernetes mode) + KubeServiceAccount string + + // OpenSandboxAPIURL is the OpenSandbox server URL (opensandbox mode) + OpenSandboxAPIURL string + + // OpenSandboxAPIKey is the API key for OpenSandbox (opensandbox mode) + OpenSandboxAPIKey string } // DefaultConfig returns a default sandbox configuration func DefaultConfig() *Config { return &Config{ - Type: SandboxTypeLocal, - FallbackEnabled: true, - DefaultTimeout: DefaultTimeout, - DockerImage: DefaultDockerImage, - AllowedCommands: defaultAllowedCommands(), - MaxMemory: DefaultMemoryLimit, - MaxCPU: DefaultCPULimit, + Type: SandboxTypeLocal, + FallbackEnabled: true, + DefaultTimeout: DefaultTimeout, + DockerImage: DefaultDockerImage, + AllowedCommands: defaultAllowedCommands(), + MaxMemory: DefaultMemoryLimit, + MaxCPU: DefaultCPULimit, + MaxConcurrentSandboxes: DefaultMaxConcurrentSandboxes, + MaxScriptSize: DefaultMaxScriptSize, + MaxLogSize: DefaultMaxLogSize, + KubeNamespace: DefaultKubeNamespace, + KubeServiceAccount: DefaultKubeServiceAccount, } } @@ -222,7 +314,7 @@ func ValidateConfig(config *Config) error { } switch config.Type { - case SandboxTypeDocker, SandboxTypeLocal, SandboxTypeDisabled: + case SandboxTypeDocker, SandboxTypeLocal, SandboxTypeDisabled, SandboxTypeKubernetes, SandboxTypeOpenSandbox: // Valid types default: return errors.New("invalid sandbox type") diff --git a/internal/sandbox/sandbox_test.go b/internal/sandbox/sandbox_test.go index dfc1383ce..349f51067 100644 --- a/internal/sandbox/sandbox_test.go +++ b/internal/sandbox/sandbox_test.go @@ -231,6 +231,60 @@ func TestExecuteResultHelpers(t *testing.T) { } } +func TestValidateConfigKubernetes(t *testing.T) { + config := &Config{ + Type: SandboxTypeKubernetes, + DefaultTimeout: 30 * time.Second, + } + if err := ValidateConfig(config); err != nil { + t.Errorf("Expected kubernetes type to be valid, got: %v", err) + } +} + +func TestValidateConfigOpenSandbox(t *testing.T) { + config := &Config{ + Type: SandboxTypeOpenSandbox, + DefaultTimeout: 30 * time.Second, + } + if err := ValidateConfig(config); err != nil { + t.Errorf("Expected opensandbox type to be valid, got: %v", err) + } +} + +func TestDefaultConfigSharedFields(t *testing.T) { + config := DefaultConfig() + if config.MaxConcurrentSandboxes != 5 { + t.Errorf("Expected MaxConcurrentSandboxes=5, got %d", config.MaxConcurrentSandboxes) + } + if config.MaxScriptSize != 512*1024 { + t.Errorf("Expected MaxScriptSize=524288, got %d", config.MaxScriptSize) + } + if config.MaxLogSize != 1048576 { + t.Errorf("Expected MaxLogSize=1048576, got %d", config.MaxLogSize) + } +} + +func TestNewManagerFromTypeKubernetes(t *testing.T) { + manager, err := NewManagerFromType("kubernetes", true, "sandbox:latest") + if err != nil { + t.Fatalf("Expected fallback to succeed, got: %v", err) + } + // Not in K8s, should fall back to local + if manager.GetType() != SandboxTypeLocal { + t.Logf("Manager type: %s (may be kubernetes if running in K8s)", manager.GetType()) + } +} + +func TestNewManagerFromTypeOpenSandbox(t *testing.T) { + manager, err := NewManagerFromType("opensandbox", true, "") + if err != nil { + t.Fatalf("Expected fallback to succeed, got: %v", err) + } + if manager.GetType() != SandboxTypeLocal { + t.Logf("Manager type: %s (may be opensandbox if server running)", manager.GetType()) + } +} + func TestPythonScriptExecution(t *testing.T) { // Create a temporary Python script tmpDir, err := os.MkdirTemp("", "sandbox-test")