Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 66 additions & 7 deletions cmd/sst/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package main

import (
"encoding/json"
"errors"
"fmt"
"os"
"strings"
"time"

"github.com/pulumi/pulumi/sdk/v3/go/common/apitype"
"github.com/sst/sst/v3/cmd/sst/cli"
"github.com/sst/sst/v3/cmd/sst/mosaic/ui"
"github.com/sst/sst/v3/internal/util"
Expand Down Expand Up @@ -349,9 +351,22 @@ var CmdState = &cli.Command{
"sst state repair --stage production",
"```",
"",
"If the current state cannot be read, you can opt into restoring the latest valid",
"snapshot with `--dangerously-revert`. This is unsafe and can orphan or recreate",
"resources on future deploys.",
"",
"By default, it runs on your personal stage.",
}, "\n"),
},
Flags: []cli.Flag{
{
Name: "dangerously-revert",
Type: "bool",
Description: cli.Description{
Short: "Dangerously restore latest valid snapshot if current state cannot be read",
},
},
},
Run: func(c *cli.Cli) error {
p, err := c.InitProject()
if err != nil {
Expand All @@ -374,18 +389,41 @@ var CmdState = &cli.Command{
}
defer workdir.Cleanup()

_, err = workdir.Pull()
if err != nil {
return util.NewReadableError(err, "Could not pull state")
_, pullErr := workdir.Pull()
if pullErr != nil && !errors.Is(pullErr, provider.ErrStateNotFound) {
return util.NewReadableError(pullErr, "Could not pull state")
}

checkpoint, err := workdir.Export()
if err != nil {
return util.NewReadableError(err, "Could not export state")
recoveredSnapshotID := ""
var checkpoint *apitype.CheckpointV3
if pullErr == nil {
checkpoint, err = workdir.Export()
}
if pullErr != nil || err != nil {
if !c.Bool("dangerously-revert") {
return repairSnapshotFlagRequired(pullErr, err)
}
snapshot, snapshotID, recoverErr := provider.LatestValidSnapshot(p.Backend(), p.App().Name, p.App().Stage)
if recoverErr != nil {
return util.NewReadableError(recoverErr, "Could not recover state")
}
err = workdir.ImportRaw(snapshot)
if err != nil {
return util.NewReadableError(err, "Could not restore snapshot")
}
recoveredSnapshotID = snapshotID
checkpoint, err = workdir.Export()
if err != nil {
return util.NewReadableError(err, "Could not export state")
}
}

if checkpoint == nil {
return util.NewReadableError(nil, "Could not export state")
}

muts := state.Repair(checkpoint)
err = confirmMutations(muts)
err = confirmRepairMutations(recoveredSnapshotID, muts)
if err != nil {
return err
}
Expand Down Expand Up @@ -447,6 +485,27 @@ func confirmMutations(muts []state.Mutation) error {
return nil
}

func confirmRepairMutations(snapshotID string, muts []state.Mutation) error {
if snapshotID != "" {
fmt.Printf("Recovering state from snapshot: %s\n", snapshotID)
}
if len(muts) == 0 {
if snapshotID == "" {
return util.NewReadableError(nil, "No changes made")
}
return nil
}
return confirmMutations(muts)
}

func repairSnapshotFlagRequired(pullErr error, exportErr error) error {
cause := exportErr
if cause == nil {
cause = pullErr
}
return util.NewReadableError(cause, "State is missing or corrupted. Re-run `sst state repair --dangerously-revert` to restore the latest valid snapshot. This can orphan or recreate resources.")
}

func indent(key string) string {
return fmt.Sprintf("%-12s", key)
}
Expand Down
34 changes: 34 additions & 0 deletions cmd/sst/state_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package main

import (
"errors"
"testing"

"github.com/sst/sst/v3/internal/util"
)

func TestConfirmRepairMutations_AllowsSnapshotOnlyRecovery(t *testing.T) {
t.Parallel()

err := confirmRepairMutations("fe628cc5182389b648c70130", nil)
if err != nil {
t.Fatal(err)
}
}

func TestRepairSnapshotFlagRequired(t *testing.T) {
t.Parallel()

cause := errors.New("EOF")
err := repairSnapshotFlagRequired(nil, cause)
if err == nil {
t.Fatal("expected error")
}
var readable *util.ReadableError
if !errors.As(err, &readable) {
t.Fatalf("expected readable error, got %T", err)
}
if readable.Error() != "State is missing or corrupted. Re-run `sst state repair --dangerously-revert` to restore the latest valid snapshot. This can orphan or recreate resources." {
t.Fatalf("unexpected error: %q", readable.Error())
}
}
59 changes: 41 additions & 18 deletions pkg/project/provider/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,43 @@ func (a *AwsHome) pathForData(key, app, stage string) string {
return path.Join(key, app, fmt.Sprintf("%v.json", stage))
}

func (a *AwsHome) listData(key, app, stage string) ([]string, error) {
bootstrap, err := a.provider.Bootstrap(a.provider.config.Region)
if err != nil {
return nil, err
}
s3Client := s3.NewFromConfig(a.provider.config)
prefix := path.Join(key, app)
if stage != "" {
prefix = path.Join(prefix, stage)
}
prefix += "/"
result := []string{}
base := path.Join(key, app) + "/"
var continuationToken *string
for {
data, err := s3Client.ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{
Bucket: aws.String(bootstrap.State),
Prefix: aws.String(prefix),
ContinuationToken: continuationToken,
})
if err != nil {
return nil, err
}
for _, obj := range data.Contents {
name := strings.TrimPrefix(*obj.Key, base)
if strings.HasSuffix(name, ".json") {
result = append(result, strings.TrimSuffix(name, ".json"))
}
}
if data.IsTruncated == nil || !*data.IsTruncated {
break
}
continuationToken = data.NextContinuationToken
}
return result, nil
}

func (a *AwsHome) pathForPassphrase(app string, stage string) string {
return "/" + strings.Join([]string{"sst", "passphrase", app, stage}, "/")
}
Expand Down Expand Up @@ -705,30 +742,16 @@ func (a *AwsHome) setPassphrase(app, stage, passphrase string) error {
}

func (a *AwsHome) listStages(app string) ([]string, error) {
bootstrap, err := a.provider.Bootstrap(a.provider.config.Region)
if err != nil {
return nil, err
}
s3Client := s3.NewFromConfig(a.provider.config)

data, err := s3Client.ListObjects(context.TODO(), &s3.ListObjectsInput{
Bucket: aws.String(bootstrap.State),
Prefix: aws.String(path.Join("app", app)),
})

data, err := a.listData("app", app, "")
if err != nil {
return nil, err
}

stages := []string{}

for _, obj := range data.Contents {
filename := path.Base(*obj.Key)
if strings.HasSuffix(filename, ".json") {
stageName := strings.TrimSuffix(filename, ".json")
if hasResources(a, app, stageName) {
stages = append(stages, stageName)
}
for _, stageName := range data {
if hasResources(a, app, stageName) {
stages = append(stages, stageName)
}
}

Expand Down
100 changes: 73 additions & 27 deletions pkg/project/provider/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"log/slog"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -81,18 +82,37 @@ type CloudflareHome struct {
sync.Mutex
provider *CloudflareProvider
bootstrap *bootstrap
request func(*cloudflare.API, context.Context, string, string, interface{}) ([]byte, error)
}

func NewCloudflareHome(provider *CloudflareProvider) *CloudflareHome {
return &CloudflareHome{
provider: provider,
request: makeRequestContext,
}
}

type bootstrap struct {
State string `json:"state"`
}

type r2Object struct {
Key string `json:"key"`
}

type r2ResponseInfo struct {
Cursor string `json:"cursor"`
IsTruncated bool `json:"is_truncated"`
}

type r2Response struct {
Success bool `json:"success"`
Errors []string `json:"errors"`
Messages []string `json:"messages"`
Result []r2Object `json:"result"`
ResultInfo r2ResponseInfo `json:"result_info"`
}

func (c *CloudflareHome) cleanup(key, app, stage string) error {
return nil
}
Expand Down Expand Up @@ -133,11 +153,23 @@ func (c *CloudflareHome) Bootstrap() error {
//go:linkname makeRequestContext github.com/cloudflare/cloudflare-go.(*API).makeRequestContext
func makeRequestContext(*cloudflare.API, context.Context, string, string, interface{}) ([]byte, error)

func (c *CloudflareHome) requestContext(method, path string, body interface{}) ([]byte, error) {
request := c.request
if request == nil {
request = makeRequestContext
}
return request(c.provider.api, context.Background(), method, path, body)
}

func (c *CloudflareHome) putData(kind, app, stage string, data io.Reader) error {
c.Lock()
defer c.Unlock()
path := filepath.Join(kind, app, stage)
_, err := makeRequestContext(c.provider.api, context.Background(), http.MethodPut, "/accounts/"+c.provider.identifier.Identifier+"/r2/buckets/"+c.bootstrap.State+"/objects/"+path, data)
body, err := io.ReadAll(data)
if err != nil {
return err
}
_, err = c.requestContext(http.MethodPut, "/accounts/"+c.provider.identifier.Identifier+"/r2/buckets/"+c.bootstrap.State+"/objects/"+path, body)
if err != nil {
return err
}
Expand All @@ -148,7 +180,7 @@ func (c *CloudflareHome) getData(kind, app, stage string) (io.Reader, error) {
c.Lock()
defer c.Unlock()
path := filepath.Join(kind, app, stage)
data, err := makeRequestContext(c.provider.api, context.Background(), http.MethodGet, "/accounts/"+c.provider.identifier.Identifier+"/r2/buckets/"+c.bootstrap.State+"/objects/"+path, nil)
data, err := c.requestContext(http.MethodGet, "/accounts/"+c.provider.identifier.Identifier+"/r2/buckets/"+c.bootstrap.State+"/objects/"+path, nil)
if err != nil {
if err.Error() == "The specified key does not exist. (10007)" {
return nil, nil
Expand All @@ -158,11 +190,47 @@ func (c *CloudflareHome) getData(kind, app, stage string) (io.Reader, error) {
return bytes.NewReader(data), nil
}

func (c *CloudflareHome) listData(kind, app, stage string) ([]string, error) {
c.Lock()
defer c.Unlock()
prefix := filepath.Join(kind, app)
if stage != "" {
prefix = filepath.Join(prefix, stage)
}
prefix += "/"
base := filepath.Join(kind, app) + "/"
pathPrefix := "/accounts/" + c.provider.identifier.Identifier + "/r2/buckets/" + c.bootstrap.State + "/objects?prefix=" + url.QueryEscape(prefix)
result := []string{}
cursor := ""
for {
requestPath := pathPrefix
if cursor != "" {
requestPath += "&cursor=" + url.QueryEscape(cursor)
}
data, err := c.requestContext(http.MethodGet, requestPath, nil)
if err != nil {
return nil, err
}
var response r2Response
if err := json.Unmarshal(data, &response); err != nil {
return nil, err
}
for _, obj := range response.Result {
result = append(result, strings.TrimPrefix(obj.Key, base))
}
if !response.ResultInfo.IsTruncated || response.ResultInfo.Cursor == "" {
break
}
cursor = response.ResultInfo.Cursor
}
return result, nil
}

func (c *CloudflareHome) removeData(kind, app, stage string) error {
c.Lock()
defer c.Unlock()
path := filepath.Join(kind, app, stage)
_, err := makeRequestContext(c.provider.api, context.Background(), http.MethodDelete, "/accounts/"+c.provider.identifier.Identifier+"/r2/buckets/"+c.bootstrap.State+"/objects/"+path, nil)
_, err := c.requestContext(http.MethodDelete, "/accounts/"+c.provider.identifier.Identifier+"/r2/buckets/"+c.bootstrap.State+"/objects/"+path, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -190,36 +258,14 @@ func (c *CloudflareHome) getPassphrase(app, stage string) (string, error) {
}

func (c *CloudflareHome) listStages(app string) ([]string, error) {
type r2Object struct {
Key string `json:"key"`
}

type r2Response struct {
Success bool `json:"success"`
Errors []string `json:"errors"`
Messages []string `json:"messages"`
Result []r2Object `json:"result"`
}

path := "/accounts/" + c.provider.identifier.Identifier + "/r2/buckets/" + c.bootstrap.State + "/objects?prefix=" + filepath.Join("app", app)

data, err := makeRequestContext(c.provider.api, context.Background(), http.MethodGet, path, nil)

if err != nil {
return nil, err
}

var response r2Response
err = json.Unmarshal(data, &response)
entries, err := c.listData("app", app, "")
if err != nil {
return nil, err
}

stages := []string{}

for _, obj := range response.Result {
segments := strings.Split(obj.Key, "/")
stageName := segments[len(segments)-1]
for _, stageName := range entries {
if hasResources(c, app, stageName) {
stages = append(stages, stageName)
}
Expand Down
Loading
Loading