Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/conf/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,5 @@ const (
PathKey
SharingIDKey
SkipHookKey
VhostPrefixKey
)
24 changes: 24 additions & 0 deletions internal/db/sharing.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ func GetSharingById(id string) (*model.SharingDB, error) {
return &s, nil
}

// GetSharingByDomain 根据绑定的域名查询 sharing 记录(用于虚拟主机能力)。
// 仅当 sharing.Domain 字段精确匹配时返回;调用方需自行判断 Disabled / Expires / Files 等有效性。
func GetSharingByDomain(domain string) (*model.SharingDB, error) {
var s model.SharingDB
if err := db.Where("domain = ?", domain).First(&s).Error; err != nil {
return nil, errors.Wrapf(err, "failed get sharing by domain")
}
return &s, nil
}

func GetSharings(pageIndex, pageSize int) (sharings []model.SharingDB, count int64, err error) {
sharingDB := db.Model(&model.SharingDB{})
if err := sharingDB.Count(&count).Error; err != nil {
Expand All @@ -38,6 +48,13 @@ func GetSharingsByCreatorId(creator uint, pageIndex, pageSize int) (sharings []m
}

func CreateSharing(s *model.SharingDB) (string, error) {
// domain 非空时做唯一性提前校验
if s.Domain != "" {
var exist model.SharingDB
if err := db.Where("domain = ?", s.Domain).First(&exist).Error; err == nil {
return "", errors.New("domain already used")
}
}
if s.ID == "" {
id := random.String(8)
for len(id) < 12 {
Expand All @@ -61,6 +78,13 @@ func CreateSharing(s *model.SharingDB) (string, error) {
}

func UpdateSharing(s *model.SharingDB) error {
// domain 非空时校验唯一性(排除自身)
if s.Domain != "" {
var exist model.SharingDB
if err := db.Where("domain = ? AND id <> ?", s.Domain, s.ID).First(&exist).Error; err == nil {
return errors.New("domain already used")
}
}
return errors.WithStack(db.Save(s).Error)
}

Expand Down
21 changes: 21 additions & 0 deletions internal/model/sharing.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ type SharingDB struct {
Remark string `json:"remark"`
Readme string `json:"readme" gorm:"type:text"`
Header string `json:"header" gorm:"type:text"`
// Domain 绑定的域名,可为空;非空时该条记录额外作为虚拟主机参与 Host 匹配(与旧 VirtualHost.Domain 等价)。
// 唯一性由应用层在 Create/Update 时校验,避免空字符串在 MySQL 下触发 uniqueIndex 冲突。
Domain string `json:"domain" gorm:"index"`
// WebHosting 仅在 Domain 非空时有效;为 true 时启用 Web 托管模式(直接响应文件内容),为 false 时仅做路径重映射。
WebHosting bool `json:"web_hosting"`
Sort
}

Expand Down Expand Up @@ -42,6 +47,22 @@ func (s *Sharing) Valid() bool {
return true
}

// ValidForVhost 虚拟主机场景的有效性检查。
// 与 Valid() 的区别:不检查 Creator.CanShare(),因为 Web Hosting / 路径重映射
// 是服务端功能,不依赖创建者的分享权限位。
func (s *Sharing) ValidForVhost() bool {
if s.Disabled {
return false
}
if len(s.Files) == 0 {
return false
}
if s.Expires != nil && !s.Expires.IsZero() && s.Expires.Before(time.Now()) {
return false
}
return true
}

func (s *Sharing) Verify(pwd string) bool {
return s.Pwd == "" || s.Pwd == pwd
}
99 changes: 94 additions & 5 deletions internal/op/sharing.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
stdpath "path"
"strings"
"time"

"github.com/OpenListTeam/OpenList/v4/internal/db"
"github.com/OpenListTeam/OpenList/v4/internal/model"
Expand All @@ -12,6 +13,7 @@ import (
"github.com/OpenListTeam/go-cache"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
)

func makeJoined(sdb []model.SharingDB) []model.Sharing {
Expand Down Expand Up @@ -42,6 +44,11 @@ func makeJoined(sdb []model.SharingDB) []model.Sharing {
var sharingCache = cache.NewMemCache(cache.WithShards[*model.Sharing](8))
var sharingG singleflight.Group[*model.Sharing]

// domainSharingCache 按虚拟主机 domain 作为 key 缓存对应的 *model.Sharing。
// 允许缓存为 nil 以实现"负缓存"防止穿透。
var domainSharingCache = cache.NewMemCache(cache.WithShards[*model.Sharing](2))
var domainSharingG singleflight.Group[*model.Sharing]

func GetSharingById(id string, refresh ...bool) (*model.Sharing, error) {
if !utils.IsBool(refresh...) {
if sharing, ok := sharingCache.Get(id); ok {
Expand Down Expand Up @@ -71,6 +78,66 @@ func GetSharingById(id string, refresh ...bool) (*model.Sharing, error) {
return sharing, err
}

// GetSharingByDomain 根据 domain 获取可用的虚拟主机 sharing(带缓存)。
// 仅当 sharing.Domain 非空、Disabled=false、Files 非空、Expires 未过期时才视为有效。
// 如果在 DB 中未找到,会负缓存 5 分钟,避免反复穿透 DB。
func GetSharingByDomain(domain string) (*model.Sharing, error) {
domain = strings.ToLower(strings.TrimSpace(domain))
if domain == "" {
return nil, errors.New("empty domain")
}
if s, ok := domainSharingCache.Get(domain); ok {
if s == nil {
log.Debugf("[Sharing] domain cache hit (nil) for %q", domain)
return nil, errors.New("sharing not found by domain")
}
log.Debugf("[Sharing] domain cache hit for %q id=%s", domain, s.ID)
if !s.ValidForVhost() {
return nil, errors.New("sharing not valid")
}
return s, nil
}
sharing, err, _ := domainSharingG.Do(domain, func() (*model.Sharing, error) {
sdb, err := db.GetSharingByDomain(domain)
if err != nil {
if errors.Is(errors.Cause(err), gorm.ErrRecordNotFound) {
log.Debugf("[Sharing] domain=%q not found in db, caching nil", domain)
domainSharingCache.Set(domain, nil, cache.WithEx[*model.Sharing](time.Minute*5))
return nil, errors.New("sharing not found by domain")
}
return nil, errors.WithMessagef(err, "failed get sharing by domain [%s]", domain)
}
// 虚拟主机场景不需要 creator,跳过 creator 查询以避免 CanShare 校验阻断 Web Hosting
var files []string
if err = utils.Json.UnmarshalFromString(sdb.FilesRaw, &files); err != nil {
files = make([]string, 0)
}
s := &model.Sharing{
SharingDB: sdb,
Files: files,
Creator: nil, // 虚拟主机匹配不依赖 creator 权限
}
domainSharingCache.Set(domain, s, cache.WithEx[*model.Sharing](time.Hour))
return s, nil
})
if err != nil {
return nil, err
}
if sharing == nil || !sharing.ValidForVhost() {
return nil, errors.New("sharing not valid for domain")
}
return sharing, nil
}

// invalidateDomainCache 在创建/更新/删除记录时调用,同时传入新/旧 domain 以使两者都失效。
func invalidateDomainCache(domains ...string) {
for _, d := range domains {
if d != "" {
domainSharingCache.Del(d)
}
}
}

func GetSharings(pageIndex, pageSize int) ([]model.Sharing, int64, error) {
s, cnt, err := db.GetSharings(pageIndex, pageSize)
if err != nil {
Expand Down Expand Up @@ -118,7 +185,11 @@ func CreateSharing(sharing *model.Sharing) (id string, err error) {
if err != nil {
return "", errors.WithStack(err)
}
return db.CreateSharing(sharing.SharingDB)
id, err = db.CreateSharing(sharing.SharingDB)
if err == nil {
invalidateDomainCache(sharing.Domain)
}
return id, err
}

func UpdateSharing(sharing *model.Sharing, skipMarshal ...bool) (err error) {
Expand All @@ -129,8 +200,17 @@ func UpdateSharing(sharing *model.Sharing, skipMarshal ...bool) (err error) {
return errors.WithStack(err)
}
}
sharingCache.Del(sharing.ID)
return db.UpdateSharing(sharing.SharingDB)
// 读取旧记录以便同时失效旧 domain 缓存
var oldDomain string
if old, e := db.GetSharingById(sharing.ID); e == nil {
oldDomain = old.Domain
}
err = db.UpdateSharing(sharing.SharingDB)
if err == nil {
sharingCache.Del(sharing.ID)
invalidateDomainCache(oldDomain, sharing.Domain)
}
return err
}

func UpdateSharingId(sharing *model.Sharing, newId string) error {
Expand All @@ -143,8 +223,17 @@ func UpdateSharingId(sharing *model.Sharing, newId string) error {
}

func DeleteSharing(sid string) error {
sharingCache.Del(sid)
return db.DeleteSharingById(sid)
// 先读取 domain 用于失效缓存
var oldDomain string
if old, e := db.GetSharingById(sid); e == nil {
oldDomain = old.Domain
}
err := db.DeleteSharingById(sid)
if err == nil {
sharingCache.Del(sid)
invalidateDomainCache(oldDomain)
}
return err
}

func DeleteSharingsByCreatorId(creatorId uint) error {
Expand Down
12 changes: 12 additions & 0 deletions server/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"html"
stdnet "net"
"net/http"
"strings"

Expand Down Expand Up @@ -155,3 +156,14 @@ func ContentWithValues(ctx context.Context, keyAndValue ...any) context.Context
}
return ctx
}

// StripHostPort 从 Host 头中去掉端口部分,返回纯域名。
// 支持 IPv4、IPv6([::1]:port)及无端口的裸域名/IP。
func StripHostPort(host string) string {
h, _, err := stdnet.SplitHostPort(host)
if err != nil {
// 无端口,原样返回
return host
}
return h
}
70 changes: 69 additions & 1 deletion server/handles/fsread.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ func FsListSplit(c *gin.Context) {
SharingList(c, &req)
return
}
// 虚拟主机路径重映射:根据 Host 头匹配虚拟主机规则,将请求路径映射到实际路径
req.Path = applyVhostPathMapping(c, req.Path)
user := c.Request.Context().Value(conf.UserKey).(*model.User)
if user.IsGuest() && user.Disabled {
common.ErrorStrResp(c, "Guest user is disabled, login please", 401)
Expand Down Expand Up @@ -272,6 +274,11 @@ func FsGetSplit(c *gin.Context) {
SharingGet(c, &req)
return
}
// 虚拟主机路径重映射:根据 Host 头匹配虚拟主机规则,将请求路径映射到实际路径
// 同时将 vhost.Path 前缀存入 context,供 FsGet 生成 /p/ 链接时去掉前缀
var vhostPrefix string
req.Path, vhostPrefix = applyVhostPathMappingWithPrefix(c, req.Path)
common.GinAppendValues(c, conf.VhostPrefixKey, vhostPrefix)
user := c.Request.Context().Value(conf.UserKey).(*model.User)
if user.IsGuest() && user.Disabled {
common.ErrorStrResp(c, "Guest user is disabled, login please", 401)
Expand Down Expand Up @@ -319,12 +326,14 @@ func FsGet(c *gin.Context, req *FsGetReq, user *model.User) {
rawURL = common.GenerateDownProxyURL(storage.GetStorage(), reqPath)
if rawURL == "" {
query := ""
// 生成 /p/ 链接时,去掉 vhost 路径前缀,保持前端看到的路径一致
downPath := stripVhostPrefix(c, reqPath)
if isEncrypt(meta, reqPath) || setting.GetBool(conf.SignAll) {
query = "?sign=" + sign.Sign(reqPath)
}
rawURL = fmt.Sprintf("%s/p%s%s",
common.GetApiUrl(c),
utils.EncodePath(reqPath, true),
utils.EncodePath(downPath, true),
query)
}
} else {
Expand Down Expand Up @@ -427,3 +436,62 @@ func FsOther(c *gin.Context) {
}
common.SuccessResp(c, res)
}

// applyVhostPathMapping 根据请求的 Host 头匹配虚拟主机规则,将请求路径映射到实际路径。
func applyVhostPathMapping(c *gin.Context, reqPath string) string {
mapped, _ := applyVhostPathMappingWithPrefix(c, reqPath)
return mapped
}

// applyVhostPathMappingWithPrefix 根据请求的 Host 头匹配 sharing 中带 Domain 的虚拟主机记录,
// 将请求路径映射到 sharing.Files[0] 之下,同时返回该路径前缀(用于生成下载链接时去掉前缀)。
// 例如:sharing.Files[0]="/123pan/Downloads",reqPath="/",则返回 ("/123pan/Downloads", "/123pan/Downloads")
// 例如:sharing.Files[0]="/123pan/Downloads",reqPath="/subdir",则返回 ("/123pan/Downloads/subdir", "/123pan/Downloads")
// 如果没有匹配的虚拟主机规则,则返回 (原始路径, "")
func applyVhostPathMappingWithPrefix(c *gin.Context, reqPath string) (string, string) {
rawHost := c.Request.Host
domain := common.StripHostPort(rawHost)
if domain == "" {
return reqPath, ""
}
sharing, err := op.GetSharingByDomain(domain)
if err != nil || sharing == nil {
return reqPath, ""
}
if sharing.WebHosting {
// Web 托管模式不做 API 路径重映射
return reqPath, ""
}
if len(sharing.Files) == 0 {
return reqPath, ""
}
root := sharing.Files[0]
// Map request path into the sharing root and verify it does not escape via traversal.
// stdpath.Join calls Clean internally, which collapses ".." segments, so we only need
// to confirm the result still lives under root.
mapped := stdpath.Join(root, reqPath)
if !strings.HasPrefix(mapped, strings.TrimRight(root, "/")+"/") && mapped != root {
utils.Log.Warnf("[VirtualHost] path traversal rejected for API remapping: domain=%q reqPath=%q", domain, reqPath)
return reqPath, ""
}
utils.Log.Debugf("[VirtualHost] API path remapping: domain=%q reqPath=%q -> mappedPath=%q", domain, reqPath, mapped)
return mapped, root
}

Comment thread
PIKACHUIM marked this conversation as resolved.
// stripVhostPrefix 从 gin context 中取出 vhost 路径前缀,并从 path 中去掉该前缀。
// 用于生成 /p/ 下载链接时,将真实路径还原为前端看到的路径。
func stripVhostPrefix(c *gin.Context, path string) string {
prefix, ok := c.Request.Context().Value(conf.VhostPrefixKey).(string)
if !ok || prefix == "" {
return path
}
if strings.HasPrefix(path, prefix+"/") {
return path[len(prefix):]
}
if path == prefix {
return "/"
}
return path
}


Loading
Loading