diff --git a/pkg/flag/flag.go b/pkg/flag/flag.go index 36a6a53a07..e009271a68 100644 --- a/pkg/flag/flag.go +++ b/pkg/flag/flag.go @@ -23,6 +23,7 @@ var SST_EXPERIMENTAL = isTrue("SST_EXPERIMENTAL") || isTrue("SST_EXPERIMENTAL_RU var SST_RUN_ID = os.Getenv("SST_RUN_ID") var SST_SKIP_APPSYNC = isTrue("SST_SKIP_APPSYNC") var SST_NO_BUN = isTrue("NO_BUN") || isTrue("SST_NO_BUN") +var SST_STATE_COMPRESS = !isFalse("SST_STATE_COMPRESS") func isTrue(name string) bool { val, ok := os.LookupEnv(name) @@ -37,3 +38,17 @@ func isTrue(name string) bool { } return false } + +func isFalse(name string) bool { + val, ok := os.LookupEnv(name) + if !ok { + return false + } + if val == "0" { + return true + } + if val == "false" { + return true + } + return false +} diff --git a/pkg/project/provider/aws.go b/pkg/project/provider/aws.go index 1ee95fc941..45fd2125a0 100644 --- a/pkg/project/provider/aws.go +++ b/pkg/project/provider/aws.go @@ -1,6 +1,8 @@ package provider import ( + "bytes" + "compress/gzip" "context" "encoding/json" "errors" @@ -23,6 +25,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/smithy-go" "github.com/sst/sst/v3/internal/util" + "github.com/sst/sst/v3/pkg/flag" ecrTypes "github.com/aws/aws-sdk-go-v2/service/ecr/types" s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" @@ -574,7 +577,7 @@ func (a *AwsHome) getData(key, app, stage string) (io.Reader, error) { } return nil, err } - return result.Body, nil + return decodeHomeReader(result.Body, result.ContentEncoding) } func (a *AwsHome) putData(key, app, stage string, data io.Reader) error { @@ -584,11 +587,17 @@ func (a *AwsHome) putData(key, app, stage string, data io.Reader) error { } s3Client := s3.NewFromConfig(a.provider.config) + body, contentEncoding, err := prepareHomeUpload(data, flag.SST_STATE_COMPRESS) + if err != nil { + return err + } + _, err = s3Client.PutObject(context.TODO(), &s3.PutObjectInput{ - Bucket: aws.String(bootstrap.State), - Key: aws.String(a.pathForData(key, app, stage)), - Body: data, - ContentType: aws.String("application/json"), + Bucket: aws.String(bootstrap.State), + Key: aws.String(a.pathForData(key, app, stage)), + Body: body, + ContentType: aws.String("application/json"), + ContentEncoding: contentEncoding, }) if err != nil { return err @@ -597,6 +606,64 @@ func (a *AwsHome) putData(key, app, stage string, data io.Reader) error { return nil } +func prepareHomeUpload(data io.Reader, compressEnabled bool) (io.Reader, *string, error) { + var buffer bytes.Buffer + _, err := io.Copy(&buffer, data) + if err != nil { + return nil, nil, err + } + if !compressEnabled { + return bytes.NewReader(buffer.Bytes()), nil, nil + } + + var compressed bytes.Buffer + writer := gzip.NewWriter(&compressed) + _, err = io.Copy(writer, bytes.NewReader(buffer.Bytes())) + if err != nil { + writer.Close() + return nil, nil, err + } + err = writer.Close() + if err != nil { + return nil, nil, err + } + + return bytes.NewReader(compressed.Bytes()), aws.String("gzip"), nil +} + +func decodeHomeReader(body io.ReadCloser, contentEncoding *string) (io.Reader, error) { + if contentEncoding != nil && strings.EqualFold(*contentEncoding, "gzip") { + reader, err := gzip.NewReader(body) + if err != nil { + body.Close() + return nil, err + } + return &gzipReadCloser{ + Reader: reader, + body: body, + }, nil + } + return body, nil +} + +type gzipReadCloser struct { + Reader *gzip.Reader + body io.Closer +} + +func (g *gzipReadCloser) Read(p []byte) (int, error) { + return g.Reader.Read(p) +} + +func (g *gzipReadCloser) Close() error { + readerErr := g.Reader.Close() + bodyErr := g.body.Close() + if readerErr != nil { + return readerErr + } + return bodyErr +} + func (a *AwsHome) removeData(key, app, stage string) error { bootstrap, err := a.provider.Bootstrap(a.provider.config.Region) if err != nil { diff --git a/pkg/project/provider/aws_test.go b/pkg/project/provider/aws_test.go new file mode 100644 index 0000000000..ec0b545245 --- /dev/null +++ b/pkg/project/provider/aws_test.go @@ -0,0 +1,83 @@ +package provider + +import ( + "bytes" + "io" + "testing" +) + +func TestPrepareStateUploadCompressed(t *testing.T) { + t.Parallel() + + input := bytes.Repeat([]byte(`{"resource":"value","nested":{"enabled":true}}`), 1024) + body, contentEncoding, err := prepareHomeUpload(bytes.NewReader(input), true) + if err != nil { + t.Fatalf("prepareHomeUpload returned error: %v", err) + } + if contentEncoding == nil || *contentEncoding != "gzip" { + t.Fatalf("expected gzip content encoding, got %v", contentEncoding) + } + + encoded, err := io.ReadAll(body) + if err != nil { + t.Fatalf("reading encoded data failed: %v", err) + } + if len(encoded) >= len(input) { + t.Fatalf("expected compressed upload to be smaller than raw input, got raw=%d uploaded=%d", len(input), len(encoded)) + } + + reader, err := decodeHomeReader(io.NopCloser(bytes.NewReader(encoded)), contentEncoding) + if err != nil { + t.Fatalf("decodeHomeReader returned error: %v", err) + } + decoded, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("reading decoded data failed: %v", err) + } + if !bytes.Equal(decoded, input) { + t.Fatal("decoded payload did not match original input") + } +} + +func TestPrepareStateUploadUncompressed(t *testing.T) { + t.Parallel() + + input := []byte(`{"resource":"value"}`) + body, contentEncoding, err := prepareHomeUpload(bytes.NewReader(input), false) + if err != nil { + t.Fatalf("prepareHomeUpload returned error: %v", err) + } + if contentEncoding != nil { + t.Fatalf("expected no content encoding, got %v", *contentEncoding) + } + + encoded, err := io.ReadAll(body) + if err != nil { + t.Fatalf("reading encoded data failed: %v", err) + } + if !bytes.Equal(encoded, input) { + t.Fatal("plain upload payload did not match original input") + } + + reader, err := decodeHomeReader(io.NopCloser(bytes.NewReader(encoded)), contentEncoding) + if err != nil { + t.Fatalf("decodeHomeReader returned error: %v", err) + } + decoded, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("reading decoded data failed: %v", err) + } + if !bytes.Equal(decoded, input) { + t.Fatal("decoded payload did not match original input") + } +} + +func TestDecodeStateReaderRejectsInvalidGzip(t *testing.T) { + t.Parallel() + + encoding := "gzip" + _, err := decodeHomeReader(io.NopCloser(bytes.NewReader([]byte("not-gzip"))), &encoding) + if err == nil { + t.Fatal("expected invalid gzip payload to return an error") + } +}