diff --git a/internal/conf/config.go b/internal/conf/config.go index 4a1172815..0f6676116 100644 --- a/internal/conf/config.go +++ b/internal/conf/config.go @@ -106,6 +106,10 @@ type SFTP struct { Listen string `json:"listen" env:"LISTEN"` } +type MCP struct { + Enable bool `json:"enable" env:"ENABLE"` +} + type Config struct { Force bool `json:"force" env:"FORCE"` SiteURL string `json:"site_url" env:"SITE_URL"` @@ -131,6 +135,7 @@ type Config struct { S3 S3 `json:"s3" envPrefix:"S3_"` FTP FTP `json:"ftp" envPrefix:"FTP_"` SFTP SFTP `json:"sftp" envPrefix:"SFTP_"` + MCP MCP `json:"mcp" envPrefix:"MCP_"` LastLaunchedVersion string `json:"last_launched_version"` ProxyAddress string `json:"proxy_address" env:"PROXY_ADDRESS"` } @@ -244,6 +249,9 @@ func DefaultConfig(dataDir string) *Config { Enable: false, Listen: ":5222", }, + MCP: MCP{ + Enable: false, + }, LastLaunchedVersion: "", ProxyAddress: "", } diff --git a/server/mcp.go b/server/mcp.go new file mode 100644 index 000000000..c036ce389 --- /dev/null +++ b/server/mcp.go @@ -0,0 +1,18 @@ +package server + +import ( + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/server/common" + "github.com/OpenListTeam/OpenList/v4/server/mcp" + "github.com/gin-gonic/gin" +) + +func MCP(g *gin.RouterGroup) { + if !conf.Conf.MCP.Enable { + g.Any("/mcp", func(c *gin.Context) { + common.ErrorStrResp(c, "MCP server is not enabled", 403) + }) + return + } + mcp.Register(g) +} diff --git a/server/mcp/call.go b/server/mcp/call.go new file mode 100644 index 000000000..5a19b9b4b --- /dev/null +++ b/server/mcp/call.go @@ -0,0 +1,88 @@ +package mcp + +import ( + "encoding/json" + "net/http" + + "github.com/gin-gonic/gin" +) + +type toolCallParams struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` +} + +type toolResultContent struct { + Type string `json:"type"` + Text string `json:"text"` +} + +func (s *Server) handleToolsCall(c *gin.Context, req request) (int, response) { + var params toolCallParams + if len(req.Params) == 0 { + return http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32602, Message: "invalid tools/call params"}, + } + } + if err := json.Unmarshal(req.Params, ¶ms); err != nil || params.Name == "" { + return http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32602, Message: "invalid tools/call params"}, + } + } + + var ( + result any + err *rpcError + ) + switch params.Name { + case "openlist.fs.list": + result, err = s.callFSList(c, params.Arguments) + case "openlist.fs.get": + result, err = s.callFSGet(c, params.Arguments) + case "openlist.fs.link": + result, err = s.callFSLink(c, params.Arguments) + default: + return http.StatusOK, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32601, Message: "unknown tool"}, + } + } + + if err != nil { + return http.StatusOK, response{ + JSONRPC: "2.0", + ID: req.ID, + Result: map[string]any{ + "content": []toolResultContent{ + {Type: "text", Text: err.Message}, + }, + "isError": true, + }, + } + } + + resultJSON, marshalErr := json.Marshal(result) + if marshalErr != nil { + return http.StatusInternalServerError, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32603, Message: "failed to encode tool result"}, + } + } + + return http.StatusOK, response{ + JSONRPC: "2.0", + ID: req.ID, + Result: map[string]any{ + "content": []toolResultContent{ + {Type: "text", Text: string(resultJSON)}, + }, + "structuredContent": result, + }, + } +} diff --git a/server/mcp/call_test.go b/server/mcp/call_test.go new file mode 100644 index 000000000..b922ffd52 --- /dev/null +++ b/server/mcp/call_test.go @@ -0,0 +1,183 @@ +package mcp + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/server/common" + "github.com/gin-gonic/gin" +) + +func TestToolsListRequiresInitializedSession(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(map[string]*session{ + "s1": {id: "s1", userID: 1}, + }) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"tools/list" + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(ProtocolVersionHeader, ProtocolVersion) + req.Header.Set(SessionHeader, "s1") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusBadRequest) + } + resp := decodeResponse(t, w) + if resp.Error == nil || resp.Error.Code != -32002 { + t.Fatalf("unexpected error response: %+v", resp.Error) + } +} + +func TestToolsListSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(map[string]*session{ + "s2": {id: "s2", userID: 1, initialized: true}, + }) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":2, + "method":"tools/list" + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(ProtocolVersionHeader, ProtocolVersion) + req.Header.Set(SessionHeader, "s2") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) + } + resp := decodeResponse(t, w) + if resp.Error != nil { + t.Fatalf("unexpected error response: %+v", resp.Error) + } + + result, ok := resp.Result.(map[string]any) + if !ok { + t.Fatalf("unexpected result type: %T", resp.Result) + } + tools, ok := result["tools"].([]any) + if !ok || len(tools) != 3 { + t.Fatalf("unexpected tools payload: %#v", result["tools"]) + } + names := map[string]bool{} + for _, rawTool := range tools { + currentTool, ok := rawTool.(map[string]any) + if !ok { + t.Fatalf("unexpected tool payload: %#v", rawTool) + } + name, _ := currentTool["name"].(string) + names[name] = true + } + if !names["openlist.fs.list"] || !names["openlist.fs.get"] || !names["openlist.fs.link"] { + t.Fatalf("unexpected tool names: %#v", names) + } +} + +func TestToolsCallUnknownTool(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(map[string]*session{ + "s3": {id: "s3", userID: 1, initialized: true}, + }) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":3, + "method":"tools/call", + "params":{"name":"openlist.fs.unknown","arguments":{"path":"/"}} + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(ProtocolVersionHeader, ProtocolVersion) + req.Header.Set(SessionHeader, "s3") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) + } + resp := decodeResponse(t, w) + if resp.Error == nil || resp.Error.Code != -32601 { + t.Fatalf("unexpected error response: %+v", resp.Error) + } +} + +func TestToolsCallInvalidParams(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(map[string]*session{ + "s4": {id: "s4", userID: 1, initialized: true}, + }) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":4, + "method":"tools/call", + "params":{"name":"openlist.fs.list","arguments":"bad"} + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(ProtocolVersionHeader, ProtocolVersion) + req.Header.Set(SessionHeader, "s4") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) + } + resp := decodeResponse(t, w) + if resp.Error != nil { + t.Fatalf("expected tool error result, got protocol error: %+v", resp.Error) + } +} + +func decodeResponse(t *testing.T, w *httptest.ResponseRecorder) response { + t.Helper() + + var resp response + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + return resp +} diff --git a/server/mcp/fs_get.go b/server/mcp/fs_get.go new file mode 100644 index 000000000..550cc036f --- /dev/null +++ b/server/mcp/fs_get.go @@ -0,0 +1,168 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + stdpath "path" + "strings" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/fs" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/OpenListTeam/OpenList/v4/internal/setting" + "github.com/OpenListTeam/OpenList/v4/internal/sign" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/OpenListTeam/OpenList/v4/server/common" + "github.com/OpenListTeam/OpenList/v4/server/handles" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +type fsGetArgs struct { + Path string `json:"path"` + Password string `json:"password"` +} + +func (s *Server) callFSGet(c *gin.Context, raw json.RawMessage) (any, *rpcError) { + args, mcpErr := parseFSGetArgs(raw) + if mcpErr != nil { + return nil, mcpErr + } + + user, ok := c.Request.Context().Value(conf.UserKey).(*model.User) + if !ok || user == nil { + return nil, &rpcError{Code: -32603, Message: "missing user context"} + } + if user.IsGuest() && user.Disabled { + return nil, &rpcError{Code: -32001, Message: "guest user is disabled"} + } + + reqPath, err := user.JoinPath(args.Path) + if err != nil { + return nil, &rpcError{Code: -32003, Message: err.Error()} + } + + meta, err := op.GetNearestMeta(reqPath) + if err != nil && !errors.Is(errors.Cause(err), errs.MetaNotFound) { + return nil, &rpcError{Code: -32603, Message: err.Error()} + } + if !common.CanAccess(user, meta, reqPath, args.Password) { + return nil, &rpcError{Code: -32003, Message: "password is incorrect or you have no permission"} + } + + ctx := context.WithValue(c.Request.Context(), conf.MetaKey, meta) + obj, err := fs.Get(ctx, reqPath, &fs.GetArgs{ + WithStorageDetails: !user.IsGuest() && !setting.GetBool(conf.HideStorageDetails), + }) + if err != nil { + return nil, &rpcError{Code: -32603, Message: err.Error()} + } + + rawURL, provider, err := buildFSGetRawURL(ctx, c, reqPath, obj, meta) + if err != nil { + return nil, &rpcError{Code: -32603, Message: err.Error()} + } + + parentPath := stdpath.Dir(reqPath) + var related []model.Obj + sameLevelFiles, err := fs.List(ctx, parentPath, &fs.ListArgs{}) + if err == nil { + related = filterRelatedObjs(sameLevelFiles, obj) + } + + parentMeta, _ := op.GetNearestMeta(parentPath) + thumb, _ := model.GetThumb(obj) + mountDetails, _ := model.GetStorageDetails(obj) + return handles.FsGetResp{ + ObjResp: handles.ObjResp{ + Name: obj.GetName(), + Size: obj.GetSize(), + IsDir: obj.IsDir(), + Modified: obj.ModTime(), + Created: obj.CreateTime(), + Sign: common.Sign(obj, parentPath, isEncrypt(meta, reqPath)), + Thumb: thumb, + Type: utils.GetFileType(obj.GetName()), + HashInfoStr: obj.GetHash().String(), + HashInfo: obj.GetHash().Export(), + MountDetails: mountDetails, + }, + RawURL: rawURL, + Readme: getReadme(meta, reqPath), + Header: getHeader(meta, reqPath), + Provider: provider, + Related: toObjResp(related, parentPath, isEncrypt(parentMeta, parentPath)), + }, nil +} + +func parseFSGetArgs(raw json.RawMessage) (*fsGetArgs, *rpcError) { + args := &fsGetArgs{} + if len(raw) == 0 || string(raw) == "null" { + return nil, &rpcError{Code: -32602, Message: "invalid openlist.fs.get arguments"} + } + + if err := json.Unmarshal(raw, args); err != nil { + return nil, &rpcError{Code: -32602, Message: "invalid openlist.fs.get arguments"} + } + if args.Path == "" { + return nil, &rpcError{Code: -32602, Message: "path is required"} + } + return args, nil +} + +func buildFSGetRawURL(ctx context.Context, c *gin.Context, reqPath string, obj model.Obj, meta *model.Meta) (string, string, error) { + storage, storageErr := fs.GetStorage(reqPath, &fs.GetStoragesArgs{}) + provider, ok := model.GetProvider(obj) + if !ok && storageErr == nil { + provider = storage.Config().Name + } + if obj.IsDir() { + return "", provider, nil + } + if storageErr != nil { + return "", provider, storageErr + } + + if storage.Config().MustProxy() || storage.GetStorage().WebProxy { + rawURL := common.GenerateDownProxyURL(storage.GetStorage(), reqPath) + if rawURL != "" { + return rawURL, provider, nil + } + query := "" + if isEncrypt(meta, reqPath) || setting.GetBool(conf.SignAll) { + query = "?sign=" + sign.Sign(reqPath) + } + return fmt.Sprintf("%s/p%s%s", common.GetApiUrl(ctx), utils.EncodePath(reqPath, true), query), provider, nil + } + + if url, ok := model.GetUrl(obj); ok { + return url, provider, nil + } + link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{ + IP: c.ClientIP(), + Header: c.Request.Header, + Redirect: true, + }) + if err != nil { + return "", provider, err + } + defer link.Close() + return link.URL, provider, nil +} + +func filterRelatedObjs(objs []model.Obj, obj model.Obj) []model.Obj { + related := make([]model.Obj, 0) + nameWithoutExt := strings.TrimSuffix(obj.GetName(), stdpath.Ext(obj.GetName())) + for _, current := range objs { + if current.GetName() == obj.GetName() { + continue + } + if strings.HasPrefix(current.GetName(), nameWithoutExt) { + related = append(related, current) + } + } + return related +} diff --git a/server/mcp/fs_link.go b/server/mcp/fs_link.go new file mode 100644 index 000000000..43ddad852 --- /dev/null +++ b/server/mcp/fs_link.go @@ -0,0 +1,200 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + stdpath "path" + "time" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/driver" + "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/fs" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/OpenListTeam/OpenList/v4/internal/setting" + "github.com/OpenListTeam/OpenList/v4/internal/sign" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/OpenListTeam/OpenList/v4/server/common" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +type fsLinkArgs struct { + Path string `json:"path"` + Password string `json:"password"` + Type string `json:"type"` +} + +type fsLinkResp struct { + Path string `json:"path"` + Name string `json:"name"` + Size int64 `json:"size"` + IsDir bool `json:"is_dir"` + Modified time.Time `json:"modified"` + Provider string `json:"provider"` + URL string `json:"url"` + URLType string `json:"url_type"` + DirectURL string `json:"direct_url,omitempty"` + ProxyURL string `json:"proxy_url,omitempty"` + DownloadURL string `json:"download_url,omitempty"` + Header http.Header `json:"header,omitempty"` + ContentLength int64 `json:"content_length,omitempty"` + Concurrency int `json:"concurrency,omitempty"` + PartSize int `json:"part_size,omitempty"` +} + +func (s *Server) callFSLink(c *gin.Context, raw json.RawMessage) (any, *rpcError) { + args, mcpErr := parseFSLinkArgs(raw) + if mcpErr != nil { + return nil, mcpErr + } + + user, ok := c.Request.Context().Value(conf.UserKey).(*model.User) + if !ok || user == nil { + return nil, &rpcError{Code: -32603, Message: "missing user context"} + } + if user.IsGuest() && user.Disabled { + return nil, &rpcError{Code: -32001, Message: "guest user is disabled"} + } + + reqPath, err := user.JoinPath(args.Path) + if err != nil { + return nil, &rpcError{Code: -32003, Message: err.Error()} + } + + meta, err := op.GetNearestMeta(reqPath) + if err != nil && !errors.Is(errors.Cause(err), errs.MetaNotFound) { + return nil, &rpcError{Code: -32603, Message: err.Error()} + } + if !common.CanAccess(user, meta, reqPath, args.Password) { + return nil, &rpcError{Code: -32003, Message: "password is incorrect or you have no permission"} + } + + ctx := context.WithValue(c.Request.Context(), conf.MetaKey, meta) + obj, err := fs.Get(ctx, reqPath, &fs.GetArgs{ + WithStorageDetails: !user.IsGuest() && !setting.GetBool(conf.HideStorageDetails), + }) + if err != nil { + return nil, &rpcError{Code: -32603, Message: err.Error()} + } + if obj.IsDir() { + return nil, &rpcError{Code: -32003, Message: "path is a directory"} + } + + storage, err := fs.GetStorage(reqPath, &fs.GetStoragesArgs{}) + if err != nil { + return nil, &rpcError{Code: -32603, Message: err.Error()} + } + + linkInfo, err := buildFSLinkInfo(ctx, c, reqPath, args, obj, meta, storage) + if err != nil { + return nil, &rpcError{Code: -32603, Message: err.Error()} + } + return linkInfo, nil +} + +func parseFSLinkArgs(raw json.RawMessage) (*fsLinkArgs, *rpcError) { + args := &fsLinkArgs{} + if len(raw) == 0 || string(raw) == "null" { + return nil, &rpcError{Code: -32602, Message: "invalid openlist.fs.link arguments"} + } + if err := json.Unmarshal(raw, args); err != nil { + return nil, &rpcError{Code: -32602, Message: "invalid openlist.fs.link arguments"} + } + if args.Path == "" { + return nil, &rpcError{Code: -32602, Message: "path is required"} + } + return args, nil +} + +func buildFSLinkInfo(ctx context.Context, c *gin.Context, reqPath string, args *fsLinkArgs, obj model.Obj, meta *model.Meta, storage driver.Driver) (*fsLinkResp, error) { + provider, ok := model.GetProvider(obj) + if !ok { + provider = storage.Config().Name + } + + resp := &fsLinkResp{ + Path: reqPath, + Name: obj.GetName(), + Size: obj.GetSize(), + IsDir: obj.IsDir(), + Modified: obj.ModTime(), + Provider: provider, + DownloadURL: signedFileURL(ctx, "/d", reqPath, meta, args.Type), + } + + if canProxyFile(storage, stdpath.Base(reqPath)) { + proxyURL := proxyFileURL(ctx, reqPath, meta, storage.GetStorage(), args.Type) + resp.ProxyURL = proxyURL + } + + if common.ShouldProxy(storage, stdpath.Base(reqPath)) { + resp.URL = resp.ProxyURL + resp.URLType = "proxy" + return resp, nil + } + + link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{ + IP: c.ClientIP(), + Header: c.Request.Header, + Type: args.Type, + Redirect: true, + }) + if err != nil { + return nil, err + } + defer link.Close() + + resp.DirectURL = link.URL + resp.URL = link.URL + resp.URLType = "direct" + resp.Header = link.Header + resp.ContentLength = link.ContentLength + resp.Concurrency = link.Concurrency + resp.PartSize = link.PartSize + return resp, nil +} + +func canProxyFile(storage driver.Driver, filename string) bool { + if storage.Config().MustProxy() || storage.GetStorage().WebProxy { + return true + } + if utils.SliceContains(conf.SlicesMap[conf.ProxyTypes], utils.Ext(filename)) { + return true + } + if utils.SliceContains(conf.SlicesMap[conf.TextTypes], utils.Ext(filename)) { + return true + } + return false +} + +func proxyFileURL(ctx context.Context, reqPath string, meta *model.Meta, storage *model.Storage, linkType string) string { + if url := common.GenerateDownProxyURL(storage, reqPath); url != "" { + return url + } + return signedFileURL(ctx, "/p", reqPath, meta, linkType) +} + +func signedFileURL(ctx context.Context, prefix, reqPath string, meta *model.Meta, linkType string) string { + query := url.Values{} + if isEncrypt(meta, reqPath) || setting.GetBool(conf.SignAll) { + query.Set("sign", sign.Sign(reqPath)) + } + if linkType != "" { + query.Set("type", linkType) + } + rawQuery := "" + if encoded := query.Encode(); encoded != "" { + rawQuery = "?" + encoded + } + return fmt.Sprintf("%s%s%s%s", + common.GetApiUrl(ctx), + prefix, + utils.EncodePath(reqPath, true), + rawQuery, + ) +} diff --git a/server/mcp/fs_list.go b/server/mcp/fs_list.go new file mode 100644 index 000000000..ac28cf756 --- /dev/null +++ b/server/mcp/fs_list.go @@ -0,0 +1,182 @@ +package mcp + +import ( + "context" + "encoding/json" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/fs" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/OpenListTeam/OpenList/v4/internal/setting" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/OpenListTeam/OpenList/v4/server/common" + "github.com/OpenListTeam/OpenList/v4/server/handles" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +type fsListArgs struct { + Path string `json:"path"` + Password string `json:"password"` + Refresh bool `json:"refresh"` + Page int `json:"page"` + PerPage int `json:"per_page"` +} + +func (s *Server) callFSList(c *gin.Context, raw json.RawMessage) (any, *rpcError) { + args, mcpErr := parseFSListArgs(raw) + if mcpErr != nil { + return nil, mcpErr + } + + user, ok := c.Request.Context().Value(conf.UserKey).(*model.User) + if !ok || user == nil { + return nil, &rpcError{Code: -32603, Message: "missing user context"} + } + if user.IsGuest() && user.Disabled { + return nil, &rpcError{Code: -32001, Message: "guest user is disabled"} + } + + reqPath, err := user.JoinPath(args.Path) + if err != nil { + return nil, &rpcError{Code: -32003, Message: err.Error()} + } + + meta, err := op.GetNearestMeta(reqPath) + if err != nil && !errors.Is(errors.Cause(err), errs.MetaNotFound) { + return nil, &rpcError{Code: -32603, Message: err.Error()} + } + if !common.CanAccess(user, meta, reqPath, args.Password) { + return nil, &rpcError{Code: -32003, Message: "password is incorrect or you have no permission"} + } + + write := common.CanWrite(user, meta, reqPath) + writeContentBypass := common.CanWriteContentBypassUserPerms(meta, reqPath) + canWriteContentAtPath := write && (user.CanWriteContent() || writeContentBypass) + if args.Refresh && !canWriteContentAtPath { + return nil, &rpcError{Code: -32003, Message: "refresh without permission"} + } + + ctx := context.WithValue(c.Request.Context(), conf.MetaKey, meta) + objs, err := fs.List(ctx, reqPath, &fs.ListArgs{ + Refresh: args.Refresh, + WithStorageDetails: !user.IsGuest() && !setting.GetBool(conf.HideStorageDetails), + }) + if err != nil { + return nil, &rpcError{Code: -32603, Message: err.Error()} + } + + total, paged := paginateObjs(objs, args.Page, args.PerPage) + return handles.FsListResp{ + Content: toObjResp(paged, reqPath, isEncrypt(meta, reqPath)), + Total: int64(total), + Write: write, + WriteContentBypass: writeContentBypass, + Provider: "unknown", + Readme: getReadme(meta, reqPath), + Header: getHeader(meta, reqPath), + }, nil +} + +func parseFSListArgs(raw json.RawMessage) (*fsListArgs, *rpcError) { + args := &fsListArgs{ + Page: 1, + PerPage: model.MaxInt, + } + if len(raw) == 0 || string(raw) == "null" { + return nil, &rpcError{Code: -32602, Message: "invalid openlist.fs.list arguments"} + } + + if err := json.Unmarshal(raw, args); err != nil { + return nil, &rpcError{Code: -32602, Message: "invalid openlist.fs.list arguments"} + } + if args.Path == "" { + return nil, &rpcError{Code: -32602, Message: "path is required"} + } + normalizeFSListArgs(args) + return args, nil +} + +func normalizeFSListArgs(args *fsListArgs) { + pageReq := model.PageReq{ + Page: args.Page, + PerPage: args.PerPage, + } + pageReq.Validate() + args.Page = pageReq.Page + args.PerPage = pageReq.PerPage +} + +func paginateObjs(objs []model.Obj, page, perPage int) (int, []model.Obj) { + total := len(objs) + if page < 1 { + page = 1 + } + if perPage < 1 { + perPage = model.MaxInt + } + offset := page - 1 + if offset > total/perPage { + return total, []model.Obj{} + } + start := offset * perPage + if start > total { + return total, []model.Obj{} + } + end := total + if perPage <= total-start { + end = start + perPage + } + if end > total { + end = total + } + return total, objs[start:end] +} + +func toObjResp(objs []model.Obj, parent string, encrypt bool) []handles.ObjResp { + resp := make([]handles.ObjResp, 0, len(objs)) + for _, obj := range objs { + thumb, _ := model.GetThumb(obj) + mountDetails, _ := model.GetStorageDetails(obj) + resp = append(resp, handles.ObjResp{ + Name: obj.GetName(), + Size: obj.GetSize(), + IsDir: obj.IsDir(), + Modified: obj.ModTime(), + Created: obj.CreateTime(), + Sign: common.Sign(obj, parent, encrypt), + Thumb: thumb, + Type: utils.GetObjType(obj.GetName(), obj.IsDir()), + HashInfoStr: obj.GetHash().String(), + HashInfo: obj.GetHash().Export(), + MountDetails: mountDetails, + }) + } + return resp +} + +func getReadme(meta *model.Meta, path string) string { + if meta != nil && common.MetaCoversPath(meta.Path, path, meta.RSub) { + return meta.Readme + } + return "" +} + +func getHeader(meta *model.Meta, path string) string { + if meta != nil && common.MetaCoversPath(meta.Path, path, meta.HeaderSub) { + return meta.Header + } + return "" +} + +func isEncrypt(meta *model.Meta, path string) bool { + if common.IsStorageSignEnabled(path) { + return true + } + if meta == nil || meta.Password == "" { + return false + } + return common.MetaCoversPath(meta.Path, path, meta.PSub) +} diff --git a/server/mcp/get_test.go b/server/mcp/get_test.go new file mode 100644 index 000000000..81fbe3900 --- /dev/null +++ b/server/mcp/get_test.go @@ -0,0 +1,39 @@ +package mcp + +import ( + "encoding/json" + "testing" +) + +func TestParseFSGetArgsRequiresPath(t *testing.T) { + _, err := parseFSGetArgs(json.RawMessage(`{"password":"secret"}`)) + if err == nil { + t.Fatal("expected error") + } + if err.Code != -32602 { + t.Fatalf("unexpected error: %+v", err) + } +} + +func TestParseFSGetArgs(t *testing.T) { + args, err := parseFSGetArgs(json.RawMessage(`{"path":"/movie.mkv","password":"secret"}`)) + if err != nil { + t.Fatalf("unexpected error: %+v", err) + } + if args.Path != "/movie.mkv" { + t.Fatalf("unexpected path: %q", args.Path) + } + if args.Password != "secret" { + t.Fatalf("unexpected password: %q", args.Password) + } +} + +func TestParseFSGetArgsRejectsInvalidJSON(t *testing.T) { + _, err := parseFSGetArgs(json.RawMessage(`"bad"`)) + if err == nil { + t.Fatal("expected error") + } + if err.Code != -32602 { + t.Fatalf("unexpected error: %+v", err) + } +} diff --git a/server/mcp/handler.go b/server/mcp/handler.go new file mode 100644 index 000000000..ac08747a7 --- /dev/null +++ b/server/mcp/handler.go @@ -0,0 +1,532 @@ +package mcp + +import ( + "encoding/json" + "fmt" + "io" + "mime" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/pkg/utils/random" + "github.com/OpenListTeam/OpenList/v4/server/common" + "github.com/OpenListTeam/OpenList/v4/server/middlewares" + "github.com/gin-gonic/gin" +) + +const ( + ProtocolVersion = "2025-11-25" + ProtocolVersionHeader = "MCP-Protocol-Version" + SessionHeader = "MCP-Session-Id" + sessionTTL = 30 * time.Minute + maxSessions = 128 + maxUserSessions = 16 +) + +type session struct { + id string + userID uint + protocolVersion string + initialized bool + createdAt time.Time + lastUsedAt time.Time +} + +type Server struct { + mu sync.Mutex + sessions map[string]*session +} + +type request struct { + JSONRPC string `json:"jsonrpc"` + ID any `json:"id,omitempty"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` +} + +type response struct { + JSONRPC string `json:"jsonrpc"` + ID any `json:"id"` + Result any `json:"result,omitempty"` + Error *rpcError `json:"error,omitempty"` +} + +type rpcError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type initializeParams struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities map[string]any `json:"capabilities"` + ClientInfo map[string]any `json:"clientInfo"` +} + +var defaultServer = &Server{ + sessions: map[string]*session{}, +} + +var supportedProtocolVersions = map[string]struct{}{ + "2025-11-25": {}, + "2025-06-18": {}, +} + +func Register(g *gin.RouterGroup) { + mcpGroup := g.Group("/mcp", middlewares.Auth(false), middlewares.AuthAdmin) + mcpGroup.GET("", defaultServer.handleGet) + mcpGroup.POST("", defaultServer.handlePost) + mcpGroup.DELETE("", defaultServer.handleDelete) +} + +func (s *Server) handleGet(c *gin.Context) { + if !validateOrigin(c.Request) { + c.Status(http.StatusForbidden) + return + } + c.Header("Allow", "POST, DELETE") + c.Status(http.StatusMethodNotAllowed) +} + +func (s *Server) handlePost(c *gin.Context) { + if !validateOrigin(c.Request) { + c.Status(http.StatusForbidden) + return + } + if !acceptsStreamableHTTP(c.GetHeader("Accept")) { + c.JSON(http.StatusNotAcceptable, response{ + JSONRPC: "2.0", + Error: &rpcError{ + Code: -32000, + Message: "Not Acceptable: client must accept both application/json and text/event-stream", + }, + }) + return + } + + body, err := io.ReadAll(io.LimitReader(c.Request.Body, 1<<20)) + if err != nil { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + Error: &rpcError{Code: -32700, Message: "failed to read request body"}, + }) + return + } + + var req request + if err := json.Unmarshal(body, &req); err != nil { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + Error: &rpcError{Code: -32700, Message: "parse error"}, + }) + return + } + if req.JSONRPC != "2.0" { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32600, Message: "invalid request"}, + }) + return + } + + if req.Method == "initialize" { + s.handleInitialize(c, req) + return + } + sessionID := c.GetHeader(SessionHeader) + if sessionID == "" { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32000, Message: "missing MCP session"}, + }) + return + } + + currentSession, ok := s.getSession(sessionID) + if !ok { + c.JSON(http.StatusNotFound, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32001, Message: "session not found"}, + }) + return + } + + user := c.Request.Context().Value(conf.UserKey).(*model.User) + if currentSession.userID != user.ID { + c.JSON(http.StatusNotFound, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32001, Message: "session not found"}, + }) + return + } + + if !s.validateRequestProtocolVersion(c.GetHeader(ProtocolVersionHeader), currentSession.protocolVersion) { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32000, Message: "missing or unsupported MCP protocol version"}, + }) + return + } + + switch req.Method { + case "ping": + c.JSON(http.StatusOK, response{JSONRPC: "2.0", ID: req.ID, Result: map[string]any{}}) + case "notifications/initialized": + s.markSessionInitialized(sessionID) + c.Status(http.StatusAccepted) + case "tools/list": + if !s.sessionInitialized(sessionID) { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32002, Message: "MCP session not initialized"}, + }) + return + } + c.JSON(http.StatusOK, s.handleToolsList(req)) + case "tools/call": + if !s.sessionInitialized(sessionID) { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32002, Message: "MCP session not initialized"}, + }) + return + } + status, resp := s.handleToolsCall(c, req) + c.JSON(status, resp) + default: + c.JSON(http.StatusOK, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32601, Message: fmt.Sprintf("method %q not implemented yet", req.Method)}, + }) + } +} + +func (s *Server) handleInitialize(c *gin.Context, req request) { + var params initializeParams + if len(req.Params) > 0 { + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32602, Message: "invalid initialize params"}, + }) + return + } + } + + protocolVersion := negotiateProtocolVersion(params.ProtocolVersion) + currentSession := s.initializeSession( + c.Request.Context().Value(conf.UserKey).(*model.User).ID, + c.GetHeader(SessionHeader), + protocolVersion, + ) + c.Header(SessionHeader, currentSession.id) + c.JSON(http.StatusOK, response{ + JSONRPC: "2.0", + ID: req.ID, + Result: map[string]any{ + "protocolVersion": protocolVersion, + "capabilities": map[string]any{ + "tools": map[string]any{ + "listChanged": false, + }, + }, + "serverInfo": map[string]any{ + "name": "OpenList MCP", + "version": conf.Version, + }, + "instructions": "Complete initialization with notifications/initialized, then use tools/list and tools/call. Available tools include openlist.fs.list, openlist.fs.get, and openlist.fs.link.", + }, + }) +} + +func (s *Server) initializeSession(userID uint, requestedID string, protocolVersion string) *session { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + s.pruneExpiredSessionsLocked(now) + if requestedID != "" { + currentSession, ok := s.sessions[requestedID] + if ok && currentSession != nil && currentSession.userID == userID { + currentSession.initialized = false + currentSession.protocolVersion = protocolVersion + currentSession.lastUsedAt = now + return currentSession + } + } + + return s.createSessionLocked(userID, protocolVersion, now) +} + +func (s *Server) handleDelete(c *gin.Context) { + if !validateOrigin(c.Request) { + c.Status(http.StatusForbidden) + return + } + + currentSession, ok := s.getSession(c.GetHeader(SessionHeader)) + if !ok { + c.Status(http.StatusNotFound) + return + } + + user := c.Request.Context().Value(conf.UserKey).(*model.User) + if currentSession.userID != user.ID { + c.Status(http.StatusNotFound) + return + } + + s.deleteSession(currentSession.id) + c.Status(http.StatusNoContent) +} + +func (s *Server) createSession(userID uint) *session { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + s.pruneExpiredSessionsLocked(now) + return s.createSessionLocked(userID, ProtocolVersion, now) +} + +func (s *Server) createSessionLocked(userID uint, protocolVersion string, now time.Time) *session { + s.pruneLeastRecentlyUsedUserSessionsLocked(userID, max(0, s.countUserSessionsLocked(userID)-maxUserSessions+1)) + s.pruneLeastRecentlyUsedSessionsLocked(max(0, len(s.sessions)-maxSessions+1)) + currentSession := &session{ + id: random.Token(), + userID: userID, + protocolVersion: protocolVersion, + createdAt: now, + lastUsedAt: now, + } + s.sessions[currentSession.id] = currentSession + return currentSession +} + +func (s *Server) validateRequestProtocolVersion(requestedVersion string, negotiatedVersion string) bool { + if requestedVersion == "" { + return false + } + if _, ok := supportedProtocolVersions[requestedVersion]; !ok { + return false + } + return negotiatedVersion == "" || requestedVersion == negotiatedVersion +} + +func negotiateProtocolVersion(requestedVersion string) string { + if _, ok := supportedProtocolVersions[requestedVersion]; ok { + return requestedVersion + } + return ProtocolVersion +} + +func (s *Server) getSession(id string) (session, bool) { + if id == "" { + return session{}, false + } + s.mu.Lock() + defer s.mu.Unlock() + currentSession, ok := s.sessions[id] + if !ok || currentSession == nil { + return session{}, false + } + now := time.Now() + if sessionExpired(currentSession, now) { + delete(s.sessions, id) + return session{}, false + } + currentSession.lastUsedAt = now + return *currentSession, true +} + +func (s *Server) markSessionInitialized(id string) bool { + s.mu.Lock() + defer s.mu.Unlock() + currentSession, ok := s.sessions[id] + if !ok || currentSession == nil { + return false + } + now := time.Now() + if sessionExpired(currentSession, now) { + delete(s.sessions, id) + return false + } + currentSession.initialized = true + currentSession.lastUsedAt = now + return true +} + +func (s *Server) sessionInitialized(id string) bool { + s.mu.Lock() + defer s.mu.Unlock() + currentSession, ok := s.sessions[id] + if !ok || currentSession == nil { + return false + } + now := time.Now() + if sessionExpired(currentSession, now) { + delete(s.sessions, id) + return false + } + currentSession.lastUsedAt = now + return currentSession.initialized +} + +func (s *Server) deleteSession(id string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.sessions, id) +} + +func (s *Server) pruneExpiredSessionsLocked(now time.Time) { + for id, currentSession := range s.sessions { + if currentSession == nil || sessionExpired(currentSession, now) { + delete(s.sessions, id) + } + } +} + +func (s *Server) pruneLeastRecentlyUsedSessionsLocked(count int) { + for range count { + var ( + oldestID string + oldest time.Time + ) + for id, currentSession := range s.sessions { + if currentSession == nil { + oldestID = id + break + } + lastUsedAt := sessionLastUsedAt(currentSession) + if oldestID == "" || lastUsedAt.Before(oldest) { + oldestID = id + oldest = lastUsedAt + } + } + if oldestID == "" { + return + } + delete(s.sessions, oldestID) + } +} + +func (s *Server) countUserSessionsLocked(userID uint) int { + count := 0 + for _, currentSession := range s.sessions { + if currentSession != nil && currentSession.userID == userID { + count++ + } + } + return count +} + +func (s *Server) pruneLeastRecentlyUsedUserSessionsLocked(userID uint, count int) { + for range count { + var ( + oldestID string + oldest time.Time + ) + for id, currentSession := range s.sessions { + if currentSession == nil { + continue + } + if currentSession.userID != userID { + continue + } + lastUsedAt := sessionLastUsedAt(currentSession) + if oldestID == "" || lastUsedAt.Before(oldest) { + oldestID = id + oldest = lastUsedAt + } + } + if oldestID == "" { + return + } + delete(s.sessions, oldestID) + } +} + +func sessionExpired(currentSession *session, now time.Time) bool { + lastUsedAt := sessionLastUsedAt(currentSession) + return !lastUsedAt.IsZero() && now.Sub(lastUsedAt) > sessionTTL +} + +func sessionLastUsedAt(currentSession *session) time.Time { + if currentSession.lastUsedAt.IsZero() { + return currentSession.createdAt + } + return currentSession.lastUsedAt +} + +func acceptsStreamableHTTP(accept string) bool { + if accept == "" { + return false + } + hasJSON := false + hasSSE := false + for part := range strings.SplitSeq(accept, ",") { + mediaType, params, err := mime.ParseMediaType(strings.TrimSpace(part)) + if err != nil { + continue + } + if q, ok := params["q"]; ok { + quality, err := strconv.ParseFloat(q, 64) + if err == nil && quality == 0 { + continue + } + } + switch mediaType { + case "application/json": + hasJSON = true + case "text/event-stream": + hasSSE = true + } + } + return hasJSON && hasSSE +} + +func validateOrigin(r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin == "" { + return true + } + + originURL, err := url.Parse(origin) + if err != nil || originURL.Host == "" { + return false + } + if strings.EqualFold(originURL.Host, r.Host) { + return strings.EqualFold(originURL.Scheme, requestScheme(r)) + } + + siteURL := common.GetApiUrlFromRequest(r) + if siteURL == "" { + return false + } + siteParsed, err := url.Parse(siteURL) + if err != nil { + return false + } + return strings.EqualFold(originURL.Host, siteParsed.Host) && strings.EqualFold(originURL.Scheme, siteParsed.Scheme) +} + +func requestScheme(r *http.Request) string { + if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { + return "https" + } + return "http" +} diff --git a/server/mcp/handler_test.go b/server/mcp/handler_test.go new file mode 100644 index 000000000..ac7a79158 --- /dev/null +++ b/server/mcp/handler_test.go @@ -0,0 +1,509 @@ +package mcp + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/server/common" + "github.com/gin-gonic/gin" +) + +func TestInitializeCreatesSession(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(nil) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"initialize", + "params":{ + "protocolVersion":"2025-11-25", + "capabilities":{}, + "clientInfo":{"name":"test-client","version":"1.0.0"} + } + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) + } + if got := w.Header().Get(SessionHeader); got == "" { + t.Fatal("expected session header to be set") + } + + var resp response + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp.Error != nil { + t.Fatalf("unexpected error response: %+v", resp.Error) + } + result, ok := resp.Result.(map[string]any) + if !ok { + t.Fatalf("unexpected result type: %T", resp.Result) + } + if result["protocolVersion"] != ProtocolVersion { + t.Fatalf("unexpected protocol version: %v", result["protocolVersion"]) + } +} + +func TestInitializeNegotiatesSupportedOlderProtocolVersion(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(nil) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"initialize", + "params":{ + "protocolVersion":"2025-06-18", + "capabilities":{}, + "clientInfo":{"name":"test-client","version":"1.0.0"} + } + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) + } + resp := decodeResponse(t, w) + if resp.Error != nil { + t.Fatalf("unexpected error response: %+v", resp.Error) + } + result, ok := resp.Result.(map[string]any) + if !ok { + t.Fatalf("unexpected result type: %T", resp.Result) + } + if result["protocolVersion"] != "2025-06-18" { + t.Fatalf("unexpected protocol version: %v", result["protocolVersion"]) + } +} + +func TestInitializeReusesExistingSession(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(nil) + currentSession := srv.createSession(1) + currentSession.initialized = true + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"initialize", + "params":{ + "protocolVersion":"2025-06-18", + "capabilities":{}, + "clientInfo":{"name":"test-client","version":"1.0.0"} + } + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(SessionHeader, currentSession.id) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) + } + if got := w.Header().Get(SessionHeader); got != currentSession.id { + t.Fatalf("unexpected session header: got %q want %q", got, currentSession.id) + } + if len(srv.sessions) != 1 { + t.Fatalf("unexpected session count: got %d want %d", len(srv.sessions), 1) + } + reusedSession, ok := srv.getSession(currentSession.id) + if !ok { + t.Fatal("expected existing session to be reused") + } + if reusedSession.initialized { + t.Fatal("expected reused session to require initialized notification again") + } + if reusedSession.protocolVersion != "2025-06-18" { + t.Fatalf("unexpected protocol version: got %q want %q", reusedSession.protocolVersion, "2025-06-18") + } +} + +func TestInitializeNegotiatesUnsupportedProtocolVersion(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(nil) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"initialize", + "params":{ + "protocolVersion":"2026-01-01", + "capabilities":{}, + "clientInfo":{"name":"test-client","version":"1.0.0"} + } + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) + } + resp := decodeResponse(t, w) + if resp.Error != nil { + t.Fatalf("unexpected error response: %+v", resp.Error) + } + result, ok := resp.Result.(map[string]any) + if !ok { + t.Fatalf("unexpected result type: %T", resp.Result) + } + if result["protocolVersion"] != ProtocolVersion { + t.Fatalf("unexpected protocol version: %v", result["protocolVersion"]) + } +} + +func TestCreateSessionPrunesUserSessionLimit(t *testing.T) { + srv := newTestServer(nil) + var firstSessionID string + for i := range maxUserSessions { + currentSession := srv.createSession(1) + if i == 0 { + firstSessionID = currentSession.id + } + } + srv.createSession(2) + srv.createSession(1) + + if _, ok := srv.sessions[firstSessionID]; ok { + t.Fatal("expected oldest user session to be pruned") + } + if got := countSessionsForUser(srv, 1); got != maxUserSessions { + t.Fatalf("unexpected user session count: got %d want %d", got, maxUserSessions) + } + if got := len(srv.sessions); got != maxUserSessions+1 { + t.Fatalf("unexpected total session count: got %d want %d", got, maxUserSessions+1) + } +} + +func TestCreateSessionPrunesGlobalSessionLimit(t *testing.T) { + srv := newTestServer(nil) + var firstSessionID string + for i := range maxSessions { + currentSession := srv.createSession(uint(i + 1)) + if i == 0 { + firstSessionID = currentSession.id + } + } + srv.createSession(uint(maxSessions + 1)) + + if _, ok := srv.sessions[firstSessionID]; ok { + t.Fatal("expected oldest global session to be pruned") + } + if got := len(srv.sessions); got != maxSessions { + t.Fatalf("unexpected session count: got %d want %d", got, maxSessions) + } +} + +func TestPostInvalidAcceptReturnsJSONRPCError(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(nil) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"initialize" + }`)) + req.Header.Set("Accept", "application/json") + req.Header.Set("Origin", "http://example.com") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotAcceptable { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusNotAcceptable) + } + resp := decodeResponse(t, w) + if resp.Error == nil || resp.Error.Code != -32000 { + t.Fatalf("unexpected error response: %+v", resp.Error) + } +} + +func TestPostRequiresProtocolVersionAfterInitialize(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(map[string]*session{ + "s1": {id: "s1", userID: 1, initialized: true}, + }) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"tools/list" + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(SessionHeader, "s1") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusBadRequest) + } + resp := decodeResponse(t, w) + if resp.Error == nil || resp.Error.Code != -32000 { + t.Fatalf("unexpected error response: %+v", resp.Error) + } +} + +func TestPostRejectsUnsupportedProtocolVersion(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(map[string]*session{ + "s1": {id: "s1", userID: 1, initialized: true}, + }) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"tools/list" + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(ProtocolVersionHeader, "2025-03-26") + req.Header.Set(SessionHeader, "s1") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusBadRequest) + } + resp := decodeResponse(t, w) + if resp.Error == nil || resp.Error.Code != -32000 { + t.Fatalf("unexpected error response: %+v", resp.Error) + } +} + +func TestPostRejectsProtocolVersionMismatch(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(map[string]*session{ + "s1": {id: "s1", userID: 1, protocolVersion: "2025-06-18", initialized: true}, + }) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"tools/list" + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(ProtocolVersionHeader, ProtocolVersion) + req.Header.Set(SessionHeader, "s1") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusBadRequest) + } + resp := decodeResponse(t, w) + if resp.Error == nil || resp.Error.Code != -32000 { + t.Fatalf("unexpected error response: %+v", resp.Error) + } +} + +func TestPostRejectsMissingSession(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(nil) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"tools/list" + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(ProtocolVersionHeader, ProtocolVersion) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusBadRequest) + } + resp := decodeResponse(t, w) + if resp.Error == nil || resp.Error.Code != -32000 { + t.Fatalf("unexpected error response: %+v", resp.Error) + } +} + +func TestPostRejectsUnknownSessionWithNotFound(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(nil) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"tools/list" + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(ProtocolVersionHeader, ProtocolVersion) + req.Header.Set(SessionHeader, "unknown") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusNotFound) + } + resp := decodeResponse(t, w) + if resp.Error == nil || resp.Error.Code != -32001 { + t.Fatalf("unexpected error response: %+v", resp.Error) + } +} + +func TestDeleteRemovesSession(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(nil) + + currentSession := srv.createSession(1) + r := gin.New() + r.DELETE("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handleDelete(c) + }) + + req := httptest.NewRequest(http.MethodDelete, "http://example.com/mcp", nil) + req.Header.Set("Origin", "http://example.com") + req.Header.Set(SessionHeader, currentSession.id) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNoContent { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusNoContent) + } + if _, ok := srv.getSession(currentSession.id); ok { + t.Fatal("expected session to be deleted") + } +} + +func TestGetReturnsMethodNotAllowed(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.GET("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + defaultServer.handleGet(c) + }) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/mcp", nil) + req.Header.Set("Origin", "http://example.com") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusMethodNotAllowed) + } + if allow := w.Header().Get("Allow"); allow != "POST, DELETE" { + t.Fatalf("unexpected Allow header: got %q want %q", allow, "POST, DELETE") + } +} + +func newTestServer(sessions map[string]*session) *Server { + if sessions == nil { + sessions = map[string]*session{} + } + now := time.Now() + for _, currentSession := range sessions { + if currentSession.createdAt.IsZero() { + currentSession.createdAt = now + } + if currentSession.lastUsedAt.IsZero() { + currentSession.lastUsedAt = now + } + } + return &Server{sessions: sessions} +} + +func countSessionsForUser(srv *Server, userID uint) int { + count := 0 + for _, currentSession := range srv.sessions { + if currentSession != nil && currentSession.userID == userID { + count++ + } + } + return count +} diff --git a/server/mcp/link_test.go b/server/mcp/link_test.go new file mode 100644 index 000000000..10896d80c --- /dev/null +++ b/server/mcp/link_test.go @@ -0,0 +1,143 @@ +package mcp + +import ( + "context" + "encoding/json" + "net/http/httptest" + "sync" + "testing" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/driver" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/gin-gonic/gin" +) + +var settingCacheMu sync.Mutex + +func TestParseFSLinkArgs(t *testing.T) { + args, err := parseFSLinkArgs(json.RawMessage(`{"path":"/file.txt","password":"pw","type":"thumb"}`)) + if err != nil { + t.Fatalf("unexpected error: %+v", err) + } + if args.Path != "/file.txt" || args.Password != "pw" || args.Type != "thumb" { + t.Fatalf("unexpected args: %+v", args) + } +} + +func TestParseFSLinkArgsRequiresPath(t *testing.T) { + _, err := parseFSLinkArgs(json.RawMessage(`{"type":"thumb"}`)) + if err == nil || err.Code != -32602 { + t.Fatalf("unexpected error: %+v", err) + } +} + +func TestCanProxyFile(t *testing.T) { + storage := &fsLinkTestDriver{ + config: driver.Config{Name: "Test"}, + storage: model.Storage{ + MountPath: "/", + }, + } + if canProxyFile(storage, "file.bin") { + t.Fatal("unexpected proxy support") + } + storage.config.OnlyProxy = true + if !canProxyFile(storage, "file.bin") { + t.Fatal("expected proxy support") + } +} + +func TestCanProxyFileIgnoresWebdavProxyURLPolicy(t *testing.T) { + storage := &fsLinkTestDriver{ + config: driver.Config{Name: "Test"}, + storage: model.Storage{ + MountPath: "/", + Proxy: model.Proxy{ + WebdavPolicy: "use_proxy_url", + }, + }, + } + + if canProxyFile(storage, "file.bin") { + t.Fatal("webdav proxy url policy should not enable MCP proxy support") + } +} + +func TestBuildFSLinkInfoUsesProxyWhenStorageRequiresProxy(t *testing.T) { + gin.SetMode(gin.TestMode) + settingCacheMu.Lock() + t.Cleanup(func() { + op.Cache.ClearAll() + settingCacheMu.Unlock() + }) + op.Cache.SetSetting(conf.LinkExpiration, &model.SettingItem{Key: conf.LinkExpiration, Value: "0"}) + op.Cache.SetSetting(conf.Token, &model.SettingItem{Key: conf.Token, Value: "test-token"}) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("POST", "http://example.com/mcp", nil) + ctx := context.WithValue(c.Request.Context(), conf.ApiUrlKey, "http://openlist.test") + storage := &fsLinkTestDriver{ + config: driver.Config{Name: "Test", OnlyProxy: true}, + storage: model.Storage{ + MountPath: "/", + Proxy: model.Proxy{ + DownProxyURL: "http://proxy.test", + DisableProxySign: true, + }, + }, + } + obj := &model.ObjectURL{ + Object: model.Object{Name: "file.txt", Size: 12}, + Url: model.Url{Url: "http://direct.test/file.txt"}, + } + + meta := &model.Meta{Path: "/file.txt", Password: "secret"} + resp, err := buildFSLinkInfo(ctx, c, "/file.txt", &fsLinkArgs{Path: "/file.txt"}, obj, meta, storage) + if err != nil { + t.Fatalf("unexpected error: %+v", err) + } + if resp.URL != "http://proxy.test/file.txt" || resp.URLType != "proxy" { + t.Fatalf("unexpected selected link: %+v", resp) + } + if resp.DirectURL != "" { + t.Fatalf("direct link should not be resolved for proxy storage: %+v", resp) + } +} + +type fsLinkTestDriver struct { + config driver.Config + storage model.Storage +} + +func (d *fsLinkTestDriver) Config() driver.Config { + return d.config +} + +func (d *fsLinkTestDriver) GetStorage() *model.Storage { + return &d.storage +} + +func (d *fsLinkTestDriver) SetStorage(storage model.Storage) { + d.storage = storage +} + +func (d *fsLinkTestDriver) GetAddition() driver.Additional { + return nil +} + +func (d *fsLinkTestDriver) Init(context.Context) error { + return nil +} + +func (d *fsLinkTestDriver) Drop(context.Context) error { + return nil +} + +func (d *fsLinkTestDriver) List(context.Context, model.Obj, model.ListArgs) ([]model.Obj, error) { + return nil, nil +} + +func (d *fsLinkTestDriver) Link(context.Context, model.Obj, model.LinkArgs) (*model.Link, error) { + return nil, nil +} diff --git a/server/mcp/list_test.go b/server/mcp/list_test.go new file mode 100644 index 000000000..8079245aa --- /dev/null +++ b/server/mcp/list_test.go @@ -0,0 +1,54 @@ +package mcp + +import ( + "encoding/json" + "testing" +) + +func TestParseFSListArgsRequiresPath(t *testing.T) { + assertFSListPathRequired(t, json.RawMessage(`{"refresh":true}`)) +} + +func TestParseFSListArgsRejectsEmptyArguments(t *testing.T) { + for _, raw := range []json.RawMessage{nil, json.RawMessage(`null`)} { + assertFSListInvalidArguments(t, raw) + } +} + +func TestParseFSListArgsRejectsInvalidJSON(t *testing.T) { + assertFSListInvalidArguments(t, json.RawMessage(`"bad"`)) +} + +func TestParseFSListArgs(t *testing.T) { + args, err := parseFSListArgs(json.RawMessage(`{"path":"/movies","password":"secret","refresh":true,"page":2,"per_page":10}`)) + if err != nil { + t.Fatalf("unexpected error: %+v", err) + } + if args.Path != "/movies" || args.Password != "secret" || !args.Refresh || args.Page != 2 || args.PerPage != 10 { + t.Fatalf("unexpected args: %+v", args) + } +} + +func assertFSListPathRequired(t *testing.T, raw json.RawMessage) { + t.Helper() + + _, err := parseFSListArgs(raw) + if err == nil { + t.Fatal("expected error") + } + if err.Code != -32602 || err.Message != "path is required" { + t.Fatalf("unexpected error: %+v", err) + } +} + +func assertFSListInvalidArguments(t *testing.T, raw json.RawMessage) { + t.Helper() + + _, err := parseFSListArgs(raw) + if err == nil { + t.Fatal("expected error") + } + if err.Code != -32602 || err.Message != "invalid openlist.fs.list arguments" { + t.Fatalf("unexpected error: %+v", err) + } +} diff --git a/server/mcp/tools.go b/server/mcp/tools.go new file mode 100644 index 000000000..188c674ce --- /dev/null +++ b/server/mcp/tools.go @@ -0,0 +1,122 @@ +package mcp + +import "encoding/json" + +type tool struct { + Name string `json:"name"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + InputSchema toolInputSchema `json:"inputSchema"` +} + +type toolInputSchema struct { + Type string `json:"type"` + Properties map[string]schemaProperty `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` +} + +type schemaProperty struct { + Type string `json:"type,omitempty"` + Description string `json:"description,omitempty"` +} + +type toolsListParams struct { + Cursor string `json:"cursor,omitempty"` +} + +var openListTools = []tool{ + { + Name: "openlist.fs.list", + Title: "OpenList FS List", + Description: "List files and directories under a mount path that the current user can access.", + InputSchema: toolInputSchema{ + Type: "object", + Properties: map[string]schemaProperty{ + "path": { + Type: "string", + Description: "Mount path to list, for example \"/\" or \"/movies\".", + }, + "refresh": { + Type: "boolean", + Description: "Refresh the directory listing before returning results.", + }, + "password": { + Type: "string", + Description: "Optional password for protected paths.", + }, + "page": { + Type: "integer", + Description: "1-based page number.", + }, + "per_page": { + Type: "integer", + Description: "Page size.", + }, + }, + Required: []string{"path"}, + }, + }, + { + Name: "openlist.fs.get", + Title: "OpenList FS Get", + Description: "Get file or directory details for a mount path that the current user can access.", + InputSchema: toolInputSchema{ + Type: "object", + Properties: map[string]schemaProperty{ + "path": { + Type: "string", + Description: "Mount path to inspect, for example \"/movies/demo.mp4\".", + }, + "password": { + Type: "string", + Description: "Optional password for protected paths.", + }, + }, + Required: []string{"path"}, + }, + }, + { + Name: "openlist.fs.link", + Title: "OpenList FS Link", + Description: "Return usable link information for a file path that the current user can access.", + InputSchema: toolInputSchema{ + Type: "object", + Properties: map[string]schemaProperty{ + "path": { + Type: "string", + Description: "File mount path, for example \"/movies/demo.mp4\".", + }, + "password": { + Type: "string", + Description: "Optional password for protected paths.", + }, + "type": { + Type: "string", + Description: "Optional link type forwarded to storage drivers.", + }, + }, + Required: []string{"path"}, + }, + }, +} + +func (s *Server) handleToolsList(req request) response { + var params toolsListParams + if len(req.Params) > 0 { + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32602, Message: "invalid tools/list params"}, + } + } + } + + return response{ + JSONRPC: "2.0", + ID: req.ID, + Result: map[string]any{ + "tools": openListTools, + }, + } +} diff --git a/server/router.go b/server/router.go index 57d1166ae..3c7b99adb 100644 --- a/server/router.go +++ b/server/router.go @@ -41,6 +41,7 @@ func Init(e *gin.Engine) { } WebDav(g.Group("/dav")) S3(g.Group("/s3")) + MCP(g) downloadLimiter := middlewares.DownloadRateLimiter(stream.ClientDownloadLimit) signCheck := middlewares.Down(sign.Verify)