Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions pkg/flag/flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var SST_EXPERIMENTAL = isTrue("SST_EXPERIMENTAL") || isTrue("SST_EXPERIMENTAL_RU
var SST_RUN_ID = os.Getenv("SST_RUN_ID")
var SST_SKIP_APPSYNC = isTrue("SST_SKIP_APPSYNC")
var SST_NO_BUN = isTrue("NO_BUN") || isTrue("SST_NO_BUN")
var SST_STATE_COMPRESS = !isFalse("SST_STATE_COMPRESS")

func isTrue(name string) bool {
val, ok := os.LookupEnv(name)
Expand All @@ -37,3 +38,17 @@ func isTrue(name string) bool {
}
return false
}

func isFalse(name string) bool {
val, ok := os.LookupEnv(name)
if !ok {
return false
}
if val == "0" {
return true
}
if val == "false" {
return true
}
return false
}
77 changes: 72 additions & 5 deletions pkg/project/provider/aws.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package provider

import (
"bytes"
"compress/gzip"
"context"
"encoding/json"
"errors"
Expand All @@ -23,6 +25,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/smithy-go"
"github.com/sst/sst/v3/internal/util"
"github.com/sst/sst/v3/pkg/flag"

ecrTypes "github.com/aws/aws-sdk-go-v2/service/ecr/types"
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
Expand Down Expand Up @@ -574,7 +577,7 @@ func (a *AwsHome) getData(key, app, stage string) (io.Reader, error) {
}
return nil, err
}
return result.Body, nil
return decodeHomeReader(result.Body, result.ContentEncoding)
}

func (a *AwsHome) putData(key, app, stage string, data io.Reader) error {
Expand All @@ -584,11 +587,17 @@ func (a *AwsHome) putData(key, app, stage string, data io.Reader) error {
}
s3Client := s3.NewFromConfig(a.provider.config)

body, contentEncoding, err := prepareHomeUpload(data, flag.SST_STATE_COMPRESS)
if err != nil {
return err
}

_, err = s3Client.PutObject(context.TODO(), &s3.PutObjectInput{
Bucket: aws.String(bootstrap.State),
Key: aws.String(a.pathForData(key, app, stage)),
Body: data,
ContentType: aws.String("application/json"),
Bucket: aws.String(bootstrap.State),
Key: aws.String(a.pathForData(key, app, stage)),
Body: body,
ContentType: aws.String("application/json"),
ContentEncoding: contentEncoding,
})
if err != nil {
return err
Expand All @@ -597,6 +606,64 @@ func (a *AwsHome) putData(key, app, stage string, data io.Reader) error {
return nil
}

func prepareHomeUpload(data io.Reader, compressEnabled bool) (io.Reader, *string, error) {
var buffer bytes.Buffer
_, err := io.Copy(&buffer, data)
if err != nil {
return nil, nil, err
}
if !compressEnabled {
return bytes.NewReader(buffer.Bytes()), nil, nil
}

var compressed bytes.Buffer
writer := gzip.NewWriter(&compressed)
_, err = io.Copy(writer, bytes.NewReader(buffer.Bytes()))
if err != nil {
writer.Close()
return nil, nil, err
}
err = writer.Close()
if err != nil {
return nil, nil, err
}

return bytes.NewReader(compressed.Bytes()), aws.String("gzip"), nil
}

func decodeHomeReader(body io.ReadCloser, contentEncoding *string) (io.Reader, error) {
if contentEncoding != nil && strings.EqualFold(*contentEncoding, "gzip") {
reader, err := gzip.NewReader(body)
if err != nil {
body.Close()
return nil, err
}
return &gzipReadCloser{
Reader: reader,
body: body,
}, nil
}
return body, nil
}

type gzipReadCloser struct {
Reader *gzip.Reader
body io.Closer
}

func (g *gzipReadCloser) Read(p []byte) (int, error) {
return g.Reader.Read(p)
}

func (g *gzipReadCloser) Close() error {
readerErr := g.Reader.Close()
bodyErr := g.body.Close()
if readerErr != nil {
return readerErr
}
return bodyErr
}

func (a *AwsHome) removeData(key, app, stage string) error {
bootstrap, err := a.provider.Bootstrap(a.provider.config.Region)
if err != nil {
Expand Down
83 changes: 83 additions & 0 deletions pkg/project/provider/aws_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package provider

import (
"bytes"
"io"
"testing"
)

func TestPrepareStateUploadCompressed(t *testing.T) {
t.Parallel()

input := bytes.Repeat([]byte(`{"resource":"value","nested":{"enabled":true}}`), 1024)
body, contentEncoding, err := prepareHomeUpload(bytes.NewReader(input), true)
if err != nil {
t.Fatalf("prepareHomeUpload returned error: %v", err)
}
if contentEncoding == nil || *contentEncoding != "gzip" {
t.Fatalf("expected gzip content encoding, got %v", contentEncoding)
}

encoded, err := io.ReadAll(body)
if err != nil {
t.Fatalf("reading encoded data failed: %v", err)
}
if len(encoded) >= len(input) {
t.Fatalf("expected compressed upload to be smaller than raw input, got raw=%d uploaded=%d", len(input), len(encoded))
}

reader, err := decodeHomeReader(io.NopCloser(bytes.NewReader(encoded)), contentEncoding)
if err != nil {
t.Fatalf("decodeHomeReader returned error: %v", err)
}
decoded, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("reading decoded data failed: %v", err)
}
if !bytes.Equal(decoded, input) {
t.Fatal("decoded payload did not match original input")
}
}

func TestPrepareStateUploadUncompressed(t *testing.T) {
t.Parallel()

input := []byte(`{"resource":"value"}`)
body, contentEncoding, err := prepareHomeUpload(bytes.NewReader(input), false)
if err != nil {
t.Fatalf("prepareHomeUpload returned error: %v", err)
}
if contentEncoding != nil {
t.Fatalf("expected no content encoding, got %v", *contentEncoding)
}

encoded, err := io.ReadAll(body)
if err != nil {
t.Fatalf("reading encoded data failed: %v", err)
}
if !bytes.Equal(encoded, input) {
t.Fatal("plain upload payload did not match original input")
}

reader, err := decodeHomeReader(io.NopCloser(bytes.NewReader(encoded)), contentEncoding)
if err != nil {
t.Fatalf("decodeHomeReader returned error: %v", err)
}
decoded, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("reading decoded data failed: %v", err)
}
if !bytes.Equal(decoded, input) {
t.Fatal("decoded payload did not match original input")
}
}

func TestDecodeStateReaderRejectsInvalidGzip(t *testing.T) {
t.Parallel()

encoding := "gzip"
_, err := decodeHomeReader(io.NopCloser(bytes.NewReader([]byte("not-gzip"))), &encoding)
if err == nil {
t.Fatal("expected invalid gzip payload to return an error")
}
}