Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ require (
)

require (
cloud.google.com/go/compute/metadata v0.9.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 // indirect
github.com/BurntSushi/toml v1.6.0 // indirect
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf // indirect
github.com/ProtonMail/gluon v0.17.1-0.20230724134000-308be39be96e // indirect
github.com/ProtonMail/go-mime v0.0.0-20230322103455-7d82a3887f2f // indirect
Expand Down
274 changes: 26 additions & 248 deletions go.sum

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions internal/conf/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,6 @@ const (
PathKey
SharingIDKey
SkipHookKey
VirtualHostKey
VhostPrefixKey
)
2 changes: 1 addition & 1 deletion internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ var db *gorm.DB

func Init(d *gorm.DB) {
db = d
err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.SharingDB))
err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.SharingDB), new(model.VirtualHost))
if err != nil {
log.Fatalf("failed migrate database: %s", err.Error())
}
Expand Down
45 changes: 45 additions & 0 deletions internal/db/virtual_host.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package db

import (
"github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/pkg/errors"
)

func GetVirtualHostByDomain(domain string) (*model.VirtualHost, error) {
var v model.VirtualHost
if err := db.Where("domain = ?", domain).First(&v).Error; err != nil {
return nil, errors.Wrapf(err, "failed to select virtual host")
}
return &v, nil
}

func GetVirtualHostById(id uint) (*model.VirtualHost, error) {
var v model.VirtualHost
if err := db.First(&v, id).Error; err != nil {
return nil, errors.Wrapf(err, "failed get virtual host")
}
return &v, nil
}

func CreateVirtualHost(v *model.VirtualHost) error {
return errors.WithStack(db.Create(v).Error)
}

func UpdateVirtualHost(v *model.VirtualHost) error {
return errors.WithStack(db.Save(v).Error)
}

func GetVirtualHosts(pageIndex, pageSize int) (vhosts []model.VirtualHost, count int64, err error) {
vhostDB := db.Model(&model.VirtualHost{})
if err = vhostDB.Count(&count).Error; err != nil {
return nil, 0, errors.Wrapf(err, "failed get virtual hosts count")
}
if err = vhostDB.Order(columnName("id")).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&vhosts).Error; err != nil {
return nil, 0, errors.Wrapf(err, "failed find virtual hosts")
}
return vhosts, count, nil
}

func DeleteVirtualHostById(id uint) error {
return errors.WithStack(db.Delete(&model.VirtualHost{}, id).Error)
}
9 changes: 9 additions & 0 deletions internal/model/virtual_host.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package model

type VirtualHost struct {
ID uint `json:"id" gorm:"primaryKey"`
Enabled bool `json:"enabled"`
Domain string `json:"domain" gorm:"unique" binding:"required"`
Path string `json:"path" binding:"required"`
WebHosting bool `json:"web_hosting"`
}
75 changes: 75 additions & 0 deletions internal/op/virtual_host.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package op

import (
"time"

"github.com/OpenListTeam/OpenList/v4/internal/db"
"github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/OpenListTeam/go-cache"
"github.com/pkg/errors"
"gorm.io/gorm"
)

var vhostCache = cache.NewMemCache(cache.WithShards[*model.VirtualHost](2))

// GetVirtualHostByDomain 根据域名获取虚拟主机配置(带缓存)
func GetVirtualHostByDomain(domain string) (*model.VirtualHost, error) {
if v, ok := vhostCache.Get(domain); ok {
if v == nil {
utils.Log.Debugf("[VirtualHost] cache hit (nil) for domain=%q", domain)
return nil, errors.New("virtual host not found")
}
utils.Log.Debugf("[VirtualHost] cache hit for domain=%q id=%d", domain, v.ID)
return v, nil
}
utils.Log.Debugf("[VirtualHost] cache miss for domain=%q, querying db...", domain)
v, err := db.GetVirtualHostByDomain(domain)
if err != nil {
Comment thread
PIKACHUIM marked this conversation as resolved.
Outdated
if errors.Is(errors.Cause(err), gorm.ErrRecordNotFound) {
utils.Log.Debugf("[VirtualHost] domain=%q not found in db, caching nil", domain)
vhostCache.Set(domain, nil, cache.WithEx[*model.VirtualHost](time.Minute*5))
return nil, errors.New("virtual host not found")
}
utils.Log.Errorf("[VirtualHost] db error for domain=%q: %v", domain, err)
return nil, err
}
utils.Log.Debugf("[VirtualHost] db found domain=%q id=%d enabled=%v web_hosting=%v", domain, v.ID, v.Enabled, v.WebHosting)
vhostCache.Set(domain, v, cache.WithEx[*model.VirtualHost](time.Hour))
return v, nil
}

