From 309ef07cdd8f261ee6f9994f93c4816a94119c26 Mon Sep 17 00:00:00 2001 From: EduSantosBrito Date: Fri, 10 Apr 2026 15:31:12 +0200 Subject: [PATCH] fix(state): add guarded snapshot fallback for state repair --- cmd/sst/state.go | 73 +++++++++++++++-- cmd/sst/state_test.go | 34 ++++++++ pkg/project/provider/aws.go | 59 +++++++++----- pkg/project/provider/cloudflare.go | 100 ++++++++++++++++------- pkg/project/provider/cloudflare_test.go | 37 +++++++++ pkg/project/provider/local.go | 36 ++++++--- pkg/project/provider/provider.go | 52 ++++++++++++ pkg/project/provider/snapshot_test.go | 101 ++++++++++++++++++++++++ pkg/project/workdir.go | 21 ++++- pkg/project/workdir_test.go | 44 +++++++++++ 10 files changed, 494 insertions(+), 63 deletions(-) create mode 100644 cmd/sst/state_test.go create mode 100644 pkg/project/provider/cloudflare_test.go create mode 100644 pkg/project/provider/snapshot_test.go create mode 100644 pkg/project/workdir_test.go diff --git a/cmd/sst/state.go b/cmd/sst/state.go index e45909c03e..f880a92a74 100644 --- a/cmd/sst/state.go +++ b/cmd/sst/state.go @@ -2,11 +2,13 @@ package main import ( "encoding/json" + "errors" "fmt" "os" "strings" "time" + "github.com/pulumi/pulumi/sdk/v3/go/common/apitype" "github.com/sst/sst/v3/cmd/sst/cli" "github.com/sst/sst/v3/cmd/sst/mosaic/ui" "github.com/sst/sst/v3/internal/util" @@ -349,9 +351,22 @@ var CmdState = &cli.Command{ "sst state repair --stage production", "```", "", + "If the current state cannot be read, you can opt into restoring the latest valid", + "snapshot with `--dangerously-revert`. This is unsafe and can orphan or recreate", + "resources on future deploys.", + "", "By default, it runs on your personal stage.", }, "\n"), }, + Flags: []cli.Flag{ + { + Name: "dangerously-revert", + Type: "bool", + Description: cli.Description{ + Short: "Dangerously restore latest valid snapshot if current state cannot be read", + }, + }, + }, Run: func(c *cli.Cli) error { p, err := c.InitProject() if err != nil { @@ -374,18 +389,41 @@ var CmdState = &cli.Command{ } defer workdir.Cleanup() - _, err = workdir.Pull() - if err != nil { - return util.NewReadableError(err, "Could not pull state") + _, pullErr := workdir.Pull() + if pullErr != nil && !errors.Is(pullErr, provider.ErrStateNotFound) { + return util.NewReadableError(pullErr, "Could not pull state") } - checkpoint, err := workdir.Export() - if err != nil { - return util.NewReadableError(err, "Could not export state") + recoveredSnapshotID := "" + var checkpoint *apitype.CheckpointV3 + if pullErr == nil { + checkpoint, err = workdir.Export() + } + if pullErr != nil || err != nil { + if !c.Bool("dangerously-revert") { + return repairSnapshotFlagRequired(pullErr, err) + } + snapshot, snapshotID, recoverErr := provider.LatestValidSnapshot(p.Backend(), p.App().Name, p.App().Stage) + if recoverErr != nil { + return util.NewReadableError(recoverErr, "Could not recover state") + } + err = workdir.ImportRaw(snapshot) + if err != nil { + return util.NewReadableError(err, "Could not restore snapshot") + } + recoveredSnapshotID = snapshotID + checkpoint, err = workdir.Export() + if err != nil { + return util.NewReadableError(err, "Could not export state") + } + } + + if checkpoint == nil { + return util.NewReadableError(nil, "Could not export state") } muts := state.Repair(checkpoint) - err = confirmMutations(muts) + err = confirmRepairMutations(recoveredSnapshotID, muts) if err != nil { return err } @@ -447,6 +485,27 @@ func confirmMutations(muts []state.Mutation) error { return nil } +func confirmRepairMutations(snapshotID string, muts []state.Mutation) error { + if snapshotID != "" { + fmt.Printf("Recovering state from snapshot: %s\n", snapshotID) + } + if len(muts) == 0 { + if snapshotID == "" { + return util.NewReadableError(nil, "No changes made") + } + return nil + } + return confirmMutations(muts) +} + +func repairSnapshotFlagRequired(pullErr error, exportErr error) error { + cause := exportErr + if cause == nil { + cause = pullErr + } + return util.NewReadableError(cause, "State is missing or corrupted. Re-run `sst state repair --dangerously-revert` to restore the latest valid snapshot. This can orphan or recreate resources.") +} + func indent(key string) string { return fmt.Sprintf("%-12s", key) } diff --git a/cmd/sst/state_test.go b/cmd/sst/state_test.go new file mode 100644 index 0000000000..40c33d4f40 --- /dev/null +++ b/cmd/sst/state_test.go @@ -0,0 +1,34 @@ +package main + +import ( + "errors" + "testing" + + "github.com/sst/sst/v3/internal/util" +) + +func TestConfirmRepairMutations_AllowsSnapshotOnlyRecovery(t *testing.T) { + t.Parallel() + + err := confirmRepairMutations("fe628cc5182389b648c70130", nil) + if err != nil { + t.Fatal(err) + } +} + +func TestRepairSnapshotFlagRequired(t *testing.T) { + t.Parallel() + + cause := errors.New("EOF") + err := repairSnapshotFlagRequired(nil, cause) + if err == nil { + t.Fatal("expected error") + } + var readable *util.ReadableError + if !errors.As(err, &readable) { + t.Fatalf("expected readable error, got %T", err) + } + if readable.Error() != "State is missing or corrupted. Re-run `sst state repair --dangerously-revert` to restore the latest valid snapshot. This can orphan or recreate resources." { + t.Fatalf("unexpected error: %q", readable.Error()) + } +} diff --git a/pkg/project/provider/aws.go b/pkg/project/provider/aws.go index 1ee95fc941..6a2bb73961 100644 --- a/pkg/project/provider/aws.go +++ b/pkg/project/provider/aws.go @@ -546,6 +546,43 @@ func (a *AwsHome) pathForData(key, app, stage string) string { return path.Join(key, app, fmt.Sprintf("%v.json", stage)) } +func (a *AwsHome) listData(key, app, stage string) ([]string, error) { + bootstrap, err := a.provider.Bootstrap(a.provider.config.Region) + if err != nil { + return nil, err + } + s3Client := s3.NewFromConfig(a.provider.config) + prefix := path.Join(key, app) + if stage != "" { + prefix = path.Join(prefix, stage) + } + prefix += "/" + result := []string{} + base := path.Join(key, app) + "/" + var continuationToken *string + for { + data, err := s3Client.ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{ + Bucket: aws.String(bootstrap.State), + Prefix: aws.String(prefix), + ContinuationToken: continuationToken, + }) + if err != nil { + return nil, err + } + for _, obj := range data.Contents { + name := strings.TrimPrefix(*obj.Key, base) + if strings.HasSuffix(name, ".json") { + result = append(result, strings.TrimSuffix(name, ".json")) + } + } + if data.IsTruncated == nil || !*data.IsTruncated { + break + } + continuationToken = data.NextContinuationToken + } + return result, nil +} + func (a *AwsHome) pathForPassphrase(app string, stage string) string { return "/" + strings.Join([]string{"sst", "passphrase", app, stage}, "/") } @@ -705,30 +742,16 @@ func (a *AwsHome) setPassphrase(app, stage, passphrase string) error { } func (a *AwsHome) listStages(app string) ([]string, error) { - bootstrap, err := a.provider.Bootstrap(a.provider.config.Region) - if err != nil { - return nil, err - } - s3Client := s3.NewFromConfig(a.provider.config) - - data, err := s3Client.ListObjects(context.TODO(), &s3.ListObjectsInput{ - Bucket: aws.String(bootstrap.State), - Prefix: aws.String(path.Join("app", app)), - }) - + data, err := a.listData("app", app, "") if err != nil { return nil, err } stages := []string{} - for _, obj := range data.Contents { - filename := path.Base(*obj.Key) - if strings.HasSuffix(filename, ".json") { - stageName := strings.TrimSuffix(filename, ".json") - if hasResources(a, app, stageName) { - stages = append(stages, stageName) - } + for _, stageName := range data { + if hasResources(a, app, stageName) { + stages = append(stages, stageName) } } diff --git a/pkg/project/provider/cloudflare.go b/pkg/project/provider/cloudflare.go index 2f6cbf8a93..cb872817f0 100644 --- a/pkg/project/provider/cloudflare.go +++ b/pkg/project/provider/cloudflare.go @@ -8,6 +8,7 @@ import ( "io" "log/slog" "net/http" + "net/url" "os" "path/filepath" "strings" @@ -81,11 +82,13 @@ type CloudflareHome struct { sync.Mutex provider *CloudflareProvider bootstrap *bootstrap + request func(*cloudflare.API, context.Context, string, string, interface{}) ([]byte, error) } func NewCloudflareHome(provider *CloudflareProvider) *CloudflareHome { return &CloudflareHome{ provider: provider, + request: makeRequestContext, } } @@ -93,6 +96,23 @@ type bootstrap struct { State string `json:"state"` } +type r2Object struct { + Key string `json:"key"` +} + +type r2ResponseInfo struct { + Cursor string `json:"cursor"` + IsTruncated bool `json:"is_truncated"` +} + +type r2Response struct { + Success bool `json:"success"` + Errors []string `json:"errors"` + Messages []string `json:"messages"` + Result []r2Object `json:"result"` + ResultInfo r2ResponseInfo `json:"result_info"` +} + func (c *CloudflareHome) cleanup(key, app, stage string) error { return nil } @@ -133,11 +153,23 @@ func (c *CloudflareHome) Bootstrap() error { //go:linkname makeRequestContext github.com/cloudflare/cloudflare-go.(*API).makeRequestContext func makeRequestContext(*cloudflare.API, context.Context, string, string, interface{}) ([]byte, error) +func (c *CloudflareHome) requestContext(method, path string, body interface{}) ([]byte, error) { + request := c.request + if request == nil { + request = makeRequestContext + } + return request(c.provider.api, context.Background(), method, path, body) +} + func (c *CloudflareHome) putData(kind, app, stage string, data io.Reader) error { c.Lock() defer c.Unlock() path := filepath.Join(kind, app, stage) - _, err := makeRequestContext(c.provider.api, context.Background(), http.MethodPut, "/accounts/"+c.provider.identifier.Identifier+"/r2/buckets/"+c.bootstrap.State+"/objects/"+path, data) + body, err := io.ReadAll(data) + if err != nil { + return err + } + _, err = c.requestContext(http.MethodPut, "/accounts/"+c.provider.identifier.Identifier+"/r2/buckets/"+c.bootstrap.State+"/objects/"+path, body) if err != nil { return err } @@ -148,7 +180,7 @@ func (c *CloudflareHome) getData(kind, app, stage string) (io.Reader, error) { c.Lock() defer c.Unlock() path := filepath.Join(kind, app, stage) - data, err := makeRequestContext(c.provider.api, context.Background(), http.MethodGet, "/accounts/"+c.provider.identifier.Identifier+"/r2/buckets/"+c.bootstrap.State+"/objects/"+path, nil) + data, err := c.requestContext(http.MethodGet, "/accounts/"+c.provider.identifier.Identifier+"/r2/buckets/"+c.bootstrap.State+"/objects/"+path, nil) if err != nil { if err.Error() == "The specified key does not exist. (10007)" { return nil, nil @@ -158,11 +190,47 @@ func (c *CloudflareHome) getData(kind, app, stage string) (io.Reader, error) { return bytes.NewReader(data), nil } +func (c *CloudflareHome) listData(kind, app, stage string) ([]string, error) { + c.Lock() + defer c.Unlock() + prefix := filepath.Join(kind, app) + if stage != "" { + prefix = filepath.Join(prefix, stage) + } + prefix += "/" + base := filepath.Join(kind, app) + "/" + pathPrefix := "/accounts/" + c.provider.identifier.Identifier + "/r2/buckets/" + c.bootstrap.State + "/objects?prefix=" + url.QueryEscape(prefix) + result := []string{} + cursor := "" + for { + requestPath := pathPrefix + if cursor != "" { + requestPath += "&cursor=" + url.QueryEscape(cursor) + } + data, err := c.requestContext(http.MethodGet, requestPath, nil) + if err != nil { + return nil, err + } + var response r2Response + if err := json.Unmarshal(data, &response); err != nil { + return nil, err + } + for _, obj := range response.Result { + result = append(result, strings.TrimPrefix(obj.Key, base)) + } + if !response.ResultInfo.IsTruncated || response.ResultInfo.Cursor == "" { + break + } + cursor = response.ResultInfo.Cursor + } + return result, nil +} + func (c *CloudflareHome) removeData(kind, app, stage string) error { c.Lock() defer c.Unlock() path := filepath.Join(kind, app, stage) - _, err := makeRequestContext(c.provider.api, context.Background(), http.MethodDelete, "/accounts/"+c.provider.identifier.Identifier+"/r2/buckets/"+c.bootstrap.State+"/objects/"+path, nil) + _, err := c.requestContext(http.MethodDelete, "/accounts/"+c.provider.identifier.Identifier+"/r2/buckets/"+c.bootstrap.State+"/objects/"+path, nil) if err != nil { return err } @@ -190,36 +258,14 @@ func (c *CloudflareHome) getPassphrase(app, stage string) (string, error) { } func (c *CloudflareHome) listStages(app string) ([]string, error) { - type r2Object struct { - Key string `json:"key"` - } - - type r2Response struct { - Success bool `json:"success"` - Errors []string `json:"errors"` - Messages []string `json:"messages"` - Result []r2Object `json:"result"` - } - - path := "/accounts/" + c.provider.identifier.Identifier + "/r2/buckets/" + c.bootstrap.State + "/objects?prefix=" + filepath.Join("app", app) - - data, err := makeRequestContext(c.provider.api, context.Background(), http.MethodGet, path, nil) - - if err != nil { - return nil, err - } - - var response r2Response - err = json.Unmarshal(data, &response) + entries, err := c.listData("app", app, "") if err != nil { return nil, err } stages := []string{} - for _, obj := range response.Result { - segments := strings.Split(obj.Key, "/") - stageName := segments[len(segments)-1] + for _, stageName := range entries { if hasResources(c, app, stageName) { stages = append(stages, stageName) } diff --git a/pkg/project/provider/cloudflare_test.go b/pkg/project/provider/cloudflare_test.go new file mode 100644 index 0000000000..fd9d6b8bc0 --- /dev/null +++ b/pkg/project/provider/cloudflare_test.go @@ -0,0 +1,37 @@ +package provider + +import ( + "bytes" + "context" + "testing" + + cloudflare "github.com/cloudflare/cloudflare-go" +) + +func TestCloudflarePutDataUsesBytePayload(t *testing.T) { + var gotBody interface{} + + home := &CloudflareHome{ + provider: &CloudflareProvider{ + identifier: &cloudflare.ResourceContainer{Identifier: "account"}, + }, + bootstrap: &bootstrap{State: "sst-state"}, + request: func(api *cloudflare.API, ctx context.Context, method string, path string, body interface{}) ([]byte, error) { + gotBody = body + return nil, nil + }, + } + + err := home.putData("app", "demo", "dev", bytes.NewReader([]byte("hello"))) + if err != nil { + t.Fatal(err) + } + + body, ok := gotBody.([]byte) + if !ok { + t.Fatalf("expected []byte body, got %T", gotBody) + } + if string(body) != "hello" { + t.Fatalf("unexpected body: %q", string(body)) + } +} diff --git a/pkg/project/provider/local.go b/pkg/project/provider/local.go index e250b24a44..621bd1116f 100644 --- a/pkg/project/provider/local.go +++ b/pkg/project/provider/local.go @@ -89,21 +89,39 @@ func (l *LocalHome) pathForData(key, app, stage string) string { return filepath.Join(global.ConfigDir(), "state", key, app, fmt.Sprintf("%v.json", stage)) } -func (a *LocalHome) listStages(app string) ([]string, error) { - path := filepath.Join(global.ConfigDir(), "state", "app", app) - +func (l *LocalHome) listData(key, app, stage string) ([]string, error) { + path := filepath.Join(global.ConfigDir(), "state", key, app, stage) entries, err := os.ReadDir(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + result := []string{} + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") { + continue + } + name := strings.TrimSuffix(entry.Name(), ".json") + if stage != "" { + name = stage + "/" + name + } + result = append(result, name) + } + return result, nil +} + +func (a *LocalHome) listStages(app string) ([]string, error) { + entries, err := a.listData("app", app, "") if err != nil { return nil, err } var stages []string - for _, entry := range entries { - if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".json") { - stageName := strings.TrimSuffix(entry.Name(), ".json") - if hasResources(a, app, stageName) { - stages = append(stages, stageName) - } + for _, stageName := range entries { + if hasResources(a, app, stageName) { + stages = append(stages, stageName) } } diff --git a/pkg/project/provider/provider.go b/pkg/project/provider/provider.go index f2f1a8acb9..d559f76374 100644 --- a/pkg/project/provider/provider.go +++ b/pkg/project/provider/provider.go @@ -10,9 +10,12 @@ import ( "fmt" "io" "os" + "sort" + "strings" "time" "github.com/pulumi/pulumi/pkg/v3/resource/stack" + "github.com/pulumi/pulumi/sdk/v3/go/common/apitype" "github.com/pulumi/pulumi/sdk/v3/go/common/encoding" "github.com/sst/sst/v3/internal/util" "github.com/sst/sst/v3/pkg/flag" @@ -23,6 +26,7 @@ import ( type Home interface { Bootstrap() error getData(key, app, stage string) (io.Reader, error) + listData(key, app, stage string) ([]string, error) putData(key, app, stage string, data io.Reader) error removeData(key, app, stage string) error setPassphrase(app, stage string, passphrase string) error @@ -64,8 +68,56 @@ const SSM_NAME_BOOTSTRAP = "/sst/bootstrap" var ErrLockExists = fmt.Errorf("Concurrent update detected, run `sst unlock --stage=` to delete lock file and retry.") var ErrLockNotFound = fmt.Errorf("Lock not found") +var ErrSnapshotNotFound = fmt.Errorf("snapshot not found") var passphraseCache = map[Home]map[string]string{} +func LatestValidSnapshot(backend Home, app, stage string) ([]byte, string, error) { + snapshots, err := backend.listData("snapshot", app, stage) + if err != nil { + return nil, "", err + } + sort.Strings(snapshots) + for _, snapshot := range snapshots { + if stage != "" && !strings.HasPrefix(snapshot, stage+"/") { + continue + } + reader, err := backend.getData("snapshot", app, snapshot) + if err != nil { + return nil, "", err + } + if reader == nil { + continue + } + data, err := io.ReadAll(reader) + if err != nil { + return nil, "", err + } + if !isValidSnapshot(data) { + continue + } + return data, strings.TrimPrefix(snapshot, stage+"/"), nil + } + return nil, "", ErrSnapshotNotFound +} + +func isValidSnapshot(data []byte) bool { + if len(data) == 0 { + return false + } + var untyped apitype.VersionedCheckpoint + if err := json.Unmarshal(data, &untyped); err != nil { + return false + } + if len(untyped.Checkpoint) == 0 { + return false + } + var checkpoint apitype.CheckpointV3 + if err := json.Unmarshal(untyped.Checkpoint, &checkpoint); err != nil { + return false + } + return checkpoint.Latest != nil +} + func Copy(from Home, to Home, app, stage string) error { reader, err := from.getData("app", app, stage) if err != nil { diff --git a/pkg/project/provider/snapshot_test.go b/pkg/project/provider/snapshot_test.go new file mode 100644 index 0000000000..06b5f162ec --- /dev/null +++ b/pkg/project/provider/snapshot_test.go @@ -0,0 +1,101 @@ +package provider + +import ( + "bytes" + "encoding/json" + "io" + "testing" + + "github.com/pulumi/pulumi/sdk/v3/go/common/apitype" + "github.com/sst/sst/v3/internal/util" +) + +type snapshotTestHome struct { + keys []string + data map[string][]byte +} + +func (s *snapshotTestHome) Bootstrap() error { + return nil +} + +func (s *snapshotTestHome) getData(key, app, stage string) (io.Reader, error) { + data, ok := s.data[key+"/"+app+"/"+stage] + if !ok { + return nil, nil + } + return bytes.NewReader(data), nil +} + +func (s *snapshotTestHome) putData(key, app, stage string, data io.Reader) error { + return nil +} + +func (s *snapshotTestHome) removeData(key, app, stage string) error { + return nil +} + +func (s *snapshotTestHome) setPassphrase(app, stage string, passphrase string) error { + return nil +} + +func (s *snapshotTestHome) getPassphrase(app, stage string) (string, error) { + return "", nil +} + +func (s *snapshotTestHome) listStages(app string) ([]string, error) { + return nil, nil +} + +func (s *snapshotTestHome) listData(key, app, stage string) ([]string, error) { + return s.keys, nil +} + +func (s *snapshotTestHome) cleanup(key, app, stage string) error { + return nil +} + +func (s *snapshotTestHome) info() (util.KeyValuePairs[string], error) { + return nil, nil +} + +func TestLatestValidSnapshotSkipsCorruptedSnapshots(t *testing.T) { + t.Parallel() + + checkpoint := apitype.CheckpointV3{ + Latest: &apitype.DeploymentV3{Resources: []apitype.ResourceV3{}}, + } + rawCheckpoint, err := json.Marshal(checkpoint) + if err != nil { + t.Fatal(err) + } + state, err := json.Marshal(apitype.VersionedCheckpoint{ + Version: 3, + Checkpoint: rawCheckpoint, + }) + if err != nil { + t.Fatal(err) + } + + home := &snapshotTestHome{ + keys: []string{ + "dev/fe6288d0ec9814453df3c388", + "dev/fe628cc5182389b648c70130", + }, + data: map[string][]byte{ + "snapshot/app/dev/fe6288d0ec9814453df3c388": {}, + "snapshot/app/dev/fe628cc5182389b648c70130": state, + }, + } + + recovered, updateID, err := LatestValidSnapshot(home, "app", "dev") + if err != nil { + t.Fatal(err) + } + if updateID != "fe628cc5182389b648c70130" { + t.Fatalf("expected fallback update id, got %q", updateID) + } + if !bytes.Equal(recovered, state) { + t.Fatal("expected recovered snapshot bytes") + } +} diff --git a/pkg/project/workdir.go b/pkg/project/workdir.go index 9796b3b670..7c451aee79 100644 --- a/pkg/project/workdir.go +++ b/pkg/project/workdir.go @@ -3,17 +3,21 @@ package project import ( "encoding/json" "fmt" + "io" "os" "path/filepath" "github.com/pulumi/pulumi/sdk/v3/go/common/apitype" + "github.com/sst/sst/v3/internal/util" "github.com/sst/sst/v3/pkg/flag" "github.com/sst/sst/v3/pkg/project/provider" "github.com/zeebo/xxh3" "golang.org/x/sync/errgroup" ) +const stateCorruptedMessage = "State file is empty or corrupted" + type PulumiWorkdir struct { path string project *Project @@ -126,15 +130,19 @@ func (w *PulumiWorkdir) Export() (*apitype.CheckpointV3, error) { if err != nil { return nil, err } + defer file.Close() err = json.NewDecoder(file).Decode(&untyped) if err != nil { - return nil, err + return nil, util.NewReadableError(err, stateCorruptedMessage) + } + if len(untyped.Checkpoint) == 0 { + return nil, util.NewReadableError(io.EOF, stateCorruptedMessage) } var result apitype.CheckpointV3 err = json.Unmarshal(untyped.Checkpoint, &result) if err != nil { - return nil, err + return nil, util.NewReadableError(err, stateCorruptedMessage) } return &result, nil @@ -162,6 +170,15 @@ func (w *PulumiWorkdir) Import(checkpoint *apitype.CheckpointV3) error { return nil } +func (w *PulumiWorkdir) ImportRaw(data []byte) error { + statePath := w.state() + err := os.MkdirAll(filepath.Dir(statePath), 0755) + if err != nil { + return err + } + return os.WriteFile(statePath, data, 0644) +} + func (w *PulumiWorkdir) EventLogPath() string { return filepath.Join(w.path, "eventlog.json") } diff --git a/pkg/project/workdir_test.go b/pkg/project/workdir_test.go new file mode 100644 index 0000000000..15b310e7a1 --- /dev/null +++ b/pkg/project/workdir_test.go @@ -0,0 +1,44 @@ +package project + +import ( + "errors" + "os" + "path/filepath" + "testing" + + "github.com/sst/sst/v3/internal/util" +) + +func TestExportReturnsReadableErrorForEmptyStateFile(t *testing.T) { + t.Parallel() + + workdir := &PulumiWorkdir{ + path: t.TempDir(), + project: &Project{ + app: &App{Name: "app", Stage: "dev"}, + }, + } + + statePath := workdir.state() + err := os.MkdirAll(filepath.Dir(statePath), 0755) + if err != nil { + t.Fatal(err) + } + err = os.WriteFile(statePath, []byte{}, 0644) + if err != nil { + t.Fatal(err) + } + + _, err = workdir.Export() + if err == nil { + t.Fatal("expected export to fail for an empty state file") + } + + var readable *util.ReadableError + if !errors.As(err, &readable) { + t.Fatalf("expected readable error, got %T: %v", err, err) + } + if readable.Error() != "State file is empty or corrupted" { + t.Fatalf("unexpected readable error message: %q", readable.Error()) + } +}