diff --git a/app.go b/app.go index 722277ff55d..de5578f9a8e 100644 --- a/app.go +++ b/app.go @@ -15,6 +15,7 @@ import ( "errors" "fmt" "io" + "io/fs" "net" "net/http" "net/http/httputil" @@ -175,6 +176,23 @@ type Config struct { //nolint:govet // Aligning the struct fields is not necessa // Default: nil Views Views `json:"-"` + // RootDir specifies the base directory for SaveFile/SaveFileToStorage uploads. + // Relative paths are resolved against this directory. + // + // Optional. Default: "" + RootDir string `json:"root_dir"` + + // RootPerms specifies the permissions used when creating RootDir or RootFs prefixes. + // + // Optional. Default: 0o750 + RootPerms fs.FileMode `json:"root_perms"` + + // RootFs specifies the filesystem used for SaveFile/SaveFileToStorage uploads. + // When set, RootDir is treated as a relative prefix within the filesystem. + // + // Optional. Default: nil + RootFs fs.FS `json:"-"` + // Views Layout is the global layout for all template render until override on Render function. // // Default: "" @@ -437,6 +455,15 @@ type Config struct { //nolint:govet // Aligning the struct fields is not necessa // // Optional. Default: a provider that returns context.Background() ServicesShutdownContextProvider func() context.Context + + uploadRootDir string + uploadRootEval string + uploadRootPath string + uploadRootFSPrefix string + uploadRootFSWriter interface { + fs.FS + OpenFile(name string, flag int, perm fs.FileMode) (fs.File, error) + } } // Default TrustProxyConfig @@ -605,6 +632,9 @@ func New(config ...Config) *App { "zstd": ".fiber.zst", } } + if app.config.RootPerms == 0 { + app.config.RootPerms = 0o750 + } if app.config.Immutable { app.toBytes, app.toString = toBytesImmutable, toStringImmutable @@ -642,6 +672,8 @@ func New(config ...Config) *App { app.config.RequestMethods = DefaultMethods } + app.configureUploads() + app.config.TrustProxyConfig.ips = make(map[string]struct{}, len(app.config.TrustProxyConfig.Proxies)) for _, ipAddress := range app.config.TrustProxyConfig.Proxies { app.handleTrustedProxy(ipAddress) diff --git a/app_test.go b/app_test.go index ed6b0780f4f..bcad64c3ad3 100644 --- a/app_test.go +++ b/app_test.go @@ -13,6 +13,7 @@ import ( "errors" "fmt" "io" + "io/fs" "mime/multipart" "net" "net/http" @@ -81,6 +82,149 @@ func testErrorResponse(t *testing.T, err error, resp *http.Response, expectedBod require.Equal(t, expectedBodyError, string(body), "Response body") } +type testUploadFS struct { + mkdirPath string + mkdirPerm fs.FileMode +} + +func (tfs *testUploadFS) Open(_ string) (fs.File, error) { + _ = tfs + return nil, fs.ErrNotExist +} + +func (tfs *testUploadFS) OpenFile(_ string, _ int, _ fs.FileMode) (fs.File, error) { + _ = tfs + return &testUploadFile{buf: &bytes.Buffer{}}, nil +} + +func (tfs *testUploadFS) MkdirAll(path string, perm fs.FileMode) error { + tfs.mkdirPath = path + tfs.mkdirPerm = perm + return nil +} + +func (tfs *testUploadFS) Remove(_ string) error { + _ = tfs + return nil +} + +type testUploadFile struct { + buf *bytes.Buffer +} + +func (tf *testUploadFile) Read(p []byte) (int, error) { + //nolint:wrapcheck // test helper passthrough + return tf.buf.Read(p) +} + +func (tf *testUploadFile) Write(p []byte) (int, error) { + //nolint:wrapcheck // test helper passthrough + return tf.buf.Write(p) +} + +func (tf *testUploadFile) Close() error { + _ = tf + return nil +} + +func (tf *testUploadFile) Stat() (fs.FileInfo, error) { + _ = tf + return testUploadFileInfo{name: "upload"}, nil +} + +type testUploadFileInfo struct { + name string +} + +func (fi testUploadFileInfo) Name() string { return fi.name } +func (fi testUploadFileInfo) Size() int64 { + _ = fi + return 0 +} + +func (fi testUploadFileInfo) Mode() fs.FileMode { + _ = fi + return 0 +} + +func (fi testUploadFileInfo) ModTime() time.Time { + _ = fi + return time.Time{} +} + +func (fi testUploadFileInfo) IsDir() bool { + _ = fi + return false +} + +func (fi testUploadFileInfo) Sys() any { + _ = fi + return nil +} + +func TestRootPermsRootFs(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" { + t.Skip("root perms are not validated on Windows in this test") + } + + tests := []struct { + name string + rootPerm fs.FileMode + wantPerm fs.FileMode + }{ + { + name: "default", + rootPerm: 0, + wantPerm: 0o750, + }, + { + name: "custom", + rootPerm: 0o700, + wantPerm: 0o700, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + tfs := &testUploadFS{} + New(Config{ + RootDir: "uploads", + RootFs: tfs, + RootPerms: tt.rootPerm, + }) + + if tfs.mkdirPath != "uploads" { + t.Fatalf("expected RootFs prefix %q, got %q", "uploads", tfs.mkdirPath) + } + if tfs.mkdirPerm != tt.wantPerm { + t.Fatalf("expected RootPerms %o, got %o", tt.wantPerm, tfs.mkdirPerm) + } + }) + } +} + +func TestValidateUploadPathPreservesLeadingDot(t *testing.T) { + t.Parallel() + + path := filepath.Join(".hidden", "file.txt") + + normalized, err := validateUploadPath(path) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if !strings.HasPrefix(normalized.osPath, ".") { + t.Fatalf("expected os path %q to preserve leading dot", normalized.osPath) + } + if normalized.slashPath != ".hidden/file.txt" { + t.Fatalf("expected slash path %q, got %q", ".hidden/file.txt", normalized.slashPath) + } +} + func Test_App_Test_Goroutine_Leak_Compare(t *testing.T) { t.Parallel() diff --git a/ctx.go b/ctx.go index 92ce2f69591..761a30700a9 100644 --- a/ctx.go +++ b/ctx.go @@ -7,10 +7,14 @@ package fiber import ( "context" "crypto/tls" + "errors" "fmt" "io" + "io/fs" "maps" "mime/multipart" + "os" + "path/filepath" "strconv" "strings" "sync/atomic" @@ -511,12 +515,57 @@ func (c *DefaultCtx) IsPreflight() bool { } // SaveFile saves any multipart file to disk. -func (*DefaultCtx) SaveFile(fileheader *multipart.FileHeader, path string) error { - return fasthttp.SaveMultipartFile(fileheader, path) +func (c *DefaultCtx) SaveFile(fileheader *multipart.FileHeader, path string) error { + normalized, err := validateUploadPath(path) + if err != nil { + return err + } + + if c.app.config.RootFs != nil { + fsPath := storageUploadPath(c.app.config.uploadRootFSPrefix, normalized.slashPath) + err = ensureNoSymlinkFS(c.app.config.RootFs, fsPath) + if err != nil { + return err + } + return saveMultipartFileToFS(fileheader, fsPath, c.app.config.uploadRootFSWriter) + } + + fullPath := normalized.osPath + if root := c.app.config.uploadRootDir; root != "" { + fullPath = filepath.Join(root, normalized.osPath) + err = ensureUploadPathWithinRoot(c.app.config.uploadRootEval, fullPath) + if err != nil { + return err + } + } + + return fasthttp.SaveMultipartFile(fileheader, fullPath) } // SaveFileToStorage saves any multipart file to an external storage system. func (c *DefaultCtx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error { + normalized, err := validateUploadPath(path) + if err != nil { + return err + } + + if c.app.config.RootFs != nil { + fsPath := storageUploadPath(c.app.config.uploadRootFSPrefix, normalized.slashPath) + err = ensureNoSymlinkFS(c.app.config.RootFs, fsPath) + if err != nil { + return err + } + } + if root := c.app.config.uploadRootDir; root != "" { + fullPath := filepath.Join(root, filepath.FromSlash(normalized.slashPath)) + err = ensureUploadPathWithinRoot(c.app.config.uploadRootEval, fullPath) + if err != nil { + return err + } + } + + storagePath := storageUploadPath(c.app.config.uploadRootPath, normalized.slashPath) + file, err := fileheader.Open() if err != nil { return fmt.Errorf("failed to open: %w", err) @@ -546,13 +595,52 @@ func (c *DefaultCtx) SaveFileToStorage(fileheader *multipart.FileHeader, path st data := append([]byte(nil), buf.Bytes()...) - if err := storage.SetWithContext(c.Context(), path, data, 0); err != nil { + if err := storage.SetWithContext(c.Context(), storagePath, data, 0); err != nil { return fmt.Errorf("failed to store: %w", err) } return nil } +func saveMultipartFileToFS( + fileheader *multipart.FileHeader, + path string, + fsys interface { + fs.FS + OpenFile(name string, flag int, perm fs.FileMode) (fs.File, error) + }, +) error { + file, err := fileheader.Open() + if err != nil { + return fmt.Errorf("failed to open multipart file: %w", err) + } + defer file.Close() //nolint:errcheck // not needed + + dst, err := fsys.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644) + if err != nil { + return fmt.Errorf("failed to open upload destination: %w", err) + } + writer, ok := dst.(io.Writer) + if !ok { + closeErr := dst.Close() + if closeErr != nil { + return fmt.Errorf("failed to close upload destination: %w", closeErr) + } + return errors.New("failed to open upload destination for write") + } + if _, err = io.Copy(writer, file); err != nil { + closeErr := dst.Close() + if closeErr != nil { + return fmt.Errorf("failed to close upload destination: %w", closeErr) + } + return fmt.Errorf("failed to copy upload data: %w", err) + } + if err := dst.Close(); err != nil { + return fmt.Errorf("failed to close upload destination: %w", err) + } + return nil +} + // Secure returns whether a secure connection was established. func (c *DefaultCtx) Secure() bool { return c.Protocol() == schemeHTTPS diff --git a/ctx_test.go b/ctx_test.go index 228ac6994b1..66aa3a41f0c 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -17,6 +17,7 @@ import ( "errors" "fmt" "io" + "io/fs" "math" "mime/multipart" "net" @@ -4744,25 +4745,20 @@ func Test_Ctx_RouteNormalized(t *testing.T) { func Test_Ctx_SaveFile(t *testing.T) { // TODO We should clean this up t.Parallel() - app := New() + rootDir := t.TempDir() + app := New(Config{RootDir: rootDir}) app.Post("/test", func(c Ctx) error { fh, err := c.Req().FormFile("file") require.NoError(t, err) - tempFile, err := os.CreateTemp(os.TempDir(), "test-") - require.NoError(t, err) - - defer func(file *os.File) { - closeErr := file.Close() - require.NoError(t, closeErr) - closeErr = os.Remove(file.Name()) - require.NoError(t, closeErr) - }(tempFile) - err = c.SaveFile(fh, tempFile.Name()) + relativePath := "upload.txt" + err = c.SaveFile(fh, relativePath) require.NoError(t, err) - bs, err := os.ReadFile(tempFile.Name()) + targetPath := filepath.Join(rootDir, relativePath) + // #nosec G304 -- reading from test-controlled temp directory. + bs, err := os.ReadFile(targetPath) require.NoError(t, err) require.Equal(t, "hello world", string(bs)) return nil @@ -4787,6 +4783,57 @@ func Test_Ctx_SaveFile(t *testing.T) { require.Equal(t, StatusOK, resp.StatusCode, "Status code") } +// go test -run Test_Ctx_SaveFile_RootDirTraversal +func Test_Ctx_SaveFile_RootDirTraversal(t *testing.T) { + t.Parallel() + + tests := map[string]string{ + "traversal": filepath.Join("..", "outside.txt"), + "absolute": filepath.Join(t.TempDir(), "abs.txt"), + } + + for name, targetPath := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + rootDir := t.TempDir() + app := New(Config{RootDir: rootDir}) + + app.Post("/test", func(c Ctx) error { + fh, err := c.FormFile("file") + require.NoError(t, err) + + err = c.SaveFile(fh, targetPath) + require.Error(t, err) + return c.SendStatus(StatusOK) + }) + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + ioWriter, err := writer.CreateFormFile("file", "test") + require.NoError(t, err) + _, err = ioWriter.Write([]byte("hello world")) + require.NoError(t, err) + require.NoError(t, writer.Close()) + + req := httptest.NewRequest(MethodPost, "/test", body) + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes()))) + + resp, err := app.Test(req) + require.NoError(t, err, "app.Test(req)") + require.Equal(t, StatusOK, resp.StatusCode, "Status code") + + expectedPath := targetPath + if !filepath.IsAbs(expectedPath) { + expectedPath = filepath.Join(rootDir, expectedPath) + } + expectedPath = filepath.Clean(expectedPath) + _, statErr := os.Stat(expectedPath) + require.Error(t, statErr) + }) + } +} + func createMultipartFileHeader(t *testing.T, filename string, data []byte) *multipart.FileHeader { t.Helper() @@ -4814,6 +4861,56 @@ func createMultipartFileHeader(t *testing.T, filename string, data []byte) *mult return files[0] } +type rootDirFS struct { + base string +} + +func (fsys rootDirFS) Open(name string) (fs.File, error) { + return os.Open(filepath.Join(fsys.base, filepath.FromSlash(name))) //nolint:wrapcheck // test helper passes temp paths directly. +} + +func (fsys rootDirFS) OpenFile(name string, flag int, perm fs.FileMode) (fs.File, error) { + fullPath := filepath.Join(fsys.base, filepath.FromSlash(name)) + if err := os.MkdirAll(filepath.Dir(fullPath), 0o750); err != nil { + return nil, fmt.Errorf("failed to create directory: %w", err) + } + return os.OpenFile(fullPath, flag, perm) //nolint:wrapcheck,gosec // test helper uses temp paths and returns os errors directly. +} + +func (fsys rootDirFS) MkdirAll(path string, perm fs.FileMode) error { + if err := os.MkdirAll(filepath.Join(fsys.base, filepath.FromSlash(path)), perm); err != nil { + return fmt.Errorf("failed to create root dir: %w", err) + } + return nil +} + +func (fsys rootDirFS) ReadDir(name string) ([]fs.DirEntry, error) { + return os.ReadDir(filepath.Join(fsys.base, filepath.FromSlash(name))) //nolint:wrapcheck // test helper returns os errors directly. +} + +func (fsys rootDirFS) Remove(name string) error { + return os.Remove(filepath.Join(fsys.base, filepath.FromSlash(name))) //nolint:wrapcheck // test helper returns os errors directly. +} + +type recordingStorage struct { + *memory.Storage + setKeys []string +} + +func newRecordingStorage() *recordingStorage { + return &recordingStorage{Storage: memory.New()} +} + +func (s *recordingStorage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { + s.setKeys = append(s.setKeys, key) + return s.Storage.SetWithContext(ctx, key, val, exp) +} + +func (s *recordingStorage) Set(key string, val []byte, exp time.Duration) error { + s.setKeys = append(s.setKeys, key) + return s.Storage.Set(key, val, exp) +} + // go test -run Test_Ctx_SaveFileToStorage func Test_Ctx_SaveFileToStorage(t *testing.T) { t.Parallel() @@ -4856,6 +4953,104 @@ func Test_Ctx_SaveFileToStorage(t *testing.T) { require.Equal(t, StatusOK, resp.StatusCode, "Status code") } +func Test_Ctx_SaveFileToStorage_RootFsPrefix(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + app := New(Config{ + RootDir: "uploads", + RootFs: rootDirFS{base: baseDir}, + }) + storage := newRecordingStorage() + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + + t.Cleanup(func() { + app.ReleaseCtx(ctx) + }) + + fileHeader := createMultipartFileHeader(t, "rootfs.txt", []byte("hello rootfs")) + + err := ctx.SaveFileToStorage(fileHeader, "rootfs.txt", storage) + require.NoError(t, err) + require.Equal(t, []string{"uploads/rootfs.txt"}, storage.setKeys) +} + +func Test_Ctx_SaveFileToStorage_RootFs_SymlinkEscape(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" { + t.Skip("symlink behavior differs on Windows") + } + + baseDir := t.TempDir() + uploadsDir := filepath.Join(baseDir, "uploads") + require.NoError(t, os.MkdirAll(uploadsDir, 0o750)) + require.NoError(t, os.Symlink(baseDir, filepath.Join(uploadsDir, "link"))) + + app := New(Config{ + RootDir: "uploads", + RootFs: rootDirFS{base: baseDir}, + }) + storage := newRecordingStorage() + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + + t.Cleanup(func() { + app.ReleaseCtx(ctx) + }) + + fileHeader := createMultipartFileHeader(t, "rootfs.txt", []byte("hello rootfs")) + + err := ctx.SaveFileToStorage(fileHeader, "link/rootfs.txt", storage) + require.ErrorIs(t, err, ErrUploadPathEscapesRoot) + require.Empty(t, storage.setKeys) +} + +// go test -run Test_Ctx_SaveFileToStorage_RootDirTraversal +func Test_Ctx_SaveFileToStorage_RootDirTraversal(t *testing.T) { + t.Parallel() + + tests := map[string]string{ + "traversal": filepath.Join("..", "outside"), + "absolute": filepath.Join(t.TempDir(), "abs"), + } + + for name, targetPath := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + rootDir := t.TempDir() + storage := newRecordingStorage() + app := New(Config{RootDir: rootDir}) + + app.Post("/test", func(c Ctx) error { + fh, err := c.FormFile("file") + require.NoError(t, err) + + err = c.SaveFileToStorage(fh, targetPath, storage) + require.Error(t, err) + return c.SendStatus(StatusOK) + }) + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + ioWriter, err := writer.CreateFormFile("file", "test") + require.NoError(t, err) + _, err = ioWriter.Write([]byte("hello world")) + require.NoError(t, err) + require.NoError(t, writer.Close()) + + req := httptest.NewRequest(MethodPost, "/test", body) + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes()))) + + resp, err := app.Test(req) + require.NoError(t, err, "app.Test(req)") + require.Equal(t, StatusOK, resp.StatusCode, "Status code") + + require.Empty(t, storage.setKeys) + }) + } +} + func Test_Ctx_SaveFileToStorage_LargeUpload(t *testing.T) { t.Parallel() const ( @@ -4926,6 +5121,85 @@ func Test_Ctx_SaveFileToStorage_LimitExceededUnknownSize(t *testing.T) { require.ErrorIs(t, err, fasthttp.ErrBodyTooLarge) } +func Test_Ctx_SaveFileToStorage_InvalidPath(t *testing.T) { + t.Parallel() + + app := New() + storage := newRecordingStorage() + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + + t.Cleanup(func() { + app.ReleaseCtx(ctx) + }) + + fileHeader := createMultipartFileHeader(t, "rootfs.txt", []byte("hello")) + + invalidPaths := []string{"", "..", "/absolute"} + if runtime.GOOS == "windows" { + invalidPaths = append(invalidPaths, `C:\absolute`) + } + + for _, path := range invalidPaths { + err := ctx.SaveFileToStorage(fileHeader, path, storage) + require.ErrorIs(t, err, ErrInvalidUploadPath) + } + + require.Empty(t, storage.setKeys) +} + +func Test_Ctx_SaveFile_RootFs(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + app := New(Config{ + RootDir: "uploads", + RootFs: rootDirFS{base: baseDir}, + }) + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + + t.Cleanup(func() { + app.ReleaseCtx(ctx) + }) + + fileHeader := createMultipartFileHeader(t, "rootfs.txt", []byte("hello rootfs")) + + err := ctx.SaveFile(fileHeader, "rootfs.txt") + require.NoError(t, err) + + //nolint:gosec // reading from test-controlled temp directory. + content, err := os.ReadFile(filepath.Join(baseDir, "uploads", "rootfs.txt")) + require.NoError(t, err) + require.Equal(t, "hello rootfs", string(content)) +} + +func Test_Ctx_SaveFile_RootFs_SymlinkEscape(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" { + t.Skip("symlink behavior differs on Windows") + } + + baseDir := t.TempDir() + uploadsDir := filepath.Join(baseDir, "uploads") + require.NoError(t, os.MkdirAll(uploadsDir, 0o750)) + require.NoError(t, os.Symlink(baseDir, filepath.Join(uploadsDir, "link"))) + + app := New(Config{ + RootDir: "uploads", + RootFs: rootDirFS{base: baseDir}, + }) + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + + t.Cleanup(func() { + app.ReleaseCtx(ctx) + }) + + fileHeader := createMultipartFileHeader(t, "rootfs.txt", []byte("hello rootfs")) + + err := ctx.SaveFile(fileHeader, "link/rootfs.txt") + require.ErrorIs(t, err, ErrUploadPathEscapesRoot) +} + type captureStorage struct { t *testing.T data map[string][]byte @@ -4998,6 +5272,46 @@ func (s *captureStorage) Close() error { return nil } +type errStorage struct { + err error +} + +func (s errStorage) GetWithContext(context.Context, string) ([]byte, error) { + return nil, s.err +} + +func (s errStorage) Get(string) ([]byte, error) { + return nil, s.err +} + +func (s errStorage) SetWithContext(context.Context, string, []byte, time.Duration) error { + return s.err +} + +func (s errStorage) Set(string, []byte, time.Duration) error { + return s.err +} + +func (s errStorage) DeleteWithContext(context.Context, string) error { + return s.err +} + +func (s errStorage) Delete(string) error { + return s.err +} + +func (s errStorage) ResetWithContext(context.Context) error { + return s.err +} + +func (s errStorage) Reset() error { + return s.err +} + +func (s errStorage) Close() error { + return s.err +} + func Test_Ctx_SaveFileToStorage_BufferNotReused(t *testing.T) { t.Parallel() @@ -5026,6 +5340,43 @@ func Test_Ctx_SaveFileToStorage_BufferNotReused(t *testing.T) { require.Equal(t, firstPayload, firstStored, "stored data must not rely on pooled buffers") } +func Test_Ctx_SaveFileToStorage_StorageError(t *testing.T) { + t.Parallel() + + app := New() + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + + t.Cleanup(func() { + app.ReleaseCtx(ctx) + }) + + fileHeader := createMultipartFileHeader(t, "rootfs.txt", []byte("hello")) + expectedErr := errors.New("store failed") + + err := ctx.SaveFileToStorage(fileHeader, "test", errStorage{err: expectedErr}) + require.ErrorIs(t, err, expectedErr) +} + +func Test_Ctx_SaveFileToStorage_RootDirPrefix(t *testing.T) { + t.Parallel() + + rootDir := filepath.Join(t.TempDir(), "uploads") + app := New(Config{RootDir: rootDir}) + storage := newRecordingStorage() + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + + t.Cleanup(func() { + app.ReleaseCtx(ctx) + }) + + fileHeader := createMultipartFileHeader(t, "rootdir.txt", []byte("hello")) + + err := ctx.SaveFileToStorage(fileHeader, "rootdir.txt", storage) + require.NoError(t, err) + expectedPath := storageUploadPath(storageRootPrefix(rootDir), "rootdir.txt") + require.Equal(t, []string{expectedPath}, storage.setKeys) +} + type mockContextAwareStorage struct { t *testing.T key any diff --git a/docs/api/ctx.md b/docs/api/ctx.md index a4941254299..f35a544c655 100644 --- a/docs/api/ctx.md +++ b/docs/api/ctx.md @@ -1751,6 +1751,8 @@ Method is used to save **any** multipart file to disk. func (c fiber.Ctx) SaveFile(fh *multipart.FileHeader, path string) error ``` +Paths must be relative and cannot contain `..` segments or absolute prefixes. When `Config.RootDir` or `Config.RootFs` is set, the path is resolved against that root and attempts to escape it are rejected. Storage keys are prefixed by the configured root when one is set. + ```go title="Example" app.Post("/", func(c fiber.Ctx) error { // Parse the multipart form: @@ -1784,6 +1786,8 @@ Method is used to save **any** multipart file to an external storage system. func (c fiber.Ctx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error ``` +Paths must be relative and cannot contain `..` segments or absolute prefixes. When `Config.RootDir` or `Config.RootFs` is set, the path is resolved against that root and attempts to escape it are rejected. + ```go title="Example" storage := memory.New() diff --git a/docs/api/fiber.md b/docs/api/fiber.md index 0b61b7115c1..d7deaea83df 100644 --- a/docs/api/fiber.md +++ b/docs/api/fiber.md @@ -76,6 +76,9 @@ app := fiber.New(fiber.Config{ | ReadTimeout | `time.Duration` | The amount of time allowed to read the full request, including the body. The default timeout is unlimited. | `0` | | ReduceMemoryUsage | `bool` | Aggressively reduces memory usage at the cost of higher CPU usage if set to true. | `false` | | RequestMethods | `[]string` | RequestMethods provides customizability for HTTP methods. You can add/remove methods as you wish. | `DefaultMethods` | +| RootDir | `string` | Base directory for SaveFile/SaveFileToStorage uploads. Relative paths are resolved against this directory and must not escape it. | `""` | +| RootPerms | `fs.FileMode` | Permissions used when creating RootDir or RootFs prefixes for uploads. | `0o750` | +| RootFs | `fs.FS` | Filesystem used for SaveFile/SaveFileToStorage uploads. When set, RootDir is treated as a relative prefix within the filesystem. | `nil` | | ServerHeader | `string` | Enables the `Server` HTTP header with the given value. | `""` | | StreamRequestBody | `bool` | StreamRequestBody enables request body streaming, and calls the handler sooner when given body is larger than the current limit. | `false` | | StrictRouting | `bool` | When enabled, the router treats `/foo` and `/foo/` as different. Otherwise, the router treats `/foo` and `/foo/` as the same. | `false` | diff --git a/docs/whats_new.md b/docs/whats_new.md index 275b4172833..2475b88a4df 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -77,6 +77,7 @@ We have made several changes to the Fiber app, including: - `ListenerNetwork` (previously `Network`) - **Trusted Proxy Configuration**: The `EnabledTrustedProxyCheck` has been moved to `app.Config.TrustProxy`, and `TrustedProxies` has been moved to `TrustProxyConfig.Proxies`. Additionally, `ProxyHeader` must be set to read client IPs from proxy headers (e.g., `X-Forwarded-For`). - **XMLDecoder Config Property**: The `XMLDecoder` property has been added to allow usage of 3rd-party XML libraries in XML binder. +- **Upload Root Permissions**: The `RootPerms` property controls the permissions used when creating `RootDir` or `RootFs` prefixes for uploads (default `0o750`). ### New Methods diff --git a/error.go b/error.go index 93c6f7cbcea..5fabaa50ca8 100644 --- a/error.go +++ b/error.go @@ -22,6 +22,10 @@ var ( ErrNoViewEngineConfigured = errors.New("fiber: no view engine configured") // ErrAutoCertWithCertFile indicates AutoCertManager cannot be used with CertFile/CertKeyFile. ErrAutoCertWithCertFile = errors.New("tls: AutoCertManager cannot be combined with CertFile/CertKeyFile") + // ErrInvalidUploadPath indicates the upload path is invalid. + ErrInvalidUploadPath = errors.New("upload: path must be relative and must not contain '..' or absolute prefixes") + // ErrUploadPathEscapesRoot indicates the upload path escapes the configured root. + ErrUploadPathEscapesRoot = errors.New("upload: path escapes the configured root") ) // Fiber redirection errors diff --git a/listen_test.go b/listen_test.go index e02f8d0d1df..73119aee62f 100644 --- a/listen_test.go +++ b/listen_test.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "log" //nolint:depguard // TODO: Required to capture output, use internal log package instead "net" "os" "path/filepath" @@ -21,6 +20,8 @@ import ( "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttputil" "golang.org/x/crypto/acme/autocert" + + fiberlog "github.com/gofiber/fiber/v3/log" ) // go test -run Test_Listen @@ -876,11 +877,11 @@ func captureOutput(f func()) string { defer func() { os.Stdout = stdout os.Stderr = stderr - log.SetOutput(os.Stderr) + fiberlog.SetOutput(os.Stderr) }() os.Stdout = writer os.Stderr = writer - log.SetOutput(writer) + fiberlog.SetOutput(writer) out := make(chan string) go func() { var buf bytes.Buffer diff --git a/upload.go b/upload.go new file mode 100644 index 00000000000..11aef32cc50 --- /dev/null +++ b/upload.go @@ -0,0 +1,332 @@ +package fiber + +import ( + "errors" + "fmt" + "io" + "io/fs" + "os" + pathpkg "path" + "path/filepath" + "slices" + "strings" + "time" + + "github.com/gofiber/utils/v2" +) + +type uploadPath struct { + osPath string + slashPath string +} + +type tempProbeFile interface { + WriteString(string) (int, error) + Close() error + Name() string +} + +func (app *App) configureUploads() { + if app.config.RootFs != nil { + writer, ok := app.config.RootFs.(interface { + fs.FS + OpenFile(name string, flag int, perm fs.FileMode) (fs.File, error) + }) + if !ok { + panic("fiber: RootFs must implement OpenFile for uploads") + } + mkdirer, ok := app.config.RootFs.(interface { + MkdirAll(path string, perm fs.FileMode) error + }) + if !ok { + panic("fiber: RootFs must implement MkdirAll for uploads") + } + + prefix, err := cleanUploadRootPrefix(app.config.RootDir) + if err != nil { + panic(fmt.Sprintf("fiber: invalid RootDir for RootFs: %v", err)) + } + + if prefix != "" { + if err := mkdirer.MkdirAll(prefix, app.config.RootPerms); err != nil { + panic(fmt.Sprintf("fiber: failed to create RootFs prefix %q: %v", prefix, err)) + } + } + + if err := probeUploadFSWritable(writer, prefix); err != nil { + panic(fmt.Sprintf("fiber: RootFs not writable: %v", err)) + } + + app.config.uploadRootFSPrefix = prefix + app.config.uploadRootFSWriter = writer + app.config.uploadRootPath = prefix + return + } + + if app.config.RootDir == "" { + return + } + + rootAbs, err := filepath.Abs(app.config.RootDir) + if err != nil { + panic(fmt.Sprintf("fiber: failed to resolve RootDir: %v", err)) + } + rootAbs = filepath.Clean(rootAbs) + if err = os.MkdirAll(rootAbs, app.config.RootPerms); err != nil { + panic(fmt.Sprintf("fiber: failed to create RootDir %q: %v", rootAbs, err)) + } + rootEval, err := filepath.EvalSymlinks(rootAbs) + if err != nil { + panic(fmt.Sprintf("fiber: failed to resolve RootDir symlinks %q: %v", rootAbs, err)) + } + rootEval = filepath.Clean(rootEval) + if err := probeUploadDirWritable(rootAbs); err != nil { + panic(fmt.Sprintf("fiber: RootDir not writable %q: %v", rootAbs, err)) + } + + app.config.uploadRootDir = rootAbs + app.config.uploadRootEval = rootEval + app.config.uploadRootPath = storageRootPrefix(app.config.RootDir) +} + +func validateUploadPath(path string) (uploadPath, error) { + if path == "" { + return uploadPath{}, ErrInvalidUploadPath + } + if isAbsUploadPath(path) { + return uploadPath{}, ErrInvalidUploadPath + } + if containsDotDot(path) { + return uploadPath{}, ErrInvalidUploadPath + } + + cleanOS := filepath.Clean(path) + + cleanSlash := pathpkg.Clean("/" + filepath.ToSlash(path)) + cleanSlash = utils.TrimLeft(cleanSlash, '/') + + if cleanOS == "." || cleanOS == "" || cleanSlash == "." || cleanSlash == "" { + return uploadPath{}, ErrInvalidUploadPath + } + if !fs.ValidPath(cleanSlash) { + return uploadPath{}, ErrInvalidUploadPath + } + + return uploadPath{ + osPath: cleanOS, + slashPath: cleanSlash, + }, nil +} + +func isAbsUploadPath(path string) bool { + if filepath.IsAbs(path) { + return true + } + if filepath.VolumeName(path) != "" { + return true + } + return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "\\") +} + +func containsDotDot(path string) bool { + return slices.Contains(strings.FieldsFunc(path, func(r rune) bool { + return r == '/' || r == '\\' + }), "..") +} + +func storageRootPrefix(root string) string { + if root == "" { + return "" + } + root = filepath.Clean(root) + if volume := filepath.VolumeName(root); volume != "" { + root = root[len(volume):] + } + root = filepath.ToSlash(root) + root = utils.TrimLeft(root, '/') + if root == "." { + return "" + } + return root +} + +func cleanUploadRootPrefix(root string) (string, error) { + if root == "" || root == "." { + return "", nil + } + if isAbsUploadPath(root) || containsDotDot(root) { + return "", ErrInvalidUploadPath + } + cleanSlash := pathpkg.Clean("/" + filepath.ToSlash(root)) + cleanSlash = utils.TrimLeft(cleanSlash, '/') + if cleanSlash == "." || cleanSlash == "" { + return "", ErrInvalidUploadPath + } + if !fs.ValidPath(cleanSlash) { + return "", ErrInvalidUploadPath + } + return cleanSlash, nil +} + +func ensureUploadPathWithinRoot(rootEval, fullPath string) error { + parent := filepath.Dir(fullPath) + parentEval, err := evalExistingPath(parent) + if err != nil { + return err + } + if !hasPathPrefix(parentEval, rootEval) { + return ErrUploadPathEscapesRoot + } + return nil +} + +func evalExistingPath(path string) (string, error) { + current := path + for { + _, err := os.Lstat(current) + if err == nil { + resolved, resolveErr := filepath.EvalSymlinks(current) + if resolveErr != nil { + return "", fmt.Errorf("failed to resolve symlinks for %q: %w", current, resolveErr) + } + return filepath.Clean(resolved), nil + } + if !os.IsNotExist(err) { + return "", fmt.Errorf("failed to stat %q: %w", current, err) + } + parent := filepath.Dir(current) + if parent == current { + return "", fmt.Errorf("failed to resolve upload path %q: %w", current, err) + } + current = parent + } +} + +func hasPathPrefix(path, prefix string) bool { + if path == prefix { + return true + } + if !strings.HasPrefix(path, prefix) { + return false + } + if strings.HasSuffix(prefix, string(os.PathSeparator)) { + return true + } + if len(path) == len(prefix) { + return true + } + return len(path) > len(prefix) && path[len(prefix)] == os.PathSeparator +} + +func ensureNoSymlinkFS(fsys fs.FS, fullPath string) error { + parts := strings.Split(fullPath, "/") + if len(parts) <= 1 { + return nil + } + for i := 0; i < len(parts)-1; i++ { + name := parts[i] + if name == "" { + continue + } + parent := strings.Join(parts[:i], "/") + if parent == "" { + parent = "." + } + entries, err := fs.ReadDir(fsys, parent) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return fmt.Errorf("failed to read upload directory %q: %w", parent, err) + } + for _, entry := range entries { + if entry.Name() == name && entry.Type()&fs.ModeSymlink != 0 { + return ErrUploadPathEscapesRoot + } + } + } + return nil +} + +func probeUploadDirWritable(root string) error { + return probeUploadDirWritableWith(root, func(dir, pattern string) (tempProbeFile, error) { + return os.CreateTemp(dir, pattern) + }, os.Remove) +} + +func probeUploadDirWritableWith( + root string, + createFile func(dir, pattern string) (tempProbeFile, error), + removeFile func(name string) error, +) error { + tempFile, err := createFile(root, ".fiber-upload-check-*") + if err != nil { + return fmt.Errorf("failed to create probe file: %w", err) + } + if _, err := tempFile.WriteString("fiber"); err != nil { + closeErr := tempFile.Close() + if closeErr != nil { + return fmt.Errorf("failed to close probe file: %w", closeErr) + } + removeErr := removeFile(tempFile.Name()) + if removeErr != nil { + return fmt.Errorf("failed to remove probe file: %w", removeErr) + } + return fmt.Errorf("failed to write probe file: %w", err) + } + if err := tempFile.Close(); err != nil { + removeErr := removeFile(tempFile.Name()) + if removeErr != nil { + return fmt.Errorf("failed to remove probe file: %w", removeErr) + } + return fmt.Errorf("failed to close probe file: %w", err) + } + if err := removeFile(tempFile.Name()); err != nil { + return fmt.Errorf("failed to remove probe file: %w", err) + } + return nil +} + +func probeUploadFSWritable(fsys interface { + fs.FS + OpenFile(name string, flag int, perm fs.FileMode) (fs.File, error) +}, prefix string, +) error { + name := pathpkg.Join(prefix, fmt.Sprintf(".fiber-upload-check-%d", time.Now().UnixNano())) + file, err := fsys.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0o600) + if err != nil { + return fmt.Errorf("failed to create probe file: %w", err) + } + writer, ok := file.(io.Writer) + if !ok { + if closeErr := file.Close(); closeErr != nil { + return fmt.Errorf("failed to close probe file: %w", closeErr) + } + return errors.New("upload file is not writable") + } + if _, err := writer.Write([]byte("fiber")); err != nil { + closeErr := file.Close() + if closeErr != nil { + return fmt.Errorf("failed to close probe file: %w", closeErr) + } + return fmt.Errorf("failed to write probe file: %w", err) + } + if err := file.Close(); err != nil { + return fmt.Errorf("failed to close probe file: %w", err) + } + if remover, ok := fsys.(interface { + Remove(name string) error + }); ok { + if err := remover.Remove(name); err != nil { + return fmt.Errorf("failed to remove probe file: %w", err) + } + } + return nil +} + +func storageUploadPath(prefix, cleanSlash string) string { + if prefix == "" { + return cleanSlash + } + return pathpkg.Join(prefix, cleanSlash) +} diff --git a/upload_test.go b/upload_test.go new file mode 100644 index 00000000000..d84d64b7282 --- /dev/null +++ b/upload_test.go @@ -0,0 +1,763 @@ +package fiber + +import ( + "errors" + "io" + "io/fs" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "testing/fstest" + "time" +) + +type stubFileInfo struct { + name string +} + +func (fi stubFileInfo) Name() string { return fi.name } +func (fi stubFileInfo) Size() int64 { + _ = fi + return 0 +} + +func (fi stubFileInfo) Mode() fs.FileMode { + _ = fi + return 0 +} + +func (fi stubFileInfo) ModTime() time.Time { + _ = fi + return time.Time{} +} + +func (fi stubFileInfo) IsDir() bool { + _ = fi + return false +} + +func (fi stubFileInfo) Sys() any { + _ = fi + return nil +} + +type noWriteFile struct { + closeErr error +} + +func (f *noWriteFile) Read(_ []byte) (int, error) { + _ = f + return 0, io.EOF +} + +func (f *noWriteFile) Stat() (fs.FileInfo, error) { + _ = f + return stubFileInfo{name: "probe"}, nil +} +func (f *noWriteFile) Close() error { return f.closeErr } + +type writeFile struct { + writeErr error + closeErr error +} + +func (f *writeFile) Read(_ []byte) (int, error) { + _ = f + return 0, io.EOF +} + +func (f *writeFile) Stat() (fs.FileInfo, error) { + _ = f + return stubFileInfo{name: "probe"}, nil +} +func (f *writeFile) Close() error { return f.closeErr } +func (f *writeFile) Write(_ []byte) (int, error) { + if f.writeErr != nil { + return 0, f.writeErr + } + return len("fiber"), nil +} + +type probeFS struct { + file fs.File + openErr error +} + +func (fsys probeFS) Open(_ string) (fs.File, error) { + if fsys.openErr != nil { + return nil, fsys.openErr + } + if fsys.file == nil { + return nil, fs.ErrNotExist + } + return fsys.file, nil +} + +func (fsys probeFS) OpenFile(_ string, _ int, _ fs.FileMode) (fs.File, error) { + if fsys.openErr != nil { + return nil, fsys.openErr + } + return fsys.file, nil +} + +type removeErrFS struct { + probeFS + removeErr error +} + +func (fsys removeErrFS) Remove(_ string) error { + return fsys.removeErr +} + +type readDirErrFS struct { + err error +} + +type rootFSMissingOpenFile struct{} + +func (rootFSMissingOpenFile) Open(_ string) (fs.File, error) { return nil, fs.ErrNotExist } +func (rootFSMissingOpenFile) MkdirAll(_ string, _ fs.FileMode) error { + return nil +} + +type rootFSMissingMkdirAll struct{} + +func (rootFSMissingMkdirAll) Open(_ string) (fs.File, error) { return nil, fs.ErrNotExist } +func (rootFSMissingMkdirAll) OpenFile(_ string, _ int, _ fs.FileMode) (fs.File, error) { + return &writeFile{}, nil +} + +type rootFSMkdirErr struct { + err error +} + +func (fsys rootFSMkdirErr) Open(_ string) (fs.File, error) { + _ = fsys + return nil, fs.ErrNotExist +} + +func (fsys rootFSMkdirErr) OpenFile(_ string, _ int, _ fs.FileMode) (fs.File, error) { + _ = fsys + return &writeFile{}, nil +} +func (fsys rootFSMkdirErr) MkdirAll(_ string, _ fs.FileMode) error { return fsys.err } + +type rootFSNoWriter struct{} + +func (rootFSNoWriter) Open(_ string) (fs.File, error) { return nil, fs.ErrNotExist } +func (rootFSNoWriter) OpenFile(_ string, _ int, _ fs.FileMode) (fs.File, error) { + return &noWriteFile{}, nil +} +func (rootFSNoWriter) MkdirAll(_ string, _ fs.FileMode) error { return nil } + +type probeTempFile struct { + name *string + writeErr error + closeErr error +} + +func (f *probeTempFile) WriteString(_ string) (int, error) { + if f.writeErr != nil { + return 0, f.writeErr + } + return len("fiber"), nil +} + +func (f *probeTempFile) Close() error { + return f.closeErr +} + +func (f *probeTempFile) Name() string { + if f.name == nil { + return "" + } + return *f.name +} + +func stringPtr(value string) *string { + return &value +} + +func (fsys readDirErrFS) Open(_ string) (fs.File, error) { + _ = fsys + return nil, fs.ErrNotExist +} + +func (fsys readDirErrFS) ReadDir(_ string) ([]fs.DirEntry, error) { + return nil, fsys.err +} + +func TestValidateUploadPath(t *testing.T) { + t.Parallel() + + absPath := "/var/uploads/file.txt" + if runtime.GOOS == "windows" { + absPath = `C:\uploads\file.txt` + } + + tests := []struct { + name string + path string + wantErr bool + }{ + {name: "valid", path: "uploads/file.txt"}, + {name: "valid_dot_prefix", path: ".hidden/file.txt"}, + {name: "dot", path: ".", wantErr: true}, + {name: "empty", path: "", wantErr: true}, + {name: "absolute", path: absPath, wantErr: true}, + {name: "dotdot", path: "../file.txt", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + normalized, err := validateUploadPath(tt.path) + if tt.wantErr { + if err == nil { + t.Fatalf("expected error for %q", tt.path) + } + return + } + if err != nil { + t.Fatalf("unexpected error for %q: %v", tt.path, err) + } + if normalized.osPath == "" || normalized.slashPath == "" { + t.Fatalf("expected normalized paths for %q", tt.path) + } + if strings.Contains(normalized.slashPath, `\`) { + t.Fatalf("expected slash path to use forward slashes, got %q", normalized.slashPath) + } + }) + } +} + +func TestIsAbsUploadPath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + path string + want bool + }{ + {name: "relative", path: "uploads/file.txt", want: false}, + {name: "slash_abs", path: "/uploads/file.txt", want: true}, + {name: "backslash_abs", path: `\uploads\file.txt`, want: true}, + } + + if runtime.GOOS == "windows" { + tests = append(tests, struct { + name string + path string + want bool + }{name: "volume_abs", path: `C:\uploads\file.txt`, want: true}) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := isAbsUploadPath(tt.path); got != tt.want { + t.Fatalf("expected %v for %q, got %v", tt.want, tt.path, got) + } + }) + } +} + +func TestContainsDotDot(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + path string + want bool + }{ + {name: "dotdot_segment", path: "uploads/../file.txt", want: true}, + {name: "dotdot_prefix", path: "../file.txt", want: true}, + {name: "no_dotdot", path: "uploads/.../file.txt", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := containsDotDot(tt.path); got != tt.want { + t.Fatalf("expected %v for %q, got %v", tt.want, tt.path, got) + } + }) + } +} + +func TestStorageRootPrefix(t *testing.T) { + t.Parallel() + + root := "/var/uploads" + if runtime.GOOS == "windows" { + root = `C:\uploads` + } + + tests := []struct { + name string + root string + want string + }{ + {name: "empty", root: "", want: ""}, + {name: "dot", root: ".", want: ""}, + {name: "rooted", root: root, want: "var/uploads"}, + {name: "relative", root: "uploads", want: "uploads"}, + } + + if runtime.GOOS == "windows" { + tests[2].want = "uploads" + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := storageRootPrefix(tt.root); got != tt.want { + t.Fatalf("expected %q for %q, got %q", tt.want, tt.root, got) + } + }) + } +} + +func TestCleanUploadRootPrefix(t *testing.T) { + t.Parallel() + + abs := "/uploads" + if runtime.GOOS == "windows" { + abs = `C:\uploads` + } + + tests := []struct { + name string + root string + want string + wantErr bool + }{ + {name: "empty", root: "", want: ""}, + {name: "dot", root: ".", want: ""}, + {name: "valid", root: "uploads", want: "uploads"}, + {name: "abs", root: abs, wantErr: true}, + {name: "dotdot", root: "../uploads", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := cleanUploadRootPrefix(tt.root) + if tt.wantErr { + if err == nil { + t.Fatalf("expected error for %q", tt.root) + } + return + } + if err != nil { + t.Fatalf("unexpected error for %q: %v", tt.root, err) + } + if got != tt.want { + t.Fatalf("expected %q for %q, got %q", tt.want, tt.root, got) + } + }) + } +} + +func TestEnsureUploadPathWithinRoot(t *testing.T) { + t.Parallel() + + root := t.TempDir() + rootEval, err := filepath.EvalSymlinks(root) + if err != nil { + t.Fatalf("failed to eval root: %v", err) + } + + inside := filepath.Join(root, "uploads", "file.txt") + if err := ensureUploadPathWithinRoot(rootEval, inside); err != nil { + t.Fatalf("expected inside path to be allowed, got %v", err) + } + + outside := filepath.Join(t.TempDir(), "file.txt") + if err := ensureUploadPathWithinRoot(rootEval, outside); !errors.Is(err, ErrUploadPathEscapesRoot) { + t.Fatalf("expected ErrUploadPathEscapesRoot, got %v", err) + } +} + +func TestEvalExistingPath(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + t.Skip("path resolution differs on this platform") + } + + root := t.TempDir() + target := filepath.Join(root, "missing", "file.txt") + got, err := evalExistingPath(target) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != filepath.Clean(root) { + t.Fatalf("expected %q, got %q", filepath.Clean(root), got) + } +} + +func TestHasPathPrefix(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + path string + prefix string + want bool + }{ + {name: "same", path: filepath.Join("a", "b"), prefix: filepath.Join("a", "b"), want: true}, + {name: "child", path: filepath.Join("a", "b", "c"), prefix: filepath.Join("a", "b"), want: true}, + {name: "sibling", path: filepath.Join("a", "bc"), prefix: filepath.Join("a", "b"), want: false}, + {name: "prefix_with_separator", path: filepath.Join("a", "b", "c"), prefix: filepath.Join("a", "b") + string(os.PathSeparator), want: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := hasPathPrefix(tt.path, tt.prefix); got != tt.want { + t.Fatalf("expected %v for %q/%q, got %v", tt.want, tt.path, tt.prefix, got) + } + }) + } +} + +func TestEnsureNoSymlinkFS(t *testing.T) { + t.Parallel() + + t.Run("missing_parent", func(t *testing.T) { + t.Parallel() + + fsys := fstest.MapFS{} + if err := ensureNoSymlinkFS(fsys, "missing/file.txt"); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + }) + + t.Run("symlink_detected", func(t *testing.T) { + t.Parallel() + + fsys := fstest.MapFS{ + "dir": &fstest.MapFile{Mode: fs.ModeDir}, + "dir/link": &fstest.MapFile{Mode: fs.ModeSymlink}, + } + if err := ensureNoSymlinkFS(fsys, "dir/link/file.txt"); !errors.Is(err, ErrUploadPathEscapesRoot) { + t.Fatalf("expected ErrUploadPathEscapesRoot, got %v", err) + } + }) + + t.Run("read_dir_error", func(t *testing.T) { + t.Parallel() + + fsys := readDirErrFS{err: errors.New("read failure")} + if err := ensureNoSymlinkFS(fsys, "dir/file.txt"); err == nil { + t.Fatal("expected error") + } + }) + + t.Run("no_symlink", func(t *testing.T) { + t.Parallel() + + fsys := fstest.MapFS{ + "dir": &fstest.MapFile{Mode: fs.ModeDir}, + "dir/child": &fstest.MapFile{Mode: fs.ModeDir}, + } + if err := ensureNoSymlinkFS(fsys, "dir/child/file.txt"); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + }) + + t.Run("single_part_path", func(t *testing.T) { + t.Parallel() + + if err := ensureNoSymlinkFS(fstest.MapFS{}, "file.txt"); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + }) + + t.Run("leading_slash", func(t *testing.T) { + t.Parallel() + + if err := ensureNoSymlinkFS(fstest.MapFS{}, "/file.txt"); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + }) +} + +func TestProbeUploadDirWritable(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" { + t.Skip("probe behavior differs on Windows") + } + + t.Run("success", func(t *testing.T) { + t.Parallel() + + if err := probeUploadDirWritable(t.TempDir()); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + }) + + t.Run("not_directory", func(t *testing.T) { + t.Parallel() + + file, err := os.CreateTemp(t.TempDir(), "probe") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + t.Cleanup(func() { + if removeErr := os.Remove(file.Name()); removeErr != nil { + t.Fatalf("failed to remove temp file: %v", removeErr) + } + }) + if err := probeUploadDirWritable(file.Name()); err == nil { + t.Fatal("expected error for file path") + } + }) + + t.Run("create_error", func(t *testing.T) { + t.Parallel() + + createErr := errors.New("create failure") + err := probeUploadDirWritableWith(t.TempDir(), func(_, _ string) (tempProbeFile, error) { + return nil, createErr + }, func(string) error { return nil }) + if !errors.Is(err, createErr) { + t.Fatalf("expected create error, got %v", err) + } + }) + + t.Run("write_error", func(t *testing.T) { + t.Parallel() + + writeErr := errors.New("write failure") + err := probeUploadDirWritableWith(t.TempDir(), func(_, _ string) (tempProbeFile, error) { + return &probeTempFile{name: stringPtr("probe"), writeErr: writeErr}, nil + }, func(string) error { return nil }) + if !errors.Is(err, writeErr) { + t.Fatalf("expected write error, got %v", err) + } + }) + + t.Run("close_error_after_write", func(t *testing.T) { + t.Parallel() + + closeErr := errors.New("close failure") + err := probeUploadDirWritableWith(t.TempDir(), func(_, _ string) (tempProbeFile, error) { + return &probeTempFile{name: stringPtr("probe"), closeErr: closeErr}, nil + }, func(string) error { return nil }) + if !errors.Is(err, closeErr) { + t.Fatalf("expected close error, got %v", err) + } + }) + + t.Run("remove_error_after_write_failure", func(t *testing.T) { + t.Parallel() + + removeErr := errors.New("remove failure") + err := probeUploadDirWritableWith(t.TempDir(), func(_, _ string) (tempProbeFile, error) { + return &probeTempFile{name: stringPtr("probe"), writeErr: errors.New("write failure")}, nil + }, func(string) error { return removeErr }) + if !errors.Is(err, removeErr) { + t.Fatalf("expected remove error, got %v", err) + } + }) + + t.Run("remove_error_after_close_failure", func(t *testing.T) { + t.Parallel() + + removeErr := errors.New("remove failure") + err := probeUploadDirWritableWith(t.TempDir(), func(_, _ string) (tempProbeFile, error) { + return &probeTempFile{name: stringPtr("probe"), closeErr: errors.New("close failure")}, nil + }, func(string) error { return removeErr }) + if !errors.Is(err, removeErr) { + t.Fatalf("expected remove error, got %v", err) + } + }) + + t.Run("remove_error_after_success", func(t *testing.T) { + t.Parallel() + + removeErr := errors.New("remove failure") + err := probeUploadDirWritableWith(t.TempDir(), func(_, _ string) (tempProbeFile, error) { + return &probeTempFile{name: stringPtr("probe")}, nil + }, func(string) error { return removeErr }) + if !errors.Is(err, removeErr) { + t.Fatalf("expected remove error, got %v", err) + } + }) +} + +func TestConfigureUploads(t *testing.T) { + t.Parallel() + + t.Run("root_dir", func(t *testing.T) { + t.Parallel() + + root := t.TempDir() + app := New(Config{RootDir: root}) + if app.config.uploadRootDir == "" { + t.Fatal("expected uploadRootDir to be set") + } + if app.config.uploadRootPath == "" { + t.Fatal("expected uploadRootPath to be set") + } + }) + + t.Run("missing_openfile", func(t *testing.T) { + t.Parallel() + + assertPanics(t, func() { + New(Config{RootDir: "uploads", RootFs: rootFSMissingOpenFile{}}) + }) + }) + + t.Run("missing_mkdirall", func(t *testing.T) { + t.Parallel() + + assertPanics(t, func() { + New(Config{RootDir: "uploads", RootFs: rootFSMissingMkdirAll{}}) + }) + }) + + t.Run("invalid_rootdir_for_rootfs", func(t *testing.T) { + t.Parallel() + + assertPanics(t, func() { + New(Config{RootDir: filepath.Join(t.TempDir(), "uploads"), RootFs: rootFSNoWriter{}}) + }) + }) + + t.Run("mkdir_error", func(t *testing.T) { + t.Parallel() + + assertPanics(t, func() { + New(Config{ + RootDir: "uploads", + RootFs: rootFSMkdirErr{err: errors.New("mkdir failure")}, + }) + }) + }) + + t.Run("not_writable", func(t *testing.T) { + t.Parallel() + + assertPanics(t, func() { + New(Config{RootDir: "uploads", RootFs: rootFSNoWriter{}}) + }) + }) +} + +func TestEvalExistingPathError(t *testing.T) { + t.Parallel() + + _, err := evalExistingPath("invalid\x00path") + if err == nil { + t.Fatal("expected error") + } +} + +func assertPanics(t *testing.T, fn func()) { + t.Helper() + + defer func() { + if recover() == nil { + t.Fatal("expected panic") + } + }() + fn() +} + +func TestProbeUploadFSWritable(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + t.Parallel() + + fsys := probeFS{file: &writeFile{}} + if err := probeUploadFSWritable(fsys, ""); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + }) + + t.Run("open_error", func(t *testing.T) { + t.Parallel() + + fsys := probeFS{openErr: errors.New("open failure")} + if err := probeUploadFSWritable(fsys, ""); err == nil { + t.Fatal("expected error") + } + }) + + t.Run("not_writable", func(t *testing.T) { + t.Parallel() + + fsys := probeFS{file: &noWriteFile{}} + if err := probeUploadFSWritable(fsys, ""); err == nil { + t.Fatal("expected error") + } + }) + + t.Run("write_error", func(t *testing.T) { + t.Parallel() + + fsys := probeFS{file: &writeFile{writeErr: errors.New("write failure")}} + if err := probeUploadFSWritable(fsys, ""); err == nil { + t.Fatal("expected error") + } + }) + + t.Run("close_error", func(t *testing.T) { + t.Parallel() + + fsys := probeFS{file: &writeFile{closeErr: errors.New("close failure")}} + if err := probeUploadFSWritable(fsys, ""); err == nil { + t.Fatal("expected error") + } + }) + + t.Run("remove_error", func(t *testing.T) { + t.Parallel() + + fsys := removeErrFS{ + probeFS: probeFS{file: &writeFile{}}, + removeErr: errors.New("remove failure"), + } + if err := probeUploadFSWritable(fsys, ""); err == nil { + t.Fatal("expected error") + } + }) +} + +func TestStorageUploadPath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + prefix string + path string + want string + }{ + {name: "empty_prefix", prefix: "", path: "uploads/file.txt", want: "uploads/file.txt"}, + {name: "with_prefix", prefix: "root", path: "uploads/file.txt", want: "root/uploads/file.txt"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := storageUploadPath(tt.prefix, tt.path); got != tt.want { + t.Fatalf("expected %q, got %q", tt.want, got) + } + }) + } +}