func GetVirtualHostById(id uint) (*model.VirtualHost, error) {
return db.GetVirtualHostById(id)
}

func CreateVirtualHost(v *model.VirtualHost) error {
v.Path = utils.FixAndCleanPath(v.Path)
vhostCache.Del(v.Domain)
return db.CreateVirtualHost(v)
}

func UpdateVirtualHost(v *model.VirtualHost) error {
v.Path = utils.FixAndCleanPath(v.Path)
old, err := db.GetVirtualHostById(v.ID)
if err != nil {
return err
}
// 如果域名变更,清除旧域名缓存
vhostCache.Del(old.Domain)
vhostCache.Del(v.Domain)
return db.UpdateVirtualHost(v)
Comment thread
PIKACHUIM marked this conversation as resolved.
Outdated
}

func DeleteVirtualHostById(id uint) error {
old, err := db.GetVirtualHostById(id)
if err != nil {
return err
}
vhostCache.Del(old.Domain)
return db.DeleteVirtualHostById(id)
}

func GetVirtualHosts(pageIndex, pageSize int) ([]model.VirtualHost, int64, error) {
return db.GetVirtualHosts(pageIndex, pageSize)
}
75 changes: 74 additions & 1 deletion server/handles/fsread.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package handles

import (
"fmt"
"net"
stdpath "path"
"strings"
"time"
Expand Down Expand Up @@ -68,6 +69,8 @@ func FsListSplit(c *gin.Context) {
SharingList(c, &req)
return
}
// 虚拟主机路径重映射:根据 Host 头匹配虚拟主机规则,将请求路径映射到实际路径
req.Path = applyVhostPathMapping(c, req.Path)
user := c.Request.Context().Value(conf.UserKey).(*model.User)
if user.IsGuest() && user.Disabled {
common.ErrorStrResp(c, "Guest user is disabled, login please", 401)
Expand Down Expand Up @@ -273,6 +276,11 @@ func FsGetSplit(c *gin.Context) {
SharingGet(c, &req)
return
}
// 虚拟主机路径重映射:根据 Host 头匹配虚拟主机规则,将请求路径映射到实际路径
// 同时将 vhost.Path 前缀存入 context,供 FsGet 生成 /p/ 链接时去掉前缀
var vhostPrefix string
req.Path, vhostPrefix = applyVhostPathMappingWithPrefix(c, req.Path)
common.GinWithValue(c, conf.VhostPrefixKey, vhostPrefix)
user := c.Request.Context().Value(conf.UserKey).(*model.User)
if user.IsGuest() && user.Disabled {
common.ErrorStrResp(c, "Guest user is disabled, login please", 401)
Expand Down Expand Up @@ -322,12 +330,14 @@ func FsGet(c *gin.Context, req *FsGetReq, user *model.User) {
rawURL = common.GenerateDownProxyURL(storage.GetStorage(), reqPath)
if rawURL == "" {
query := ""
// 生成 /p/ 链接时,去掉 vhost 路径前缀,保持前端看到的路径一致
downPath := stripVhostPrefix(c, reqPath)
if isEncrypt(meta, reqPath) || setting.GetBool(conf.SignAll) {
query = "?sign=" + sign.Sign(reqPath)
}
rawURL = fmt.Sprintf("%s/p%s%s",
common.GetApiUrl(c),
utils.EncodePath(reqPath, true),
utils.EncodePath(downPath, true),
query)
}
} else {
Expand Down Expand Up @@ -432,3 +442,66 @@ func FsOther(c *gin.Context) {
}
common.SuccessResp(c, res)
}

// applyVhostPathMapping 根据请求的 Host 头匹配虚拟主机规则,将请求路径映射到实际路径。
func applyVhostPathMapping(c *gin.Context, reqPath string) string {
mapped, _ := applyVhostPathMappingWithPrefix(c, reqPath)
return mapped
}

// applyVhostPathMappingWithPrefix 根据请求的 Host 头匹配虚拟主机规则,
// 将请求路径映射到虚拟主机配置的实际路径,同时返回 vhost.Path 前缀(用于生成下载链接时去掉前缀)。
// 例如:vhost.Path="/123pan/Downloads",reqPath="/",则返回 ("/123pan/Downloads", "/123pan/Downloads")
// 例如:vhost.Path="/123pan/Downloads",reqPath="/subdir",则返回 ("/123pan/Downloads/subdir", "/123pan/Downloads")
// 如果没有匹配的虚拟主机规则,则返回 (原始路径, "")
func applyVhostPathMappingWithPrefix(c *gin.Context, reqPath string) (string, string) {
rawHost := c.Request.Host
domain := stripHostPortForVhost(rawHost)
if domain == "" {
return reqPath, ""
}
vhost, err := op.GetVirtualHostByDomain(domain)
if err != nil || vhost == nil {
return reqPath, ""
}
if !vhost.Enabled || vhost.WebHosting {
// 未启用,或者是 Web 托管模式(Web 托管不做路径重映射)
return reqPath, ""
}
// Map request path into the vhost root and verify it does not escape via traversal.
// stdpath.Join calls Clean internally, which collapses ".." segments, so we only need
// to confirm the result still lives under vhost.Path.
mapped := stdpath.Join(vhost.Path, reqPath)
if !strings.HasPrefix(mapped, strings.TrimRight(vhost.Path, "/")+"/") && mapped != vhost.Path {
utils.Log.Warnf("[VirtualHost] path traversal rejected for API remapping: domain=%q reqPath=%q", domain, reqPath)
return reqPath, ""
}
utils.Log.Debugf("[VirtualHost] API path remapping: domain=%q reqPath=%q -> mappedPath=%q", domain, reqPath, mapped)
return mapped, vhost.Path
}

Comment thread
PIKACHUIM marked this conversation as resolved.
// stripVhostPrefix 从 gin context 中取出 vhost 路径前缀,并从 path 中去掉该前缀。
// 用于生成 /p/ 下载链接时,将真实路径还原为前端看到的路径。
func stripVhostPrefix(c *gin.Context, path string) string {
prefix, ok := c.Request.Context().Value(conf.VhostPrefixKey).(string)
if !ok || prefix == "" {
return path
}
if strings.HasPrefix(path, prefix+"/") {
return path[len(prefix):]
}
if path == prefix {
return "/"
}
return path
}

// stripHostPortForVhost removes the port from a host string (supports IPv4, IPv6, and bracketed IPv6).
func stripHostPortForVhost(host string) string {
h, _, err := net.SplitHostPort(host)
if err != nil {
// No port present; return host as-is
return host
}
return h
Comment thread
PIKACHUIM marked this conversation as resolved.
Outdated
}
83 changes: 83 additions & 0 deletions server/handles/virtual_host.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package handles

import (
"strconv"

"github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/internal/op"
"github.com/OpenListTeam/OpenList/v4/server/common"
"github.com/gin-gonic/gin"
)

func ListVirtualHosts(c *gin.Context) {
var req model.PageReq
if err := c.ShouldBind(&req); err != nil {
common.ErrorResp(c, err, 400)
return
}
req.Validate()
vhosts, total, err := op.GetVirtualHosts(req.Page, req.PerPage)
if err != nil {
common.ErrorResp(c, err, 500, true)
return
}
common.SuccessResp(c, common.PageResp{
Content: vhosts,
Total: total,
})
}

func GetVirtualHost(c *gin.Context) {
idStr := c.Query("id")
id, err := strconv.Atoi(idStr)
if err != nil {
common.ErrorResp(c, err, 400)
return
}
vhost, err := op.GetVirtualHostById(uint(id))
if err != nil {
common.ErrorResp(c, err, 500, true)
return
}
common.SuccessResp(c, vhost)
}

func CreateVirtualHost(c *gin.Context) {
var req model.VirtualHost
if err := c.ShouldBind(&req); err != nil {
common.ErrorResp(c, err, 400)
return
}
if err := op.CreateVirtualHost(&req); err != nil {
common.ErrorResp(c, err, 500, true)
} else {
common.SuccessResp(c)
}
}

func UpdateVirtualHost(c *gin.Context) {
var req model.VirtualHost
if err := c.ShouldBind(&req); err != nil {
common.ErrorResp(c, err, 400)
return
}
if err := op.UpdateVirtualHost(&req); err != nil {
common.ErrorResp(c, err, 500, true)
} else {
common.SuccessResp(c)
}
}

func DeleteVirtualHost(c *gin.Context) {
idStr := c.Query("id")
id, err := strconv.Atoi(idStr)
if err != nil {
common.ErrorResp(c, err, 400)
return
}
if err := op.DeleteVirtualHostById(uint(id)); err != nil {
common.ErrorResp(c, err, 500, true)
return
}
common.SuccessResp(c)
}
Loading