diff --git a/app.go b/app.go
index 722277ff55d..de5578f9a8e 100644
--- a/app.go
+++ b/app.go
@@ -15,6 +15,7 @@ import (
"errors"
"fmt"
"io"
+ "io/fs"
"net"
"net/http"
"net/http/httputil"
@@ -175,6 +176,23 @@ type Config struct { //nolint:govet // Aligning the struct fields is not necessa
// Default: nil
Views Views `json:"-"`
+ // RootDir specifies the base directory for SaveFile/SaveFileToStorage uploads.
+ // Relative paths are resolved against this directory.
+ //
+ // Optional. Default: ""
+ RootDir string `json:"root_dir"`
+
+ // RootPerms specifies the permissions used when creating RootDir or RootFs prefixes.
+ //
+ // Optional. Default: 0o750
+ RootPerms fs.FileMode `json:"root_perms"`
+
+ // RootFs specifies the filesystem used for SaveFile/SaveFileToStorage uploads.
+ // When set, RootDir is treated as a relative prefix within the filesystem.
+ //
+ // Optional. Default: nil
+ RootFs fs.FS `json:"-"`
+
// Views Layout is the global layout for all template render until override on Render function.
//
// Default: ""
@@ -437,6 +455,15 @@ type Config struct { //nolint:govet // Aligning the struct fields is not necessa
//
// Optional. Default: a provider that returns context.Background()
ServicesShutdownContextProvider func() context.Context
+
+ uploadRootDir string
+ uploadRootEval string
+ uploadRootPath string
+ uploadRootFSPrefix string
+ uploadRootFSWriter interface {
+ fs.FS
+ OpenFile(name string, flag int, perm fs.FileMode) (fs.File, error)
+ }
}
// Default TrustProxyConfig
@@ -605,6 +632,9 @@ func New(config ...Config) *App {
"zstd": ".fiber.zst",
}
}
+ if app.config.RootPerms == 0 {
+ app.config.RootPerms = 0o750
+ }
if app.config.Immutable {
app.toBytes, app.toString = toBytesImmutable, toStringImmutable
@@ -642,6 +672,8 @@ func New(config ...Config) *App {
app.config.RequestMethods = DefaultMethods
}
+ app.configureUploads()
+
app.config.TrustProxyConfig.ips = make(map[string]struct{}, len(app.config.TrustProxyConfig.Proxies))
for _, ipAddress := range app.config.TrustProxyConfig.Proxies {
app.handleTrustedProxy(ipAddress)
diff --git a/app_test.go b/app_test.go
index ed6b0780f4f..bcad64c3ad3 100644
--- a/app_test.go
+++ b/app_test.go
@@ -13,6 +13,7 @@ import (
"errors"
"fmt"
"io"
+ "io/fs"
"mime/multipart"
"net"
"net/http"
@@ -81,6 +82,149 @@ func testErrorResponse(t *testing.T, err error, resp *http.Response, expectedBod
require.Equal(t, expectedBodyError, string(body), "Response body")
}
+type testUploadFS struct {
+ mkdirPath string
+ mkdirPerm fs.FileMode
+}
+
+func (tfs *testUploadFS) Open(_ string) (fs.File, error) {
+ _ = tfs
+ return nil, fs.ErrNotExist
+}
+
+func (tfs *testUploadFS) OpenFile(_ string, _ int, _ fs.FileMode) (fs.File, error) {
+ _ = tfs
+ return &testUploadFile{buf: &bytes.Buffer{}}, nil
+}
+
+func (tfs *testUploadFS) MkdirAll(path string, perm fs.FileMode) error {
+ tfs.mkdirPath = path
+ tfs.mkdirPerm = perm
+ return nil
+}
+
+func (tfs *testUploadFS) Remove(_ string) error {
+ _ = tfs
+ return nil
+}
+
+type testUploadFile struct {
+ buf *bytes.Buffer
+}
+
+func (tf *testUploadFile) Read(p []byte) (int, error) {
+ //nolint:wrapcheck // test helper passthrough
+ return tf.buf.Read(p)
+}
+
+func (tf *testUploadFile) Write(p []byte) (int, error) {
+ //nolint:wrapcheck // test helper passthrough
+ return tf.buf.Write(p)
+}
+
+func (tf *testUploadFile) Close() error {
+ _ = tf
+ return nil
+}
+
+func (tf *testUploadFile) Stat() (fs.FileInfo, error) {
+ _ = tf
+ return testUploadFileInfo{name: "upload"}, nil
+}
+
+type testUploadFileInfo struct {
+ name string
+}
+
+func (fi testUploadFileInfo) Name() string { return fi.name }
+func (fi testUploadFileInfo) Size() int64 {
+ _ = fi
+ return 0
+}
+
+func (fi testUploadFileInfo) Mode() fs.FileMode {
+ _ = fi
+ return 0
+}
+
+func (fi testUploadFileInfo) ModTime() time.Time {
+ _ = fi
+ return time.Time{}
+}
+
+func (fi testUploadFileInfo) IsDir() bool {
+ _ = fi
+ return false
+}
+
+func (fi testUploadFileInfo) Sys() any {
+ _ = fi
+ return nil
+}
+
+func TestRootPermsRootFs(t *testing.T) {
+ t.Parallel()
+
+ if runtime.GOOS == "windows" {
+ t.Skip("root perms are not validated on Windows in this test")
+ }
+
+ tests := []struct {
+ name string
+ rootPerm fs.FileMode
+ wantPerm fs.FileMode
+ }{
+ {
+ name: "default",
+ rootPerm: 0,
+ wantPerm: 0o750,
+ },
+ {
+ name: "custom",
+ rootPerm: 0o700,
+ wantPerm: 0o700,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ tfs := &testUploadFS{}
+ New(Config{
+ RootDir: "uploads",
+ RootFs: tfs,
+ RootPerms: tt.rootPerm,
+ })
+
+ if tfs.mkdirPath != "uploads" {
+ t.Fatalf("expected RootFs prefix %q, got %q", "uploads", tfs.mkdirPath)
+ }
+ if tfs.mkdirPerm != tt.wantPerm {
+ t.Fatalf("expected RootPerms %o, got %o", tt.wantPerm, tfs.mkdirPerm)
+ }
+ })
+ }
+}
+
+func TestValidateUploadPathPreservesLeadingDot(t *testing.T) {
+ t.Parallel()
+
+ path := filepath.Join(".hidden", "file.txt")
+
+ normalized, err := validateUploadPath(path)
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ if !strings.HasPrefix(normalized.osPath, ".") {
+ t.Fatalf("expected os path %q to preserve leading dot", normalized.osPath)
+ }
+ if normalized.slashPath != ".hidden/file.txt" {
+ t.Fatalf("expected slash path %q, got %q", ".hidden/file.txt", normalized.slashPath)
+ }
+}
+
func Test_App_Test_Goroutine_Leak_Compare(t *testing.T) {
t.Parallel()
diff --git a/ctx.go b/ctx.go
index 92ce2f69591..761a30700a9 100644
--- a/ctx.go
+++ b/ctx.go
@@ -7,10 +7,14 @@ package fiber
import (
"context"
"crypto/tls"
+ "errors"
"fmt"
"io"
+ "io/fs"
"maps"
"mime/multipart"
+ "os"
+ "path/filepath"
"strconv"
"strings"
"sync/atomic"
@@ -511,12 +515,57 @@ func (c *DefaultCtx) IsPreflight() bool {
}
// SaveFile saves any multipart file to disk.
-func (*DefaultCtx) SaveFile(fileheader *multipart.FileHeader, path string) error {
- return fasthttp.SaveMultipartFile(fileheader, path)
+func (c *DefaultCtx) SaveFile(fileheader *multipart.FileHeader, path string) error {
+ normalized, err := validateUploadPath(path)
+ if err != nil {
+ return err
+ }
+
+ if c.app.config.RootFs != nil {
+ fsPath := storageUploadPath(c.app.config.uploadRootFSPrefix, normalized.slashPath)
+ err = ensureNoSymlinkFS(c.app.config.RootFs, fsPath)
+ if err != nil {
+ return err
+ }
+ return saveMultipartFileToFS(fileheader, fsPath, c.app.config.uploadRootFSWriter)
+ }
+
+ fullPath := normalized.osPath
+ if root := c.app.config.uploadRootDir; root != "" {
+ fullPath = filepath.Join(root, normalized.osPath)
+ err = ensureUploadPathWithinRoot(c.app.config.uploadRootEval, fullPath)
+ if err != nil {
+ return err
+ }
+ }
+
+ return fasthttp.SaveMultipartFile(fileheader, fullPath)
}
// SaveFileToStorage saves any multipart file to an external storage system.
func (c *DefaultCtx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error {
+ normalized, err := validateUploadPath(path)
+ if err != nil {
+ return err
+ }
+
+ if c.app.config.RootFs != nil {
+ fsPath := storageUploadPath(c.app.config.uploadRootFSPrefix, normalized.slashPath)
+ err = ensureNoSymlinkFS(c.app.config.RootFs, fsPath)
+ if err != nil {
+ return err
+ }
+ }
+ if root := c.app.config.uploadRootDir; root != "" {
+ fullPath := filepath.Join(root, filepath.FromSlash(normalized.slashPath))
+ err = ensureUploadPathWithinRoot(c.app.config.uploadRootEval, fullPath)
+ if err != nil {
+ return err
+ }
+ }
+
+ storagePath := storageUploadPath(c.app.config.uploadRootPath, normalized.slashPath)
+
file, err := fileheader.Open()
if err != nil {
return fmt.Errorf("failed to open: %w", err)
@@ -546,13 +595,52 @@ func (c *DefaultCtx) SaveFileToStorage(fileheader *multipart.FileHeader, path st
data := append([]byte(nil), buf.Bytes()...)
- if err := storage.SetWithContext(c.Context(), path, data, 0); err != nil {
+ if err := storage.SetWithContext(c.Context(), storagePath, data, 0); err != nil {
return fmt.Errorf("failed to store: %w", err)
}
return nil
}
+func saveMultipartFileToFS(
+ fileheader *multipart.FileHeader,
+ path string,
+ fsys interface {
+ fs.FS
+ OpenFile(name string, flag int, perm fs.FileMode) (fs.File, error)
+ },
+) error {
+ file, err := fileheader.Open()
+ if err != nil {
+ return fmt.Errorf("failed to open multipart file: %w", err)
+ }
+ defer file.Close() //nolint:errcheck // not needed
+
+ dst, err := fsys.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
+ if err != nil {
+ return fmt.Errorf("failed to open upload destination: %w", err)
+ }
+ writer, ok := dst.(io.Writer)
+ if !ok {
+ closeErr := dst.Close()
+ if closeErr != nil {
+ return fmt.Errorf("failed to close upload destination: %w", closeErr)
+ }
+ return errors.New("failed to open upload destination for write")
+ }
+ if _, err = io.Copy(writer, file); err != nil {
+ closeErr := dst.Close()
+ if closeErr != nil {
+ return fmt.Errorf("failed to close upload destination: %w", closeErr)
+ }
+ return fmt.Errorf("failed to copy upload data: %w", err)
+ }
+ if err := dst.Close(); err != nil {
+ return fmt.Errorf("failed to close upload destination: %w", err)
+ }
+ return nil
+}
+
// Secure returns whether a secure connection was established.
func (c *DefaultCtx) Secure() bool {
return c.Protocol() == schemeHTTPS
diff --git a/ctx_test.go b/ctx_test.go
index 228ac6994b1..66aa3a41f0c 100644
--- a/ctx_test.go
+++ b/ctx_test.go
@@ -17,6 +17,7 @@ import (
"errors"
"fmt"
"io"
+ "io/fs"
"math"
"mime/multipart"
"net"
@@ -4744,25 +4745,20 @@ func Test_Ctx_RouteNormalized(t *testing.T) {
func Test_Ctx_SaveFile(t *testing.T) {
// TODO We should clean this up
t.Parallel()
- app := New()
+ rootDir := t.TempDir()
+ app := New(Config{RootDir: rootDir})
app.Post("/test", func(c Ctx) error {
fh, err := c.Req().FormFile("file")
require.NoError(t, err)
- tempFile, err := os.CreateTemp(os.TempDir(), "test-")
- require.NoError(t, err)
-
- defer func(file *os.File) {
- closeErr := file.Close()
- require.NoError(t, closeErr)
- closeErr = os.Remove(file.Name())
- require.NoError(t, closeErr)
- }(tempFile)
- err = c.SaveFile(fh, tempFile.Name())
+ relativePath := "upload.txt"
+ err = c.SaveFile(fh, relativePath)
require.NoError(t, err)
- bs, err := os.ReadFile(tempFile.Name())
+ targetPath := filepath.Join(rootDir, relativePath)
+ // #nosec G304 -- reading from test-controlled temp directory.
+ bs, err := os.ReadFile(targetPath)
require.NoError(t, err)
require.Equal(t, "hello world", string(bs))
return nil
@@ -4787,6 +4783,57 @@ func Test_Ctx_SaveFile(t *testing.T) {
require.Equal(t, StatusOK, resp.StatusCode, "Status code")
}
+// go test -run Test_Ctx_SaveFile_RootDirTraversal
+func Test_Ctx_SaveFile_RootDirTraversal(t *testing.T) {
+ t.Parallel()
+
+ tests := map[string]string{
+ "traversal": filepath.Join("..", "outside.txt"),
+ "absolute": filepath.Join(t.TempDir(), "abs.txt"),
+ }
+
+ for name, targetPath := range tests {
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+ rootDir := t.TempDir()
+ app := New(Config{RootDir: rootDir})
+
+ app.Post("/test", func(c Ctx) error {
+ fh, err := c.FormFile("file")
+ require.NoError(t, err)
+
+ err = c.SaveFile(fh, targetPath)
+ require.Error(t, err)
+ return c.SendStatus(StatusOK)
+ })
+
+ body := &bytes.Buffer{}
+ writer := multipart.NewWriter(body)
+ ioWriter, err := writer.CreateFormFile("file", "test")
+ require.NoError(t, err)
+ _, err = ioWriter.Write([]byte("hello world"))
+ require.NoError(t, err)
+ require.NoError(t, writer.Close())
+
+ req := httptest.NewRequest(MethodPost, "/test", body)
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes())))
+
+ resp, err := app.Test(req)
+ require.NoError(t, err, "app.Test(req)")
+ require.Equal(t, StatusOK, resp.StatusCode, "Status code")
+
+ expectedPath := targetPath
+ if !filepath.IsAbs(expectedPath) {
+ expectedPath = filepath.Join(rootDir, expectedPath)
+ }
+ expectedPath = filepath.Clean(expectedPath)
+ _, statErr := os.Stat(expectedPath)
+ require.Error(t, statErr)
+ })
+ }
+}
+
func createMultipartFileHeader(t *testing.T, filename string, data []byte) *multipart.FileHeader {
t.Helper()
@@ -4814,6 +4861,56 @@ func createMultipartFileHeader(t *testing.T, filename string, data []byte) *mult
return files[0]
}
+type rootDirFS struct {
+ base string
+}
+
+func (fsys rootDirFS) Open(name string) (fs.File, error) {
+ return os.Open(filepath.Join(fsys.base, filepath.FromSlash(name))) //nolint:wrapcheck // test helper passes temp paths directly.
+}
+
+func (fsys rootDirFS) OpenFile(name string, flag int, perm fs.FileMode) (fs.File, error) {
+ fullPath := filepath.Join(fsys.base, filepath.FromSlash(name))
+ if err := os.MkdirAll(filepath.Dir(fullPath), 0o750); err != nil {
+ return nil, fmt.Errorf("failed to create directory: %w", err)
+ }
+ return os.OpenFile(fullPath, flag, perm) //nolint:wrapcheck,gosec // test helper uses temp paths and returns os errors directly.
+}
+
+func (fsys rootDirFS) MkdirAll(path string, perm fs.FileMode) error {
+ if err := os.MkdirAll(filepath.Join(fsys.base, filepath.FromSlash(path)), perm); err != nil {
+ return fmt.Errorf("failed to create root dir: %w", err)
+ }
+ return nil
+}
+
+func (fsys rootDirFS) ReadDir(name string) ([]fs.DirEntry, error) {
+ return os.ReadDir(filepath.Join(fsys.base, filepath.FromSlash(name))) //nolint:wrapcheck // test helper returns os errors directly.
+}
+
+func (fsys rootDirFS) Remove(name string) error {
+ return os.Remove(filepath.Join(fsys.base, filepath.FromSlash(name))) //nolint:wrapcheck // test helper returns os errors directly.
+}
+
+type recordingStorage struct {
+ *memory.Storage
+ setKeys []string
+}
+
+func newRecordingStorage() *recordingStorage {
+ return &recordingStorage{Storage: memory.New()}
+}
+
+func (s *recordingStorage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error {
+ s.setKeys = append(s.setKeys, key)
+ return s.Storage.SetWithContext(ctx, key, val, exp)
+}
+
+func (s *recordingStorage) Set(key string, val []byte, exp time.Duration) error {
+ s.setKeys = append(s.setKeys, key)
+ return s.Storage.Set(key, val, exp)
+}
+
// go test -run Test_Ctx_SaveFileToStorage
func Test_Ctx_SaveFileToStorage(t *testing.T) {
t.Parallel()
@@ -4856,6 +4953,104 @@ func Test_Ctx_SaveFileToStorage(t *testing.T) {
require.Equal(t, StatusOK, resp.StatusCode, "Status code")
}
+func Test_Ctx_SaveFileToStorage_RootFsPrefix(t *testing.T) {
+ t.Parallel()
+
+ baseDir := t.TempDir()
+ app := New(Config{
+ RootDir: "uploads",
+ RootFs: rootDirFS{base: baseDir},
+ })
+ storage := newRecordingStorage()
+ ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
+
+ t.Cleanup(func() {
+ app.ReleaseCtx(ctx)
+ })
+
+ fileHeader := createMultipartFileHeader(t, "rootfs.txt", []byte("hello rootfs"))
+
+ err := ctx.SaveFileToStorage(fileHeader, "rootfs.txt", storage)
+ require.NoError(t, err)
+ require.Equal(t, []string{"uploads/rootfs.txt"}, storage.setKeys)
+}
+
+func Test_Ctx_SaveFileToStorage_RootFs_SymlinkEscape(t *testing.T) {
+ t.Parallel()
+
+ if runtime.GOOS == "windows" {
+ t.Skip("symlink behavior differs on Windows")
+ }
+
+ baseDir := t.TempDir()
+ uploadsDir := filepath.Join(baseDir, "uploads")
+ require.NoError(t, os.MkdirAll(uploadsDir, 0o750))
+ require.NoError(t, os.Symlink(baseDir, filepath.Join(uploadsDir, "link")))
+
+ app := New(Config{
+ RootDir: "uploads",
+ RootFs: rootDirFS{base: baseDir},
+ })
+ storage := newRecordingStorage()
+ ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
+
+ t.Cleanup(func() {
+ app.ReleaseCtx(ctx)
+ })
+
+ fileHeader := createMultipartFileHeader(t, "rootfs.txt", []byte("hello rootfs"))
+
+ err := ctx.SaveFileToStorage(fileHeader, "link/rootfs.txt", storage)
+ require.ErrorIs(t, err, ErrUploadPathEscapesRoot)
+ require.Empty(t, storage.setKeys)
+}
+
+// go test -run Test_Ctx_SaveFileToStorage_RootDirTraversal
+func Test_Ctx_SaveFileToStorage_RootDirTraversal(t *testing.T) {
+ t.Parallel()
+
+ tests := map[string]string{
+ "traversal": filepath.Join("..", "outside"),
+ "absolute": filepath.Join(t.TempDir(), "abs"),
+ }
+
+ for name, targetPath := range tests {
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+ rootDir := t.TempDir()
+ storage := newRecordingStorage()
+ app := New(Config{RootDir: rootDir})
+
+ app.Post("/test", func(c Ctx) error {
+ fh, err := c.FormFile("file")
+ require.NoError(t, err)
+
+ err = c.SaveFileToStorage(fh, targetPath, storage)
+ require.Error(t, err)
+ return c.SendStatus(StatusOK)
+ })
+
+ body := &bytes.Buffer{}
+ writer := multipart.NewWriter(body)
+ ioWriter, err := writer.CreateFormFile("file", "test")
+ require.NoError(t, err)
+ _, err = ioWriter.Write([]byte("hello world"))
+ require.NoError(t, err)
+ require.NoError(t, writer.Close())
+
+ req := httptest.NewRequest(MethodPost, "/test", body)
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes())))
+
+ resp, err := app.Test(req)
+ require.NoError(t, err, "app.Test(req)")
+ require.Equal(t, StatusOK, resp.StatusCode, "Status code")
+
+ require.Empty(t, storage.setKeys)
+ })
+ }
+}
+
func Test_Ctx_SaveFileToStorage_LargeUpload(t *testing.T) {
t.Parallel()
const (
@@ -4926,6 +5121,85 @@ func Test_Ctx_SaveFileToStorage_LimitExceededUnknownSize(t *testing.T) {
require.ErrorIs(t, err, fasthttp.ErrBodyTooLarge)
}
+func Test_Ctx_SaveFileToStorage_InvalidPath(t *testing.T) {
+ t.Parallel()
+
+ app := New()
+ storage := newRecordingStorage()
+ ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
+
+ t.Cleanup(func() {
+ app.ReleaseCtx(ctx)
+ })
+
+ fileHeader := createMultipartFileHeader(t, "rootfs.txt", []byte("hello"))
+
+ invalidPaths := []string{"", "..", "/absolute"}
+ if runtime.GOOS == "windows" {
+ invalidPaths = append(invalidPaths, `C:\absolute`)
+ }
+
+ for _, path := range invalidPaths {
+ err := ctx.SaveFileToStorage(fileHeader, path, storage)
+ require.ErrorIs(t, err, ErrInvalidUploadPath)
+ }
+
+ require.Empty(t, storage.setKeys)
+}
+
+func Test_Ctx_SaveFile_RootFs(t *testing.T) {
+ t.Parallel()
+
+ baseDir := t.TempDir()
+ app := New(Config{
+ RootDir: "uploads",
+ RootFs: rootDirFS{base: baseDir},
+ })
+ ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
+
+ t.Cleanup(func() {
+ app.ReleaseCtx(ctx)
+ })
+
+ fileHeader := createMultipartFileHeader(t, "rootfs.txt", []byte("hello rootfs"))
+
+ err := ctx.SaveFile(fileHeader, "rootfs.txt")
+ require.NoError(t, err)
+
+ //nolint:gosec // reading from test-controlled temp directory.
+ content, err := os.ReadFile(filepath.Join(baseDir, "uploads", "rootfs.txt"))
+ require.NoError(t, err)
+ require.Equal(t, "hello rootfs", string(content))
+}
+
+func Test_Ctx_SaveFile_RootFs_SymlinkEscape(t *testing.T) {
+ t.Parallel()
+
+ if runtime.GOOS == "windows" {
+ t.Skip("symlink behavior differs on Windows")
+ }
+
+ baseDir := t.TempDir()
+ uploadsDir := filepath.Join(baseDir, "uploads")
+ require.NoError(t, os.MkdirAll(uploadsDir, 0o750))
+ require.NoError(t, os.Symlink(baseDir, filepath.Join(uploadsDir, "link")))
+
+ app := New(Config{
+ RootDir: "uploads",
+ RootFs: rootDirFS{base: baseDir},
+ })
+ ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
+
+ t.Cleanup(func() {
+ app.ReleaseCtx(ctx)
+ })
+
+ fileHeader := createMultipartFileHeader(t, "rootfs.txt", []byte("hello rootfs"))
+
+ err := ctx.SaveFile(fileHeader, "link/rootfs.txt")
+ require.ErrorIs(t, err, ErrUploadPathEscapesRoot)
+}
+
type captureStorage struct {
t *testing.T
data map[string][]byte
@@ -4998,6 +5272,46 @@ func (s *captureStorage) Close() error {
return nil
}
+type errStorage struct {
+ err error
+}
+
+func (s errStorage) GetWithContext(context.Context, string) ([]byte, error) {
+ return nil, s.err
+}
+
+func (s errStorage) Get(string) ([]byte, error) {
+ return nil, s.err
+}
+
+func (s errStorage) SetWithContext(context.Context, string, []byte, time.Duration) error {
+ return s.err
+}
+
+func (s errStorage) Set(string, []byte, time.Duration) error {
+ return s.err
+}
+
+func (s errStorage) DeleteWithContext(context.Context, string) error {
+ return s.err
+}
+
+func (s errStorage) Delete(string) error {
+ return s.err
+}
+
+func (s errStorage) ResetWithContext(context.Context) error {
+ return s.err
+}
+
+func (s errStorage) Reset() error {
+ return s.err
+}
+
+func (s errStorage) Close() error {
+ return s.err
+}
+
func Test_Ctx_SaveFileToStorage_BufferNotReused(t *testing.T) {
t.Parallel()
@@ -5026,6 +5340,43 @@ func Test_Ctx_SaveFileToStorage_BufferNotReused(t *testing.T) {
require.Equal(t, firstPayload, firstStored, "stored data must not rely on pooled buffers")
}
+func Test_Ctx_SaveFileToStorage_StorageError(t *testing.T) {
+ t.Parallel()
+
+ app := New()
+ ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
+
+ t.Cleanup(func() {
+ app.ReleaseCtx(ctx)
+ })
+
+ fileHeader := createMultipartFileHeader(t, "rootfs.txt", []byte("hello"))
+ expectedErr := errors.New("store failed")
+
+ err := ctx.SaveFileToStorage(fileHeader, "test", errStorage{err: expectedErr})
+ require.ErrorIs(t, err, expectedErr)
+}
+
+func Test_Ctx_SaveFileToStorage_RootDirPrefix(t *testing.T) {
+ t.Parallel()
+
+ rootDir := filepath.Join(t.TempDir(), "uploads")
+ app := New(Config{RootDir: rootDir})
+ storage := newRecordingStorage()
+ ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
+
+ t.Cleanup(func() {
+ app.ReleaseCtx(ctx)
+ })
+
+ fileHeader := createMultipartFileHeader(t, "rootdir.txt", []byte("hello"))
+
+ err := ctx.SaveFileToStorage(fileHeader, "rootdir.txt", storage)
+ require.NoError(t, err)
+ expectedPath := storageUploadPath(storageRootPrefix(rootDir), "rootdir.txt")
+ require.Equal(t, []string{expectedPath}, storage.setKeys)
+}
+
type mockContextAwareStorage struct {
t *testing.T
key any
diff --git a/docs/api/ctx.md b/docs/api/ctx.md
index a4941254299..f35a544c655 100644
--- a/docs/api/ctx.md
+++ b/docs/api/ctx.md
@@ -1751,6 +1751,8 @@ Method is used to save **any** multipart file to disk.
func (c fiber.Ctx) SaveFile(fh *multipart.FileHeader, path string) error
```
+Paths must be relative and cannot contain `..` segments or absolute prefixes. When `Config.RootDir` or `Config.RootFs` is set, the path is resolved against that root and attempts to escape it are rejected. Storage keys are prefixed by the configured root when one is set.
+
```go title="Example"
app.Post("/", func(c fiber.Ctx) error {
// Parse the multipart form:
@@ -1784,6 +1786,8 @@ Method is used to save **any** multipart file to an external storage system.
func (c fiber.Ctx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error
```
+Paths must be relative and cannot contain `..` segments or absolute prefixes. When `Config.RootDir` or `Config.RootFs` is set, the path is resolved against that root and attempts to escape it are rejected.
+
```go title="Example"
storage := memory.New()
diff --git a/docs/api/fiber.md b/docs/api/fiber.md
index 0b61b7115c1..d7deaea83df 100644
--- a/docs/api/fiber.md
+++ b/docs/api/fiber.md
@@ -76,6 +76,9 @@ app := fiber.New(fiber.Config{
| ReadTimeout | `time.Duration` | The amount of time allowed to read the full request, including the body. The default timeout is unlimited. | `0` |
| ReduceMemoryUsage | `bool` | Aggressively reduces memory usage at the cost of higher CPU usage if set to true. | `false` |
| RequestMethods | `[]string` | RequestMethods provides customizability for HTTP methods. You can add/remove methods as you wish. | `DefaultMethods` |
+| RootDir | `string` | Base directory for SaveFile/SaveFileToStorage uploads. Relative paths are resolved against this directory and must not escape it. | `""` |
+| RootPerms | `fs.FileMode` | Permissions used when creating RootDir or RootFs prefixes for uploads. | `0o750` |
+| RootFs | `fs.FS` | Filesystem used for SaveFile/SaveFileToStorage uploads. When set, RootDir is treated as a relative prefix within the filesystem. | `nil` |
| | `string` | Enables the `Server` HTTP header with the given value. | `""` |
| StreamRequestBody | `bool` | StreamRequestBody enables request body streaming, and calls the handler sooner when given body is larger than the current limit. | `false` |
| StrictRouting | `bool` | When enabled, the router treats `/foo` and `/foo/` as different. Otherwise, the router treats `/foo` and `/foo/` as the same. | `false` |
diff --git a/docs/whats_new.md b/docs/whats_new.md
index 275b4172833..2475b88a4df 100644
--- a/docs/whats_new.md
+++ b/docs/whats_new.md
@@ -77,6 +77,7 @@ We have made several changes to the Fiber app, including:
- `ListenerNetwork` (previously `Network`)
- **Trusted Proxy Configuration**: The `EnabledTrustedProxyCheck` has been moved to `app.Config.TrustProxy`, and `TrustedProxies` has been moved to `TrustProxyConfig.Proxies`. Additionally, `ProxyHeader` must be set to read client IPs from proxy headers (e.g., `X-Forwarded-For`).
- **XMLDecoder Config Property**: The `XMLDecoder` property has been added to allow usage of 3rd-party XML libraries in XML binder.
+- **Upload Root Permissions**: The `RootPerms` property controls the permissions used when creating `RootDir` or `RootFs` prefixes for uploads (default `0o750`).
### New Methods
diff --git a/error.go b/error.go
index 93c6f7cbcea..5fabaa50ca8 100644
--- a/error.go
+++ b/error.go
@@ -22,6 +22,10 @@ var (
ErrNoViewEngineConfigured = errors.New("fiber: no view engine configured")
// ErrAutoCertWithCertFile indicates AutoCertManager cannot be used with CertFile/CertKeyFile.
ErrAutoCertWithCertFile = errors.New("tls: AutoCertManager cannot be combined with CertFile/CertKeyFile")
+ // ErrInvalidUploadPath indicates the upload path is invalid.
+ ErrInvalidUploadPath = errors.New("upload: path must be relative and must not contain '..' or absolute prefixes")
+ // ErrUploadPathEscapesRoot indicates the upload path escapes the configured root.
+ ErrUploadPathEscapesRoot = errors.New("upload: path escapes the configured root")
)
// Fiber redirection errors
diff --git a/listen_test.go b/listen_test.go
index e02f8d0d1df..73119aee62f 100644
--- a/listen_test.go
+++ b/listen_test.go
@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
- "log" //nolint:depguard // TODO: Required to capture output, use internal log package instead
"net"
"os"
"path/filepath"
@@ -21,6 +20,8 @@ import (
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil"
"golang.org/x/crypto/acme/autocert"
+
+ fiberlog "github.com/gofiber/fiber/v3/log"
)
// go test -run Test_Listen
@@ -876,11 +877,11 @@ func captureOutput(f func()) string {
defer func() {
os.Stdout = stdout
os.Stderr = stderr
- log.SetOutput(os.Stderr)
+ fiberlog.SetOutput(os.Stderr)
}()
os.Stdout = writer
os.Stderr = writer
- log.SetOutput(writer)
+ fiberlog.SetOutput(writer)
out := make(chan string)
go func() {
var buf bytes.Buffer
diff --git a/upload.go b/upload.go
new file mode 100644
index 00000000000..11aef32cc50
--- /dev/null
+++ b/upload.go
@@ -0,0 +1,332 @@
+package fiber
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "io/fs"
+ "os"
+ pathpkg "path"
+ "path/filepath"
+ "slices"
+ "strings"
+ "time"
+
+ "github.com/gofiber/utils/v2"
+)
+
+type uploadPath struct {
+ osPath string
+ slashPath string
+}
+
+type tempProbeFile interface {
+ WriteString(string) (int, error)
+ Close() error
+ Name() string
+}
+
+func (app *App) configureUploads() {
+ if app.config.RootFs != nil {
+ writer, ok := app.config.RootFs.(interface {
+ fs.FS
+ OpenFile(name string, flag int, perm fs.FileMode) (fs.File, error)
+ })
+ if !ok {
+ panic("fiber: RootFs must implement OpenFile for uploads")
+ }
+ mkdirer, ok := app.config.RootFs.(interface {
+ MkdirAll(path string, perm fs.FileMode) error
+ })
+ if !ok {
+ panic("fiber: RootFs must implement MkdirAll for uploads")
+ }
+
+ prefix, err := cleanUploadRootPrefix(app.config.RootDir)
+ if err != nil {
+ panic(fmt.Sprintf("fiber: invalid RootDir for RootFs: %v", err))
+ }
+
+ if prefix != "" {
+ if err := mkdirer.MkdirAll(prefix, app.config.RootPerms); err != nil {
+ panic(fmt.Sprintf("fiber: failed to create RootFs prefix %q: %v", prefix, err))
+ }
+ }
+
+ if err := probeUploadFSWritable(writer, prefix); err != nil {
+ panic(fmt.Sprintf("fiber: RootFs not writable: %v", err))
+ }
+
+ app.config.uploadRootFSPrefix = prefix
+ app.config.uploadRootFSWriter = writer
+ app.config.uploadRootPath = prefix
+ return
+ }
+
+ if app.config.RootDir == "" {
+ return
+ }
+
+ rootAbs, err := filepath.Abs(app.config.RootDir)
+ if err != nil {
+ panic(fmt.Sprintf("fiber: failed to resolve RootDir: %v", err))
+ }
+ rootAbs = filepath.Clean(rootAbs)
+ if err = os.MkdirAll(rootAbs, app.config.RootPerms); err != nil {
+ panic(fmt.Sprintf("fiber: failed to create RootDir %q: %v", rootAbs, err))
+ }
+ rootEval, err := filepath.EvalSymlinks(rootAbs)
+ if err != nil {
+ panic(fmt.Sprintf("fiber: failed to resolve RootDir symlinks %q: %v", rootAbs, err))
+ }
+ rootEval = filepath.Clean(rootEval)
+ if err := probeUploadDirWritable(rootAbs); err != nil {
+ panic(fmt.Sprintf("fiber: RootDir not writable %q: %v", rootAbs, err))
+ }
+
+ app.config.uploadRootDir = rootAbs
+ app.config.uploadRootEval = rootEval
+ app.config.uploadRootPath = storageRootPrefix(app.config.RootDir)
+}
+
+func validateUploadPath(path string) (uploadPath, error) {
+ if path == "" {
+ return uploadPath{}, ErrInvalidUploadPath
+ }
+ if isAbsUploadPath(path) {
+ return uploadPath{}, ErrInvalidUploadPath
+ }
+ if containsDotDot(path) {
+ return uploadPath{}, ErrInvalidUploadPath
+ }
+
+ cleanOS := filepath.Clean(path)
+
+ cleanSlash := pathpkg.Clean("/" + filepath.ToSlash(path))
+ cleanSlash = utils.TrimLeft(cleanSlash, '/')
+
+ if cleanOS == "." || cleanOS == "" || cleanSlash == "." || cleanSlash == "" {
+ return uploadPath{}, ErrInvalidUploadPath
+ }
+ if !fs.ValidPath(cleanSlash) {
+ return uploadPath{}, ErrInvalidUploadPath
+ }
+
+ return uploadPath{
+ osPath: cleanOS,
+ slashPath: cleanSlash,
+ }, nil
+}
+
+func isAbsUploadPath(path string) bool {
+ if filepath.IsAbs(path) {
+ return true
+ }
+ if filepath.VolumeName(path) != "" {
+ return true
+ }
+ return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "\\")
+}
+
+func containsDotDot(path string) bool {
+ return slices.Contains(strings.FieldsFunc(path, func(r rune) bool {
+ return r == '/' || r == '\\'
+ }), "..")
+}
+
+func storageRootPrefix(root string) string {
+ if root == "" {
+ return ""
+ }
+ root = filepath.Clean(root)
+ if volume := filepath.VolumeName(root); volume != "" {
+ root = root[len(volume):]
+ }
+ root = filepath.ToSlash(root)
+ root = utils.TrimLeft(root, '/')
+ if root == "." {
+ return ""
+ }
+ return root
+}
+
+func cleanUploadRootPrefix(root string) (string, error) {
+ if root == "" || root == "." {
+ return "", nil
+ }
+ if isAbsUploadPath(root) || containsDotDot(root) {
+ return "", ErrInvalidUploadPath
+ }
+ cleanSlash := pathpkg.Clean("/" + filepath.ToSlash(root))
+ cleanSlash = utils.TrimLeft(cleanSlash, '/')
+ if cleanSlash == "." || cleanSlash == "" {
+ return "", ErrInvalidUploadPath
+ }
+ if !fs.ValidPath(cleanSlash) {
+ return "", ErrInvalidUploadPath
+ }
+ return cleanSlash, nil
+}
+
+func ensureUploadPathWithinRoot(rootEval, fullPath string) error {
+ parent := filepath.Dir(fullPath)
+ parentEval, err := evalExistingPath(parent)
+ if err != nil {
+ return err
+ }
+ if !hasPathPrefix(parentEval, rootEval) {
+ return ErrUploadPathEscapesRoot
+ }
+ return nil
+}
+
+func evalExistingPath(path string) (string, error) {
+ current := path
+ for {
+ _, err := os.Lstat(current)
+ if err == nil {
+ resolved, resolveErr := filepath.EvalSymlinks(current)
+ if resolveErr != nil {
+ return "", fmt.Errorf("failed to resolve symlinks for %q: %w", current, resolveErr)
+ }
+ return filepath.Clean(resolved), nil
+ }
+ if !os.IsNotExist(err) {
+ return "", fmt.Errorf("failed to stat %q: %w", current, err)
+ }
+ parent := filepath.Dir(current)
+ if parent == current {
+ return "", fmt.Errorf("failed to resolve upload path %q: %w", current, err)
+ }
+ current = parent
+ }
+}
+
+func hasPathPrefix(path, prefix string) bool {
+ if path == prefix {
+ return true
+ }
+ if !strings.HasPrefix(path, prefix) {
+ return false
+ }
+ if strings.HasSuffix(prefix, string(os.PathSeparator)) {
+ return true
+ }
+ if len(path) == len(prefix) {
+ return true
+ }
+ return len(path) > len(prefix) && path[len(prefix)] == os.PathSeparator
+}
+
+func ensureNoSymlinkFS(fsys fs.FS, fullPath string) error {
+ parts := strings.Split(fullPath, "/")
+ if len(parts) <= 1 {
+ return nil
+ }
+ for i := 0; i < len(parts)-1; i++ {
+ name := parts[i]
+ if name == "" {
+ continue
+ }
+ parent := strings.Join(parts[:i], "/")
+ if parent == "" {
+ parent = "."
+ }
+ entries, err := fs.ReadDir(fsys, parent)
+ if err != nil {
+ if errors.Is(err, fs.ErrNotExist) {
+ return nil
+ }
+ return fmt.Errorf("failed to read upload directory %q: %w", parent, err)
+ }
+ for _, entry := range entries {
+ if entry.Name() == name && entry.Type()&fs.ModeSymlink != 0 {
+ return ErrUploadPathEscapesRoot
+ }
+ }
+ }
+ return nil
+}
+
+func probeUploadDirWritable(root string) error {
+ return probeUploadDirWritableWith(root, func(dir, pattern string) (tempProbeFile, error) {
+ return os.CreateTemp(dir, pattern)
+ }, os.Remove)
+}
+
+func probeUploadDirWritableWith(
+ root string,
+ createFile func(dir, pattern string) (tempProbeFile, error),
+ removeFile func(name string) error,
+) error {
+ tempFile, err := createFile(root, ".fiber-upload-check-*")
+ if err != nil {
+ return fmt.Errorf("failed to create probe file: %w", err)
+ }
+ if _, err := tempFile.WriteString("fiber"); err != nil {
+ closeErr := tempFile.Close()
+ if closeErr != nil {
+ return fmt.Errorf("failed to close probe file: %w", closeErr)
+ }
+ removeErr := removeFile(tempFile.Name())
+ if removeErr != nil {
+ return fmt.Errorf("failed to remove probe file: %w", removeErr)
+ }
+ return fmt.Errorf("failed to write probe file: %w", err)
+ }
+ if err := tempFile.Close(); err != nil {
+ removeErr := removeFile(tempFile.Name())
+ if removeErr != nil {
+ return fmt.Errorf("failed to remove probe file: %w", removeErr)
+ }
+ return fmt.Errorf("failed to close probe file: %w", err)
+ }
+ if err := removeFile(tempFile.Name()); err != nil {
+ return fmt.Errorf("failed to remove probe file: %w", err)
+ }
+ return nil
+}
+
+func probeUploadFSWritable(fsys interface {
+ fs.FS
+ OpenFile(name string, flag int, perm fs.FileMode) (fs.File, error)
+}, prefix string,
+) error {
+ name := pathpkg.Join(prefix, fmt.Sprintf(".fiber-upload-check-%d", time.Now().UnixNano()))
+ file, err := fsys.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0o600)
+ if err != nil {
+ return fmt.Errorf("failed to create probe file: %w", err)
+ }
+ writer, ok := file.(io.Writer)
+ if !ok {
+ if closeErr := file.Close(); closeErr != nil {
+ return fmt.Errorf("failed to close probe file: %w", closeErr)
+ }
+ return errors.New("upload file is not writable")
+ }
+ if _, err := writer.Write([]byte("fiber")); err != nil {
+ closeErr := file.Close()
+ if closeErr != nil {
+ return fmt.Errorf("failed to close probe file: %w", closeErr)
+ }
+ return fmt.Errorf("failed to write probe file: %w", err)
+ }
+ if err := file.Close(); err != nil {
+ return fmt.Errorf("failed to close probe file: %w", err)
+ }
+ if remover, ok := fsys.(interface {
+ Remove(name string) error
+ }); ok {
+ if err := remover.Remove(name); err != nil {
+ return fmt.Errorf("failed to remove probe file: %w", err)
+ }
+ }
+ return nil
+}
+
+func storageUploadPath(prefix, cleanSlash string) string {
+ if prefix == "" {
+ return cleanSlash
+ }
+ return pathpkg.Join(prefix, cleanSlash)
+}
diff --git a/upload_test.go b/upload_test.go
new file mode 100644
index 00000000000..d84d64b7282
--- /dev/null
+++ b/upload_test.go
@@ -0,0 +1,763 @@
+package fiber
+
+import (
+ "errors"
+ "io"
+ "io/fs"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "testing"
+ "testing/fstest"
+ "time"
+)
+
+type stubFileInfo struct {
+ name string
+}
+
+func (fi stubFileInfo) Name() string { return fi.name }
+func (fi stubFileInfo) Size() int64 {
+ _ = fi
+ return 0
+}
+
+func (fi stubFileInfo) Mode() fs.FileMode {
+ _ = fi
+ return 0
+}
+
+func (fi stubFileInfo) ModTime() time.Time {
+ _ = fi
+ return time.Time{}
+}
+
+func (fi stubFileInfo) IsDir() bool {
+ _ = fi
+ return false
+}
+
+func (fi stubFileInfo) Sys() any {
+ _ = fi
+ return nil
+}
+
+type noWriteFile struct {
+ closeErr error
+}
+
+func (f *noWriteFile) Read(_ []byte) (int, error) {
+ _ = f
+ return 0, io.EOF
+}
+
+func (f *noWriteFile) Stat() (fs.FileInfo, error) {
+ _ = f
+ return stubFileInfo{name: "probe"}, nil
+}
+func (f *noWriteFile) Close() error { return f.closeErr }
+
+type writeFile struct {
+ writeErr error
+ closeErr error
+}
+
+func (f *writeFile) Read(_ []byte) (int, error) {
+ _ = f
+ return 0, io.EOF
+}
+
+func (f *writeFile) Stat() (fs.FileInfo, error) {
+ _ = f
+ return stubFileInfo{name: "probe"}, nil
+}
+func (f *writeFile) Close() error { return f.closeErr }
+func (f *writeFile) Write(_ []byte) (int, error) {
+ if f.writeErr != nil {
+ return 0, f.writeErr
+ }
+ return len("fiber"), nil
+}
+
+type probeFS struct {
+ file fs.File
+ openErr error
+}
+
+func (fsys probeFS) Open(_ string) (fs.File, error) {
+ if fsys.openErr != nil {
+ return nil, fsys.openErr
+ }
+ if fsys.file == nil {
+ return nil, fs.ErrNotExist
+ }
+ return fsys.file, nil
+}
+
+func (fsys probeFS) OpenFile(_ string, _ int, _ fs.FileMode) (fs.File, error) {
+ if fsys.openErr != nil {
+ return nil, fsys.openErr
+ }
+ return fsys.file, nil
+}
+
+type removeErrFS struct {
+ probeFS
+ removeErr error
+}
+
+func (fsys removeErrFS) Remove(_ string) error {
+ return fsys.removeErr
+}
+
+type readDirErrFS struct {
+ err error
+}
+
+type rootFSMissingOpenFile struct{}
+
+func (rootFSMissingOpenFile) Open(_ string) (fs.File, error) { return nil, fs.ErrNotExist }
+func (rootFSMissingOpenFile) MkdirAll(_ string, _ fs.FileMode) error {
+ return nil
+}
+
+type rootFSMissingMkdirAll struct{}
+
+func (rootFSMissingMkdirAll) Open(_ string) (fs.File, error) { return nil, fs.ErrNotExist }
+func (rootFSMissingMkdirAll) OpenFile(_ string, _ int, _ fs.FileMode) (fs.File, error) {
+ return &writeFile{}, nil
+}
+
+type rootFSMkdirErr struct {
+ err error
+}
+
+func (fsys rootFSMkdirErr) Open(_ string) (fs.File, error) {
+ _ = fsys
+ return nil, fs.ErrNotExist
+}
+
+func (fsys rootFSMkdirErr) OpenFile(_ string, _ int, _ fs.FileMode) (fs.File, error) {
+ _ = fsys
+ return &writeFile{}, nil
+}
+func (fsys rootFSMkdirErr) MkdirAll(_ string, _ fs.FileMode) error { return fsys.err }
+
+type rootFSNoWriter struct{}
+
+func (rootFSNoWriter) Open(_ string) (fs.File, error) { return nil, fs.ErrNotExist }
+func (rootFSNoWriter) OpenFile(_ string, _ int, _ fs.FileMode) (fs.File, error) {
+ return &noWriteFile{}, nil
+}
+func (rootFSNoWriter) MkdirAll(_ string, _ fs.FileMode) error { return nil }
+
+type probeTempFile struct {
+ name *string
+ writeErr error
+ closeErr error
+}
+
+func (f *probeTempFile) WriteString(_ string) (int, error) {
+ if f.writeErr != nil {
+ return 0, f.writeErr
+ }
+ return len("fiber"), nil
+}
+
+func (f *probeTempFile) Close() error {
+ return f.closeErr
+}
+
+func (f *probeTempFile) Name() string {
+ if f.name == nil {
+ return ""
+ }
+ return *f.name
+}
+
+func stringPtr(value string) *string {
+ return &value
+}
+
+func (fsys readDirErrFS) Open(_ string) (fs.File, error) {
+ _ = fsys
+ return nil, fs.ErrNotExist
+}
+
+func (fsys readDirErrFS) ReadDir(_ string) ([]fs.DirEntry, error) {
+ return nil, fsys.err
+}
+
+func TestValidateUploadPath(t *testing.T) {
+ t.Parallel()
+
+ absPath := "/var/uploads/file.txt"
+ if runtime.GOOS == "windows" {
+ absPath = `C:\uploads\file.txt`
+ }
+
+ tests := []struct {
+ name string
+ path string
+ wantErr bool
+ }{
+ {name: "valid", path: "uploads/file.txt"},
+ {name: "valid_dot_prefix", path: ".hidden/file.txt"},
+ {name: "dot", path: ".", wantErr: true},
+ {name: "empty", path: "", wantErr: true},
+ {name: "absolute", path: absPath, wantErr: true},
+ {name: "dotdot", path: "../file.txt", wantErr: true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ normalized, err := validateUploadPath(tt.path)
+ if tt.wantErr {
+ if err == nil {
+ t.Fatalf("expected error for %q", tt.path)
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("unexpected error for %q: %v", tt.path, err)
+ }
+ if normalized.osPath == "" || normalized.slashPath == "" {
+ t.Fatalf("expected normalized paths for %q", tt.path)
+ }
+ if strings.Contains(normalized.slashPath, `\`) {
+ t.Fatalf("expected slash path to use forward slashes, got %q", normalized.slashPath)
+ }
+ })
+ }
+}
+
+func TestIsAbsUploadPath(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ path string
+ want bool
+ }{
+ {name: "relative", path: "uploads/file.txt", want: false},
+ {name: "slash_abs", path: "/uploads/file.txt", want: true},
+ {name: "backslash_abs", path: `\uploads\file.txt`, want: true},
+ }
+
+ if runtime.GOOS == "windows" {
+ tests = append(tests, struct {
+ name string
+ path string
+ want bool
+ }{name: "volume_abs", path: `C:\uploads\file.txt`, want: true})
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ if got := isAbsUploadPath(tt.path); got != tt.want {
+ t.Fatalf("expected %v for %q, got %v", tt.want, tt.path, got)
+ }
+ })
+ }
+}
+
+func TestContainsDotDot(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ path string
+ want bool
+ }{
+ {name: "dotdot_segment", path: "uploads/../file.txt", want: true},
+ {name: "dotdot_prefix", path: "../file.txt", want: true},
+ {name: "no_dotdot", path: "uploads/.../file.txt", want: false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ if got := containsDotDot(tt.path); got != tt.want {
+ t.Fatalf("expected %v for %q, got %v", tt.want, tt.path, got)
+ }
+ })
+ }
+}
+
+func TestStorageRootPrefix(t *testing.T) {
+ t.Parallel()
+
+ root := "/var/uploads"
+ if runtime.GOOS == "windows" {
+ root = `C:\uploads`
+ }
+
+ tests := []struct {
+ name string
+ root string
+ want string
+ }{
+ {name: "empty", root: "", want: ""},
+ {name: "dot", root: ".", want: ""},
+ {name: "rooted", root: root, want: "var/uploads"},
+ {name: "relative", root: "uploads", want: "uploads"},
+ }
+
+ if runtime.GOOS == "windows" {
+ tests[2].want = "uploads"
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ if got := storageRootPrefix(tt.root); got != tt.want {
+ t.Fatalf("expected %q for %q, got %q", tt.want, tt.root, got)
+ }
+ })
+ }
+}
+
+func TestCleanUploadRootPrefix(t *testing.T) {
+ t.Parallel()
+
+ abs := "/uploads"
+ if runtime.GOOS == "windows" {
+ abs = `C:\uploads`
+ }
+
+ tests := []struct {
+ name string
+ root string
+ want string
+ wantErr bool
+ }{
+ {name: "empty", root: "", want: ""},
+ {name: "dot", root: ".", want: ""},
+ {name: "valid", root: "uploads", want: "uploads"},
+ {name: "abs", root: abs, wantErr: true},
+ {name: "dotdot", root: "../uploads", wantErr: true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ got, err := cleanUploadRootPrefix(tt.root)
+ if tt.wantErr {
+ if err == nil {
+ t.Fatalf("expected error for %q", tt.root)
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("unexpected error for %q: %v", tt.root, err)
+ }
+ if got != tt.want {
+ t.Fatalf("expected %q for %q, got %q", tt.want, tt.root, got)
+ }
+ })
+ }
+}
+
+func TestEnsureUploadPathWithinRoot(t *testing.T) {
+ t.Parallel()
+
+ root := t.TempDir()
+ rootEval, err := filepath.EvalSymlinks(root)
+ if err != nil {
+ t.Fatalf("failed to eval root: %v", err)
+ }
+
+ inside := filepath.Join(root, "uploads", "file.txt")
+ if err := ensureUploadPathWithinRoot(rootEval, inside); err != nil {
+ t.Fatalf("expected inside path to be allowed, got %v", err)
+ }
+
+ outside := filepath.Join(t.TempDir(), "file.txt")
+ if err := ensureUploadPathWithinRoot(rootEval, outside); !errors.Is(err, ErrUploadPathEscapesRoot) {
+ t.Fatalf("expected ErrUploadPathEscapesRoot, got %v", err)
+ }
+}
+
+func TestEvalExistingPath(t *testing.T) {
+ t.Parallel()
+
+ if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
+ t.Skip("path resolution differs on this platform")
+ }
+
+ root := t.TempDir()
+ target := filepath.Join(root, "missing", "file.txt")
+ got, err := evalExistingPath(target)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got != filepath.Clean(root) {
+ t.Fatalf("expected %q, got %q", filepath.Clean(root), got)
+ }
+}
+
+func TestHasPathPrefix(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ path string
+ prefix string
+ want bool
+ }{
+ {name: "same", path: filepath.Join("a", "b"), prefix: filepath.Join("a", "b"), want: true},
+ {name: "child", path: filepath.Join("a", "b", "c"), prefix: filepath.Join("a", "b"), want: true},
+ {name: "sibling", path: filepath.Join("a", "bc"), prefix: filepath.Join("a", "b"), want: false},
+ {name: "prefix_with_separator", path: filepath.Join("a", "b", "c"), prefix: filepath.Join("a", "b") + string(os.PathSeparator), want: true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ if got := hasPathPrefix(tt.path, tt.prefix); got != tt.want {
+ t.Fatalf("expected %v for %q/%q, got %v", tt.want, tt.path, tt.prefix, got)
+ }
+ })
+ }
+}
+
+func TestEnsureNoSymlinkFS(t *testing.T) {
+ t.Parallel()
+
+ t.Run("missing_parent", func(t *testing.T) {
+ t.Parallel()
+
+ fsys := fstest.MapFS{}
+ if err := ensureNoSymlinkFS(fsys, "missing/file.txt"); err != nil {
+ t.Fatalf("expected nil error, got %v", err)
+ }
+ })
+
+ t.Run("symlink_detected", func(t *testing.T) {
+ t.Parallel()
+
+ fsys := fstest.MapFS{
+ "dir": &fstest.MapFile{Mode: fs.ModeDir},
+ "dir/link": &fstest.MapFile{Mode: fs.ModeSymlink},
+ }
+ if err := ensureNoSymlinkFS(fsys, "dir/link/file.txt"); !errors.Is(err, ErrUploadPathEscapesRoot) {
+ t.Fatalf("expected ErrUploadPathEscapesRoot, got %v", err)
+ }
+ })
+
+ t.Run("read_dir_error", func(t *testing.T) {
+ t.Parallel()
+
+ fsys := readDirErrFS{err: errors.New("read failure")}
+ if err := ensureNoSymlinkFS(fsys, "dir/file.txt"); err == nil {
+ t.Fatal("expected error")
+ }
+ })
+
+ t.Run("no_symlink", func(t *testing.T) {
+ t.Parallel()
+
+ fsys := fstest.MapFS{
+ "dir": &fstest.MapFile{Mode: fs.ModeDir},
+ "dir/child": &fstest.MapFile{Mode: fs.ModeDir},
+ }
+ if err := ensureNoSymlinkFS(fsys, "dir/child/file.txt"); err != nil {
+ t.Fatalf("expected nil error, got %v", err)
+ }
+ })
+
+ t.Run("single_part_path", func(t *testing.T) {
+ t.Parallel()
+
+ if err := ensureNoSymlinkFS(fstest.MapFS{}, "file.txt"); err != nil {
+ t.Fatalf("expected nil error, got %v", err)
+ }
+ })
+
+ t.Run("leading_slash", func(t *testing.T) {
+ t.Parallel()
+
+ if err := ensureNoSymlinkFS(fstest.MapFS{}, "/file.txt"); err != nil {
+ t.Fatalf("expected nil error, got %v", err)
+ }
+ })
+}
+
+func TestProbeUploadDirWritable(t *testing.T) {
+ t.Parallel()
+
+ if runtime.GOOS == "windows" {
+ t.Skip("probe behavior differs on Windows")
+ }
+
+ t.Run("success", func(t *testing.T) {
+ t.Parallel()
+
+ if err := probeUploadDirWritable(t.TempDir()); err != nil {
+ t.Fatalf("expected nil error, got %v", err)
+ }
+ })
+
+ t.Run("not_directory", func(t *testing.T) {
+ t.Parallel()
+
+ file, err := os.CreateTemp(t.TempDir(), "probe")
+ if err != nil {
+ t.Fatalf("failed to create temp file: %v", err)
+ }
+ t.Cleanup(func() {
+ if removeErr := os.Remove(file.Name()); removeErr != nil {
+ t.Fatalf("failed to remove temp file: %v", removeErr)
+ }
+ })
+ if err := probeUploadDirWritable(file.Name()); err == nil {
+ t.Fatal("expected error for file path")
+ }
+ })
+
+ t.Run("create_error", func(t *testing.T) {
+ t.Parallel()
+
+ createErr := errors.New("create failure")
+ err := probeUploadDirWritableWith(t.TempDir(), func(_, _ string) (tempProbeFile, error) {
+ return nil, createErr
+ }, func(string) error { return nil })
+ if !errors.Is(err, createErr) {
+ t.Fatalf("expected create error, got %v", err)
+ }
+ })
+
+ t.Run("write_error", func(t *testing.T) {
+ t.Parallel()
+
+ writeErr := errors.New("write failure")
+ err := probeUploadDirWritableWith(t.TempDir(), func(_, _ string) (tempProbeFile, error) {
+ return &probeTempFile{name: stringPtr("probe"), writeErr: writeErr}, nil
+ }, func(string) error { return nil })
+ if !errors.Is(err, writeErr) {
+ t.Fatalf("expected write error, got %v", err)
+ }
+ })
+
+ t.Run("close_error_after_write", func(t *testing.T) {
+ t.Parallel()
+
+ closeErr := errors.New("close failure")
+ err := probeUploadDirWritableWith(t.TempDir(), func(_, _ string) (tempProbeFile, error) {
+ return &probeTempFile{name: stringPtr("probe"), closeErr: closeErr}, nil
+ }, func(string) error { return nil })
+ if !errors.Is(err, closeErr) {
+ t.Fatalf("expected close error, got %v", err)
+ }
+ })
+
+ t.Run("remove_error_after_write_failure", func(t *testing.T) {
+ t.Parallel()
+
+ removeErr := errors.New("remove failure")
+ err := probeUploadDirWritableWith(t.TempDir(), func(_, _ string) (tempProbeFile, error) {
+ return &probeTempFile{name: stringPtr("probe"), writeErr: errors.New("write failure")}, nil
+ }, func(string) error { return removeErr })
+ if !errors.Is(err, removeErr) {
+ t.Fatalf("expected remove error, got %v", err)
+ }
+ })
+
+ t.Run("remove_error_after_close_failure", func(t *testing.T) {
+ t.Parallel()
+
+ removeErr := errors.New("remove failure")
+ err := probeUploadDirWritableWith(t.TempDir(), func(_, _ string) (tempProbeFile, error) {
+ return &probeTempFile{name: stringPtr("probe"), closeErr: errors.New("close failure")}, nil
+ }, func(string) error { return removeErr })
+ if !errors.Is(err, removeErr) {
+ t.Fatalf("expected remove error, got %v", err)
+ }
+ })
+
+ t.Run("remove_error_after_success", func(t *testing.T) {
+ t.Parallel()
+
+ removeErr := errors.New("remove failure")
+ err := probeUploadDirWritableWith(t.TempDir(), func(_, _ string) (tempProbeFile, error) {
+ return &probeTempFile{name: stringPtr("probe")}, nil
+ }, func(string) error { return removeErr })
+ if !errors.Is(err, removeErr) {
+ t.Fatalf("expected remove error, got %v", err)
+ }
+ })
+}
+
+func TestConfigureUploads(t *testing.T) {
+ t.Parallel()
+
+ t.Run("root_dir", func(t *testing.T) {
+ t.Parallel()
+
+ root := t.TempDir()
+ app := New(Config{RootDir: root})
+ if app.config.uploadRootDir == "" {
+ t.Fatal("expected uploadRootDir to be set")
+ }
+ if app.config.uploadRootPath == "" {
+ t.Fatal("expected uploadRootPath to be set")
+ }
+ })
+
+ t.Run("missing_openfile", func(t *testing.T) {
+ t.Parallel()
+
+ assertPanics(t, func() {
+ New(Config{RootDir: "uploads", RootFs: rootFSMissingOpenFile{}})
+ })
+ })
+
+ t.Run("missing_mkdirall", func(t *testing.T) {
+ t.Parallel()
+
+ assertPanics(t, func() {
+ New(Config{RootDir: "uploads", RootFs: rootFSMissingMkdirAll{}})
+ })
+ })
+
+ t.Run("invalid_rootdir_for_rootfs", func(t *testing.T) {
+ t.Parallel()
+
+ assertPanics(t, func() {
+ New(Config{RootDir: filepath.Join(t.TempDir(), "uploads"), RootFs: rootFSNoWriter{}})
+ })
+ })
+
+ t.Run("mkdir_error", func(t *testing.T) {
+ t.Parallel()
+
+ assertPanics(t, func() {
+ New(Config{
+ RootDir: "uploads",
+ RootFs: rootFSMkdirErr{err: errors.New("mkdir failure")},
+ })
+ })
+ })
+
+ t.Run("not_writable", func(t *testing.T) {
+ t.Parallel()
+
+ assertPanics(t, func() {
+ New(Config{RootDir: "uploads", RootFs: rootFSNoWriter{}})
+ })
+ })
+}
+
+func TestEvalExistingPathError(t *testing.T) {
+ t.Parallel()
+
+ _, err := evalExistingPath("invalid\x00path")
+ if err == nil {
+ t.Fatal("expected error")
+ }
+}
+
+func assertPanics(t *testing.T, fn func()) {
+ t.Helper()
+
+ defer func() {
+ if recover() == nil {
+ t.Fatal("expected panic")
+ }
+ }()
+ fn()
+}
+
+func TestProbeUploadFSWritable(t *testing.T) {
+ t.Parallel()
+
+ t.Run("success", func(t *testing.T) {
+ t.Parallel()
+
+ fsys := probeFS{file: &writeFile{}}
+ if err := probeUploadFSWritable(fsys, ""); err != nil {
+ t.Fatalf("expected nil error, got %v", err)
+ }
+ })
+
+ t.Run("open_error", func(t *testing.T) {
+ t.Parallel()
+
+ fsys := probeFS{openErr: errors.New("open failure")}
+ if err := probeUploadFSWritable(fsys, ""); err == nil {
+ t.Fatal("expected error")
+ }
+ })
+
+ t.Run("not_writable", func(t *testing.T) {
+ t.Parallel()
+
+ fsys := probeFS{file: &noWriteFile{}}
+ if err := probeUploadFSWritable(fsys, ""); err == nil {
+ t.Fatal("expected error")
+ }
+ })
+
+ t.Run("write_error", func(t *testing.T) {
+ t.Parallel()
+
+ fsys := probeFS{file: &writeFile{writeErr: errors.New("write failure")}}
+ if err := probeUploadFSWritable(fsys, ""); err == nil {
+ t.Fatal("expected error")
+ }
+ })
+
+ t.Run("close_error", func(t *testing.T) {
+ t.Parallel()
+
+ fsys := probeFS{file: &writeFile{closeErr: errors.New("close failure")}}
+ if err := probeUploadFSWritable(fsys, ""); err == nil {
+ t.Fatal("expected error")
+ }
+ })
+
+ t.Run("remove_error", func(t *testing.T) {
+ t.Parallel()
+
+ fsys := removeErrFS{
+ probeFS: probeFS{file: &writeFile{}},
+ removeErr: errors.New("remove failure"),
+ }
+ if err := probeUploadFSWritable(fsys, ""); err == nil {
+ t.Fatal("expected error")
+ }
+ })
+}
+
+func TestStorageUploadPath(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ prefix string
+ path string
+ want string
+ }{
+ {name: "empty_prefix", prefix: "", path: "uploads/file.txt", want: "uploads/file.txt"},
+ {name: "with_prefix", prefix: "root", path: "uploads/file.txt", want: "root/uploads/file.txt"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ if got := storageUploadPath(tt.prefix, tt.path); got != tt.want {
+ t.Fatalf("expected %q, got %q", tt.want, got)
+ }
+ })
+ }
+}