From 0958d73a14995836a9b97f1269cd2bd9d206319f Mon Sep 17 00:00:00 2001 From: jyxjjj <773933146@qq.com> Date: Mon, 18 May 2026 23:20:37 +0800 Subject: [PATCH 01/16] feat(mcp): add initial MCP tools support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 MCP 路由与 session 初始化流程 - 接入 tools/list 与 tools/call,并开放 openlist.fs.list - 补充 MCP 相关协议与路由测试 Co-authored-by: Codex <267193182+codex@users.noreply.github.com> Signed-off-by: jyxjjj <16695261+jyxjjj@users.noreply.github.com> --- server/mcp.go | 337 ++++++++++++++++++++++++++++++++++++++++ server/mcp_call.go | 84 ++++++++++ server/mcp_call_test.go | 181 +++++++++++++++++++++ server/mcp_fs_list.go | 198 +++++++++++++++++++++++ server/mcp_test.go | 110 +++++++++++++ server/mcp_tools.go | 80 ++++++++++ server/router.go | 1 + 7 files changed, 991 insertions(+) create mode 100644 server/mcp.go create mode 100644 server/mcp_call.go create mode 100644 server/mcp_call_test.go create mode 100644 server/mcp_fs_list.go create mode 100644 server/mcp_test.go create mode 100644 server/mcp_tools.go diff --git a/server/mcp.go b/server/mcp.go new file mode 100644 index 000000000..05c24b343 --- /dev/null +++ b/server/mcp.go @@ -0,0 +1,337 @@ +package server + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/pkg/utils/random" + "github.com/OpenListTeam/OpenList/v4/server/common" + "github.com/OpenListTeam/OpenList/v4/server/middlewares" + "github.com/gin-gonic/gin" +) + +const ( + mcpProtocolVersion = "2025-06-18" + mcpSessionHeader = "Mcp-Session-Id" +) + +type mcpSession struct { + id string + userID uint + initialized bool + createdAt time.Time +} + +type mcpServer struct { + mu sync.RWMutex + sessions map[string]*mcpSession +} + +type mcpRequest struct { + JSONRPC string `json:"jsonrpc"` + ID any `json:"id,omitempty"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` +} + +type mcpResponse struct { + JSONRPC string `json:"jsonrpc"` + ID any `json:"id"` + Result any `json:"result,omitempty"` + Error *mcpError `json:"error,omitempty"` +} + +type mcpError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type mcpInitializeParams struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities map[string]any `json:"capabilities"` + ClientInfo map[string]any `json:"clientInfo"` +} + +var openListMCP = &mcpServer{ + sessions: map[string]*mcpSession{}, +} + +func MCP(g *gin.RouterGroup) { + mcp := g.Group("/mcp", middlewares.Auth(false), middlewares.AuthAdmin) + mcp.GET("", openListMCP.handleGet) + mcp.POST("", openListMCP.handlePost) + mcp.DELETE("", openListMCP.handleDelete) +} + +func (s *mcpServer) handleGet(c *gin.Context) { + if !validateMCPOrigin(c.Request) { + c.Status(http.StatusForbidden) + return + } + c.Status(http.StatusMethodNotAllowed) +} + +func (s *mcpServer) handlePost(c *gin.Context) { + if !validateMCPOrigin(c.Request) { + c.Status(http.StatusForbidden) + return + } + if !acceptsMCPJSON(c.GetHeader("Accept")) { + c.Status(http.StatusNotAcceptable) + return + } + + body, err := io.ReadAll(io.LimitReader(c.Request.Body, 1<<20)) + if err != nil { + c.JSON(http.StatusBadRequest, mcpResponse{ + JSONRPC: "2.0", + Error: &mcpError{Code: -32700, Message: "failed to read request body"}, + }) + return + } + + var req mcpRequest + if err := json.Unmarshal(body, &req); err != nil { + c.JSON(http.StatusBadRequest, mcpResponse{ + JSONRPC: "2.0", + Error: &mcpError{Code: -32700, Message: "parse error"}, + }) + return + } + if req.JSONRPC != "2.0" { + c.JSON(http.StatusBadRequest, mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Error: &mcpError{Code: -32600, Message: "invalid request"}, + }) + return + } + + if req.Method == "initialize" { + s.handleInitialize(c, req) + return + } + + sessionID := c.GetHeader(mcpSessionHeader) + session, ok := s.getSession(sessionID) + if !ok { + c.Status(http.StatusBadRequest) + c.JSON(http.StatusBadRequest, mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Error: &mcpError{Code: -32000, Message: "missing or invalid MCP session"}, + }) + return + } + + user := c.Request.Context().Value(conf.UserKey).(*model.User) + if session.userID != user.ID { + c.Status(http.StatusNotFound) + c.JSON(http.StatusNotFound, mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Error: &mcpError{Code: -32001, Message: "session not found"}, + }) + return + } + + switch req.Method { + case "ping": + c.JSON(http.StatusOK, mcpResponse{JSONRPC: "2.0", ID: req.ID, Result: map[string]any{}}) + case "notifications/initialized": + s.markSessionInitialized(sessionID) + c.Status(http.StatusAccepted) + case "tools/list": + if !s.sessionInitialized(sessionID) { + c.JSON(http.StatusBadRequest, mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Error: &mcpError{Code: -32002, Message: "MCP session not initialized"}, + }) + return + } + c.JSON(http.StatusOK, s.handleToolsList(req)) + case "tools/call": + if !s.sessionInitialized(sessionID) { + c.JSON(http.StatusBadRequest, mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Error: &mcpError{Code: -32002, Message: "MCP session not initialized"}, + }) + return + } + status, resp := s.handleToolsCall(c, req) + c.JSON(status, resp) + default: + c.JSON(http.StatusOK, mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Error: &mcpError{Code: -32601, Message: fmt.Sprintf("method %q not implemented yet", req.Method)}, + }) + } +} + +func (s *mcpServer) handleInitialize(c *gin.Context, req mcpRequest) { + var params mcpInitializeParams + if len(req.Params) > 0 { + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + c.JSON(http.StatusBadRequest, mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Error: &mcpError{Code: -32602, Message: "invalid initialize params"}, + }) + return + } + } + if params.ProtocolVersion != "" && params.ProtocolVersion != mcpProtocolVersion { + c.JSON(http.StatusBadRequest, mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Error: &mcpError{Code: -32602, Message: "unsupported protocol version"}, + }) + return + } + + session := s.createSession(c.Request.Context().Value(conf.UserKey).(*model.User).ID) + c.Header(mcpSessionHeader, session.id) + c.JSON(http.StatusOK, mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Result: map[string]any{ + "protocolVersion": mcpProtocolVersion, + "capabilities": map[string]any{ + "tools": map[string]any{ + "listChanged": false, + }, + }, + "serverInfo": map[string]any{ + "name": "OpenList MCP", + "version": conf.Version, + }, + "instructions": "Complete initialization with notifications/initialized, then use tools/list and tools/call. The first available tool is openlist.fs.list.", + }, + }) +} + +func (s *mcpServer) handleDelete(c *gin.Context) { + if !validateMCPOrigin(c.Request) { + c.Status(http.StatusForbidden) + return + } + + session, ok := s.getSession(c.GetHeader(mcpSessionHeader)) + if !ok { + c.Status(http.StatusNotFound) + return + } + + user := c.Request.Context().Value(conf.UserKey).(*model.User) + if session.userID != user.ID { + c.Status(http.StatusNotFound) + return + } + + s.deleteSession(session.id) + c.Status(http.StatusNoContent) +} + +func (s *mcpServer) createSession(userID uint) *mcpSession { + s.mu.Lock() + defer s.mu.Unlock() + + session := &mcpSession{ + id: random.Token(), + userID: userID, + createdAt: time.Now(), + } + s.sessions[session.id] = session + return session +} + +func (s *mcpServer) getSession(id string) (mcpSession, bool) { + if id == "" { + return mcpSession{}, false + } + s.mu.RLock() + defer s.mu.RUnlock() + session, ok := s.sessions[id] + if !ok || session == nil { + return mcpSession{}, false + } + return *session, true +} + +func (s *mcpServer) markSessionInitialized(id string) bool { + s.mu.Lock() + defer s.mu.Unlock() + session, ok := s.sessions[id] + if !ok || session == nil { + return false + } + session.initialized = true + return true +} + +func (s *mcpServer) sessionInitialized(id string) bool { + s.mu.RLock() + defer s.mu.RUnlock() + session, ok := s.sessions[id] + return ok && session != nil && session.initialized +} + +func (s *mcpServer) deleteSession(id string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.sessions, id) +} + +func acceptsMCPJSON(accept string) bool { + if accept == "" { + return false + } + return strings.Contains(accept, "application/json") || + strings.Contains(accept, "*/*") +} + +func validateMCPOrigin(r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin == "" { + return true + } + + originURL, err := url.Parse(origin) + if err != nil || originURL.Host == "" { + return false + } + if strings.EqualFold(originURL.Host, r.Host) { + return strings.EqualFold(originURL.Scheme, requestScheme(r)) + } + + siteURL := common.GetApiUrlFromRequest(r) + if siteURL == "" { + return false + } + siteParsed, err := url.Parse(siteURL) + if err != nil { + return false + } + if strings.EqualFold(originURL.Host, siteParsed.Host) && strings.EqualFold(originURL.Scheme, siteParsed.Scheme) { + return true + } + return false +} + +func requestScheme(r *http.Request) string { + if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { + return "https" + } + return "http" +} diff --git a/server/mcp_call.go b/server/mcp_call.go new file mode 100644 index 000000000..76aafd7b9 --- /dev/null +++ b/server/mcp_call.go @@ -0,0 +1,84 @@ +package server + +import ( + "encoding/json" + "net/http" + + "github.com/gin-gonic/gin" +) + +type mcpToolCallParams struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` +} + +type mcpToolResultContent struct { + Type string `json:"type"` + Text string `json:"text"` +} + +func (s *mcpServer) handleToolsCall(c *gin.Context, req mcpRequest) (int, mcpResponse) { + var params mcpToolCallParams + if len(req.Params) == 0 { + return http.StatusBadRequest, mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Error: &mcpError{Code: -32602, Message: "invalid tools/call params"}, + } + } + if err := json.Unmarshal(req.Params, ¶ms); err != nil || params.Name == "" { + return http.StatusBadRequest, mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Error: &mcpError{Code: -32602, Message: "invalid tools/call params"}, + } + } + + var ( + result any + err *mcpError + ) + switch params.Name { + case "openlist.fs.list": + result, err = s.callFSList(c, params.Arguments) + default: + return http.StatusOK, mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Error: &mcpError{Code: -32601, Message: "unknown tool"}, + } + } + + if err != nil { + return http.StatusOK, mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Result: map[string]any{ + "content": []mcpToolResultContent{ + {Type: "text", Text: err.Message}, + }, + "isError": true, + }, + } + } + + resultJSON, marshalErr := json.Marshal(result) + if marshalErr != nil { + return http.StatusInternalServerError, mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Error: &mcpError{Code: -32603, Message: "failed to encode tool result"}, + } + } + + return http.StatusOK, mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Result: map[string]any{ + "content": []mcpToolResultContent{ + {Type: "text", Text: string(resultJSON)}, + }, + "structuredContent": result, + }, + } +} diff --git a/server/mcp_call_test.go b/server/mcp_call_test.go new file mode 100644 index 000000000..eaac49c52 --- /dev/null +++ b/server/mcp_call_test.go @@ -0,0 +1,181 @@ +package server + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/server/common" + "github.com/gin-gonic/gin" +) + +func TestMCPToolsListRequiresInitializedSession(t *testing.T) { + gin.SetMode(gin.TestMode) + openListMCP.sessions = map[string]*mcpSession{ + "s1": {id: "s1", userID: 1}, + } + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + openListMCP.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"tools/list" + }`)) + req.Header.Set("Accept", "application/json") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(mcpSessionHeader, "s1") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusBadRequest) + } + resp := decodeMCPResponse(t, w) + if resp.Error == nil || resp.Error.Code != -32002 { + t.Fatalf("unexpected error response: %+v", resp.Error) + } +} + +func TestMCPToolsListSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + openListMCP.sessions = map[string]*mcpSession{ + "s2": {id: "s2", userID: 1, initialized: true}, + } + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + openListMCP.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":2, + "method":"tools/list" + }`)) + req.Header.Set("Accept", "application/json") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(mcpSessionHeader, "s2") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) + } + resp := decodeMCPResponse(t, w) + if resp.Error != nil { + t.Fatalf("unexpected error response: %+v", resp.Error) + } + + result, ok := resp.Result.(map[string]any) + if !ok { + t.Fatalf("unexpected result type: %T", resp.Result) + } + tools, ok := result["tools"].([]any) + if !ok || len(tools) != 1 { + t.Fatalf("unexpected tools payload: %#v", result["tools"]) + } + tool, ok := tools[0].(map[string]any) + if !ok { + t.Fatalf("unexpected tool payload: %#v", tools[0]) + } + if tool["name"] != "openlist.fs.list" { + t.Fatalf("unexpected tool name: got %v", tool["name"]) + } +} + +func TestMCPToolsCallUnknownTool(t *testing.T) { + gin.SetMode(gin.TestMode) + openListMCP.sessions = map[string]*mcpSession{ + "s3": {id: "s3", userID: 1, initialized: true}, + } + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + openListMCP.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":3, + "method":"tools/call", + "params":{"name":"openlist.fs.unknown","arguments":{"path":"/"}} + }`)) + req.Header.Set("Accept", "application/json") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(mcpSessionHeader, "s3") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) + } + resp := decodeMCPResponse(t, w) + if resp.Error == nil || resp.Error.Code != -32601 { + t.Fatalf("unexpected error response: %+v", resp.Error) + } +} + +func TestMCPToolsCallInvalidParams(t *testing.T) { + gin.SetMode(gin.TestMode) + openListMCP.sessions = map[string]*mcpSession{ + "s4": {id: "s4", userID: 1, initialized: true}, + } + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + openListMCP.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":4, + "method":"tools/call", + "params":{"name":"openlist.fs.list","arguments":"bad"} + }`)) + req.Header.Set("Accept", "application/json") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(mcpSessionHeader, "s4") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) + } + resp := decodeMCPResponse(t, w) + if resp.Error != nil { + t.Fatalf("expected tool error result, got protocol error: %+v", resp.Error) + } + result, ok := resp.Result.(map[string]any) + if !ok { + t.Fatalf("unexpected result type: %T", resp.Result) + } + if isError, ok := result["isError"].(bool); !ok || !isError { + t.Fatalf("unexpected tool error flag: %#v", result["isError"]) + } +} + +func decodeMCPResponse(t *testing.T, w *httptest.ResponseRecorder) mcpResponse { + t.Helper() + + var resp mcpResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + return resp +} diff --git a/server/mcp_fs_list.go b/server/mcp_fs_list.go new file mode 100644 index 000000000..a0ccf3feb --- /dev/null +++ b/server/mcp_fs_list.go @@ -0,0 +1,198 @@ +package server + +import ( + "context" + "encoding/json" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/fs" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/OpenListTeam/OpenList/v4/internal/setting" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/OpenListTeam/OpenList/v4/server/common" + "github.com/OpenListTeam/OpenList/v4/server/handles" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +type mcpFSListArgs struct { + Path string `json:"path"` + Password string `json:"password"` + Refresh bool `json:"refresh"` + Page int `json:"page"` + PerPage int `json:"per_page"` +} + +type mcpToolCallEnvelope struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` +} + +func (s *mcpServer) callFSList(c *gin.Context, raw json.RawMessage) (any, *mcpError) { + args, mcpErr := parseMCPFSListArgs(raw) + if mcpErr != nil { + return nil, mcpErr + } + + user, ok := c.Request.Context().Value(conf.UserKey).(*model.User) + if !ok || user == nil { + return nil, &mcpError{Code: -32603, Message: "missing user context"} + } + if user.IsGuest() && user.Disabled { + return nil, &mcpError{Code: -32001, Message: "guest user is disabled"} + } + + reqPath, err := user.JoinPath(args.Path) + if err != nil { + return nil, &mcpError{Code: -32003, Message: err.Error()} + } + + meta, err := op.GetNearestMeta(reqPath) + if err != nil && !errors.Is(errors.Cause(err), errs.MetaNotFound) { + return nil, &mcpError{Code: -32603, Message: err.Error()} + } + if !common.CanAccess(user, meta, reqPath, args.Password) { + return nil, &mcpError{Code: -32003, Message: "password is incorrect or you have no permission"} + } + + write := common.CanWrite(user, meta, reqPath) + writeContentBypass := common.CanWriteContentBypassUserPerms(meta, reqPath) + canWriteContentAtPath := write && (user.CanWriteContent() || writeContentBypass) + if args.Refresh && !canWriteContentAtPath { + return nil, &mcpError{Code: -32003, Message: "refresh without permission"} + } + + ctx := context.WithValue(c.Request.Context(), conf.MetaKey, meta) + objs, err := fs.List(ctx, reqPath, &fs.ListArgs{ + Refresh: args.Refresh, + WithStorageDetails: !user.IsGuest() && !setting.GetBool(conf.HideStorageDetails), + }) + if err != nil { + return nil, &mcpError{Code: -32603, Message: err.Error()} + } + + total, paged := paginateMCPObjs(objs, args.Page, args.PerPage) + return handles.FsListResp{ + Content: toMCPObjResp(paged, reqPath, isEncryptMCP(meta, reqPath)), + Total: int64(total), + Write: write, + WriteContentBypass: writeContentBypass, + Provider: detectMCPProvider(reqPath, paged), + Readme: getMCPReadme(meta, reqPath), + Header: getMCPHeader(meta, reqPath), + }, nil +} + +func parseMCPFSListArgs(raw json.RawMessage) (*mcpFSListArgs, *mcpError) { + args := &mcpFSListArgs{ + Page: 1, + PerPage: model.MaxInt, + } + if len(raw) == 0 || string(raw) == "null" { + return args, nil + } + + if err := json.Unmarshal(raw, args); err == nil { + normalizeMCPFSListArgs(args) + return args, nil + } + + var envelope mcpToolCallEnvelope + if err := json.Unmarshal(raw, &envelope); err != nil { + return nil, &mcpError{Code: -32602, Message: "invalid openlist.fs.list arguments"} + } + if len(envelope.Arguments) > 0 { + if err := json.Unmarshal(envelope.Arguments, args); err != nil { + return nil, &mcpError{Code: -32602, Message: "invalid openlist.fs.list arguments"} + } + } + normalizeMCPFSListArgs(args) + return args, nil +} + +func normalizeMCPFSListArgs(args *mcpFSListArgs) { + pageReq := model.PageReq{ + Page: args.Page, + PerPage: args.PerPage, + } + pageReq.Validate() + args.Page = pageReq.Page + args.PerPage = pageReq.PerPage +} + +func paginateMCPObjs(objs []model.Obj, page, perPage int) (int, []model.Obj) { + total := len(objs) + start := (page - 1) * perPage + if start > total { + return total, []model.Obj{} + } + end := start + perPage + if end > total { + end = total + } + return total, objs[start:end] +} + +func toMCPObjResp(objs []model.Obj, parent string, encrypt bool) []handles.ObjResp { + resp := make([]handles.ObjResp, 0, len(objs)) + for _, obj := range objs { + thumb, _ := model.GetThumb(obj) + mountDetails, _ := model.GetStorageDetails(obj) + resp = append(resp, handles.ObjResp{ + Name: obj.GetName(), + Size: obj.GetSize(), + IsDir: obj.IsDir(), + Modified: obj.ModTime(), + Created: obj.CreateTime(), + Sign: common.Sign(obj, parent, encrypt), + Thumb: thumb, + Type: utils.GetObjType(obj.GetName(), obj.IsDir()), + HashInfoStr: obj.GetHash().String(), + HashInfo: obj.GetHash().Export(), + MountDetails: mountDetails, + }) + } + return resp +} + +func detectMCPProvider(reqPath string, objs []model.Obj) string { + storage, err := fs.GetStorage(reqPath, &fs.GetStoragesArgs{}) + if err == nil && storage != nil { + return storage.Config().Name + } + for _, obj := range objs { + if provider, ok := model.GetProvider(obj); ok && provider != "" { + return provider + } + } + return "unknown" +} + +func getMCPReadme(meta *model.Meta, path string) string { + if meta != nil && common.MetaCoversPath(meta.Path, path, meta.RSub) { + return meta.Readme + } + return "" +} + +func getMCPHeader(meta *model.Meta, path string) string { + if meta != nil && common.MetaCoversPath(meta.Path, path, meta.HeaderSub) { + return meta.Header + } + return "" +} + +func isEncryptMCP(meta *model.Meta, path string) bool { + if common.IsStorageSignEnabled(path) { + return true + } + if meta == nil || meta.Password == "" { + return false + } + if !common.MetaCoversPath(meta.Path, path, meta.PSub) { + return false + } + return true +} diff --git a/server/mcp_test.go b/server/mcp_test.go new file mode 100644 index 000000000..bc61e04de --- /dev/null +++ b/server/mcp_test.go @@ -0,0 +1,110 @@ +package server + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/server/common" + "github.com/gin-gonic/gin" +) + +func TestMCPInitializeCreatesSession(t *testing.T) { + gin.SetMode(gin.TestMode) + openListMCP.sessions = map[string]*mcpSession{} + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + openListMCP.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"initialize", + "params":{ + "protocolVersion":"2025-06-18", + "capabilities":{}, + "clientInfo":{"name":"test-client","version":"1.0.0"} + } + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) + } + if got := w.Header().Get(mcpSessionHeader); got == "" { + t.Fatal("expected session header to be set") + } + + var resp mcpResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp.Error != nil { + t.Fatalf("unexpected error response: %+v", resp.Error) + } + + result, ok := resp.Result.(map[string]any) + if !ok { + t.Fatalf("unexpected result type: %T", resp.Result) + } + if result["protocolVersion"] != mcpProtocolVersion { + t.Fatalf("unexpected protocol version: got %v want %s", result["protocolVersion"], mcpProtocolVersion) + } +} + +func TestMCPDeleteRemovesSession(t *testing.T) { + gin.SetMode(gin.TestMode) + openListMCP.sessions = map[string]*mcpSession{} + + session := openListMCP.createSession(1) + r := gin.New() + r.DELETE("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + openListMCP.handleDelete(c) + }) + + req := httptest.NewRequest(http.MethodDelete, "http://example.com/mcp", nil) + req.Header.Set("Origin", "http://example.com") + req.Header.Set(mcpSessionHeader, session.id) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNoContent { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusNoContent) + } + if _, ok := openListMCP.getSession(session.id); ok { + t.Fatal("expected session to be deleted") + } +} + +func TestMCPGetReturnsMethodNotAllowed(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.GET("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + openListMCP.handleGet(c) + }) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/mcp", nil) + req.Header.Set("Origin", "http://example.com") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusMethodNotAllowed) + } +} diff --git a/server/mcp_tools.go b/server/mcp_tools.go new file mode 100644 index 000000000..afed29143 --- /dev/null +++ b/server/mcp_tools.go @@ -0,0 +1,80 @@ +package server + +import "encoding/json" + +type mcpTool struct { + Name string `json:"name"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + InputSchema mcpToolInputSchema `json:"inputSchema"` +} + +type mcpToolInputSchema struct { + Type string `json:"type"` + Properties map[string]mcpToolSchemaProperty `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` +} + +type mcpToolSchemaProperty struct { + Type string `json:"type,omitempty"` + Description string `json:"description,omitempty"` +} + +type mcpToolsListParams struct { + Cursor string `json:"cursor,omitempty"` +} + +var openListMCPTools = []mcpTool{ + { + Name: "openlist.fs.list", + Title: "OpenList FS List", + Description: "List files and directories under a mount path that the current user can access.", + InputSchema: mcpToolInputSchema{ + Type: "object", + Properties: map[string]mcpToolSchemaProperty{ + "path": { + Type: "string", + Description: "Mount path to list, for example \"/\" or \"/movies\".", + }, + "refresh": { + Type: "boolean", + Description: "Refresh the directory listing before returning results.", + }, + "password": { + Type: "string", + Description: "Optional password for protected paths.", + }, + "page": { + Type: "integer", + Description: "1-based page number.", + }, + "per_page": { + Type: "integer", + Description: "Page size.", + }, + }, + Required: []string{"path"}, + }, + }, +} + +func (s *mcpServer) handleToolsList(req mcpRequest) mcpResponse { + var params mcpToolsListParams + if len(req.Params) > 0 { + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Error: &mcpError{Code: -32602, Message: "invalid tools/list params"}, + } + } + } + + return mcpResponse{ + JSONRPC: "2.0", + ID: req.ID, + Result: map[string]any{ + "tools": openListMCPTools, + }, + } +} diff --git a/server/router.go b/server/router.go index 57d1166ae..3c7b99adb 100644 --- a/server/router.go +++ b/server/router.go @@ -41,6 +41,7 @@ func Init(e *gin.Engine) { } WebDav(g.Group("/dav")) S3(g.Group("/s3")) + MCP(g) downloadLimiter := middlewares.DownloadRateLimiter(stream.ClientDownloadLimit) signCheck := middlewares.Down(sign.Verify) From 59351e8a2acd03387a1524801e86e8f300bc48c4 Mon Sep 17 00:00:00 2001 From: jyxjjj <773933146@qq.com> Date: Mon, 18 May 2026 23:27:55 +0800 Subject: [PATCH 02/16] refactor(mcp): move MCP implementation into server/mcp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 MCP 实现与测试迁移到 server/mcp 目录 - 保留 server/mcp.go 作为路由接入包装入口 - 对齐 webdav、s3、ftp、sftp 的目录组织风格 Co-authored-by: Codex <267193182+codex@users.noreply.github.com> Signed-off-by: jyxjjj <16695261+jyxjjj@users.noreply.github.com> --- server/mcp.go | 331 +----------------- server/{mcp_call.go => mcp/call.go} | 36 +- server/{mcp_call_test.go => mcp/call_test.go} | 60 ++-- server/{mcp_fs_list.go => mcp/fs_list.go} | 67 ++-- server/mcp/handler.go | 331 ++++++++++++++++++ server/{mcp_test.go => mcp/handler_test.go} | 38 +- server/{mcp_tools.go => mcp/tools.go} | 42 +-- 7 files changed, 442 insertions(+), 463 deletions(-) rename server/{mcp_call.go => mcp/call.go} (57%) rename server/{mcp_call_test.go => mcp/call_test.go} (72%) rename server/{mcp_fs_list.go => mcp/fs_list.go} (68%) create mode 100644 server/mcp/handler.go rename server/{mcp_test.go => mcp/handler_test.go} (70%) rename server/{mcp_tools.go => mcp/tools.go} (56%) diff --git a/server/mcp.go b/server/mcp.go index 05c24b343..9d85c02cc 100644 --- a/server/mcp.go +++ b/server/mcp.go @@ -1,337 +1,10 @@ package server import ( - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "strings" - "sync" - "time" - - "github.com/OpenListTeam/OpenList/v4/internal/conf" - "github.com/OpenListTeam/OpenList/v4/internal/model" - "github.com/OpenListTeam/OpenList/v4/pkg/utils/random" - "github.com/OpenListTeam/OpenList/v4/server/common" - "github.com/OpenListTeam/OpenList/v4/server/middlewares" + "github.com/OpenListTeam/OpenList/v4/server/mcp" "github.com/gin-gonic/gin" ) -const ( - mcpProtocolVersion = "2025-06-18" - mcpSessionHeader = "Mcp-Session-Id" -) - -type mcpSession struct { - id string - userID uint - initialized bool - createdAt time.Time -} - -type mcpServer struct { - mu sync.RWMutex - sessions map[string]*mcpSession -} - -type mcpRequest struct { - JSONRPC string `json:"jsonrpc"` - ID any `json:"id,omitempty"` - Method string `json:"method"` - Params json.RawMessage `json:"params,omitempty"` -} - -type mcpResponse struct { - JSONRPC string `json:"jsonrpc"` - ID any `json:"id"` - Result any `json:"result,omitempty"` - Error *mcpError `json:"error,omitempty"` -} - -type mcpError struct { - Code int `json:"code"` - Message string `json:"message"` -} - -type mcpInitializeParams struct { - ProtocolVersion string `json:"protocolVersion"` - Capabilities map[string]any `json:"capabilities"` - ClientInfo map[string]any `json:"clientInfo"` -} - -var openListMCP = &mcpServer{ - sessions: map[string]*mcpSession{}, -} - func MCP(g *gin.RouterGroup) { - mcp := g.Group("/mcp", middlewares.Auth(false), middlewares.AuthAdmin) - mcp.GET("", openListMCP.handleGet) - mcp.POST("", openListMCP.handlePost) - mcp.DELETE("", openListMCP.handleDelete) -} - -func (s *mcpServer) handleGet(c *gin.Context) { - if !validateMCPOrigin(c.Request) { - c.Status(http.StatusForbidden) - return - } - c.Status(http.StatusMethodNotAllowed) -} - -func (s *mcpServer) handlePost(c *gin.Context) { - if !validateMCPOrigin(c.Request) { - c.Status(http.StatusForbidden) - return - } - if !acceptsMCPJSON(c.GetHeader("Accept")) { - c.Status(http.StatusNotAcceptable) - return - } - - body, err := io.ReadAll(io.LimitReader(c.Request.Body, 1<<20)) - if err != nil { - c.JSON(http.StatusBadRequest, mcpResponse{ - JSONRPC: "2.0", - Error: &mcpError{Code: -32700, Message: "failed to read request body"}, - }) - return - } - - var req mcpRequest - if err := json.Unmarshal(body, &req); err != nil { - c.JSON(http.StatusBadRequest, mcpResponse{ - JSONRPC: "2.0", - Error: &mcpError{Code: -32700, Message: "parse error"}, - }) - return - } - if req.JSONRPC != "2.0" { - c.JSON(http.StatusBadRequest, mcpResponse{ - JSONRPC: "2.0", - ID: req.ID, - Error: &mcpError{Code: -32600, Message: "invalid request"}, - }) - return - } - - if req.Method == "initialize" { - s.handleInitialize(c, req) - return - } - - sessionID := c.GetHeader(mcpSessionHeader) - session, ok := s.getSession(sessionID) - if !ok { - c.Status(http.StatusBadRequest) - c.JSON(http.StatusBadRequest, mcpResponse{ - JSONRPC: "2.0", - ID: req.ID, - Error: &mcpError{Code: -32000, Message: "missing or invalid MCP session"}, - }) - return - } - - user := c.Request.Context().Value(conf.UserKey).(*model.User) - if session.userID != user.ID { - c.Status(http.StatusNotFound) - c.JSON(http.StatusNotFound, mcpResponse{ - JSONRPC: "2.0", - ID: req.ID, - Error: &mcpError{Code: -32001, Message: "session not found"}, - }) - return - } - - switch req.Method { - case "ping": - c.JSON(http.StatusOK, mcpResponse{JSONRPC: "2.0", ID: req.ID, Result: map[string]any{}}) - case "notifications/initialized": - s.markSessionInitialized(sessionID) - c.Status(http.StatusAccepted) - case "tools/list": - if !s.sessionInitialized(sessionID) { - c.JSON(http.StatusBadRequest, mcpResponse{ - JSONRPC: "2.0", - ID: req.ID, - Error: &mcpError{Code: -32002, Message: "MCP session not initialized"}, - }) - return - } - c.JSON(http.StatusOK, s.handleToolsList(req)) - case "tools/call": - if !s.sessionInitialized(sessionID) { - c.JSON(http.StatusBadRequest, mcpResponse{ - JSONRPC: "2.0", - ID: req.ID, - Error: &mcpError{Code: -32002, Message: "MCP session not initialized"}, - }) - return - } - status, resp := s.handleToolsCall(c, req) - c.JSON(status, resp) - default: - c.JSON(http.StatusOK, mcpResponse{ - JSONRPC: "2.0", - ID: req.ID, - Error: &mcpError{Code: -32601, Message: fmt.Sprintf("method %q not implemented yet", req.Method)}, - }) - } -} - -func (s *mcpServer) handleInitialize(c *gin.Context, req mcpRequest) { - var params mcpInitializeParams - if len(req.Params) > 0 { - if err := json.Unmarshal(req.Params, ¶ms); err != nil { - c.JSON(http.StatusBadRequest, mcpResponse{ - JSONRPC: "2.0", - ID: req.ID, - Error: &mcpError{Code: -32602, Message: "invalid initialize params"}, - }) - return - } - } - if params.ProtocolVersion != "" && params.ProtocolVersion != mcpProtocolVersion { - c.JSON(http.StatusBadRequest, mcpResponse{ - JSONRPC: "2.0", - ID: req.ID, - Error: &mcpError{Code: -32602, Message: "unsupported protocol version"}, - }) - return - } - - session := s.createSession(c.Request.Context().Value(conf.UserKey).(*model.User).ID) - c.Header(mcpSessionHeader, session.id) - c.JSON(http.StatusOK, mcpResponse{ - JSONRPC: "2.0", - ID: req.ID, - Result: map[string]any{ - "protocolVersion": mcpProtocolVersion, - "capabilities": map[string]any{ - "tools": map[string]any{ - "listChanged": false, - }, - }, - "serverInfo": map[string]any{ - "name": "OpenList MCP", - "version": conf.Version, - }, - "instructions": "Complete initialization with notifications/initialized, then use tools/list and tools/call. The first available tool is openlist.fs.list.", - }, - }) -} - -func (s *mcpServer) handleDelete(c *gin.Context) { - if !validateMCPOrigin(c.Request) { - c.Status(http.StatusForbidden) - return - } - - session, ok := s.getSession(c.GetHeader(mcpSessionHeader)) - if !ok { - c.Status(http.StatusNotFound) - return - } - - user := c.Request.Context().Value(conf.UserKey).(*model.User) - if session.userID != user.ID { - c.Status(http.StatusNotFound) - return - } - - s.deleteSession(session.id) - c.Status(http.StatusNoContent) -} - -func (s *mcpServer) createSession(userID uint) *mcpSession { - s.mu.Lock() - defer s.mu.Unlock() - - session := &mcpSession{ - id: random.Token(), - userID: userID, - createdAt: time.Now(), - } - s.sessions[session.id] = session - return session -} - -func (s *mcpServer) getSession(id string) (mcpSession, bool) { - if id == "" { - return mcpSession{}, false - } - s.mu.RLock() - defer s.mu.RUnlock() - session, ok := s.sessions[id] - if !ok || session == nil { - return mcpSession{}, false - } - return *session, true -} - -func (s *mcpServer) markSessionInitialized(id string) bool { - s.mu.Lock() - defer s.mu.Unlock() - session, ok := s.sessions[id] - if !ok || session == nil { - return false - } - session.initialized = true - return true -} - -func (s *mcpServer) sessionInitialized(id string) bool { - s.mu.RLock() - defer s.mu.RUnlock() - session, ok := s.sessions[id] - return ok && session != nil && session.initialized -} - -func (s *mcpServer) deleteSession(id string) { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.sessions, id) -} - -func acceptsMCPJSON(accept string) bool { - if accept == "" { - return false - } - return strings.Contains(accept, "application/json") || - strings.Contains(accept, "*/*") -} - -func validateMCPOrigin(r *http.Request) bool { - origin := r.Header.Get("Origin") - if origin == "" { - return true - } - - originURL, err := url.Parse(origin) - if err != nil || originURL.Host == "" { - return false - } - if strings.EqualFold(originURL.Host, r.Host) { - return strings.EqualFold(originURL.Scheme, requestScheme(r)) - } - - siteURL := common.GetApiUrlFromRequest(r) - if siteURL == "" { - return false - } - siteParsed, err := url.Parse(siteURL) - if err != nil { - return false - } - if strings.EqualFold(originURL.Host, siteParsed.Host) && strings.EqualFold(originURL.Scheme, siteParsed.Scheme) { - return true - } - return false -} - -func requestScheme(r *http.Request) string { - if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { - return "https" - } - return "http" + mcp.Register(g) } diff --git a/server/mcp_call.go b/server/mcp/call.go similarity index 57% rename from server/mcp_call.go rename to server/mcp/call.go index 76aafd7b9..7734e5786 100644 --- a/server/mcp_call.go +++ b/server/mcp/call.go @@ -1,4 +1,4 @@ -package server +package mcp import ( "encoding/json" @@ -7,54 +7,54 @@ import ( "github.com/gin-gonic/gin" ) -type mcpToolCallParams struct { +type toolCallParams struct { Name string `json:"name"` Arguments json.RawMessage `json:"arguments"` } -type mcpToolResultContent struct { +type toolResultContent struct { Type string `json:"type"` Text string `json:"text"` } -func (s *mcpServer) handleToolsCall(c *gin.Context, req mcpRequest) (int, mcpResponse) { - var params mcpToolCallParams +func (s *Server) handleToolsCall(c *gin.Context, req request) (int, response) { + var params toolCallParams if len(req.Params) == 0 { - return http.StatusBadRequest, mcpResponse{ + return http.StatusBadRequest, response{ JSONRPC: "2.0", ID: req.ID, - Error: &mcpError{Code: -32602, Message: "invalid tools/call params"}, + Error: &rpcError{Code: -32602, Message: "invalid tools/call params"}, } } if err := json.Unmarshal(req.Params, ¶ms); err != nil || params.Name == "" { - return http.StatusBadRequest, mcpResponse{ + return http.StatusBadRequest, response{ JSONRPC: "2.0", ID: req.ID, - Error: &mcpError{Code: -32602, Message: "invalid tools/call params"}, + Error: &rpcError{Code: -32602, Message: "invalid tools/call params"}, } } var ( result any - err *mcpError + err *rpcError ) switch params.Name { case "openlist.fs.list": result, err = s.callFSList(c, params.Arguments) default: - return http.StatusOK, mcpResponse{ + return http.StatusOK, response{ JSONRPC: "2.0", ID: req.ID, - Error: &mcpError{Code: -32601, Message: "unknown tool"}, + Error: &rpcError{Code: -32601, Message: "unknown tool"}, } } if err != nil { - return http.StatusOK, mcpResponse{ + return http.StatusOK, response{ JSONRPC: "2.0", ID: req.ID, Result: map[string]any{ - "content": []mcpToolResultContent{ + "content": []toolResultContent{ {Type: "text", Text: err.Message}, }, "isError": true, @@ -64,18 +64,18 @@ func (s *mcpServer) handleToolsCall(c *gin.Context, req mcpRequest) (int, mcpRes resultJSON, marshalErr := json.Marshal(result) if marshalErr != nil { - return http.StatusInternalServerError, mcpResponse{ + return http.StatusInternalServerError, response{ JSONRPC: "2.0", ID: req.ID, - Error: &mcpError{Code: -32603, Message: "failed to encode tool result"}, + Error: &rpcError{Code: -32603, Message: "failed to encode tool result"}, } } - return http.StatusOK, mcpResponse{ + return http.StatusOK, response{ JSONRPC: "2.0", ID: req.ID, Result: map[string]any{ - "content": []mcpToolResultContent{ + "content": []toolResultContent{ {Type: "text", Text: string(resultJSON)}, }, "structuredContent": result, diff --git a/server/mcp_call_test.go b/server/mcp/call_test.go similarity index 72% rename from server/mcp_call_test.go rename to server/mcp/call_test.go index eaac49c52..43a501524 100644 --- a/server/mcp_call_test.go +++ b/server/mcp/call_test.go @@ -1,4 +1,4 @@ -package server +package mcp import ( "encoding/json" @@ -13,16 +13,16 @@ import ( "github.com/gin-gonic/gin" ) -func TestMCPToolsListRequiresInitializedSession(t *testing.T) { +func TestToolsListRequiresInitializedSession(t *testing.T) { gin.SetMode(gin.TestMode) - openListMCP.sessions = map[string]*mcpSession{ + defaultServer.sessions = map[string]*session{ "s1": {id: "s1", userID: 1}, } r := gin.New() r.POST("/mcp", func(c *gin.Context) { common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) - openListMCP.handlePost(c) + defaultServer.handlePost(c) }) req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ @@ -32,7 +32,7 @@ func TestMCPToolsListRequiresInitializedSession(t *testing.T) { }`)) req.Header.Set("Accept", "application/json") req.Header.Set("Origin", "http://example.com") - req.Header.Set(mcpSessionHeader, "s1") + req.Header.Set(SessionHeader, "s1") w := httptest.NewRecorder() r.ServeHTTP(w, req) @@ -40,22 +40,22 @@ func TestMCPToolsListRequiresInitializedSession(t *testing.T) { if w.Code != http.StatusBadRequest { t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusBadRequest) } - resp := decodeMCPResponse(t, w) + resp := decodeResponse(t, w) if resp.Error == nil || resp.Error.Code != -32002 { t.Fatalf("unexpected error response: %+v", resp.Error) } } -func TestMCPToolsListSuccess(t *testing.T) { +func TestToolsListSuccess(t *testing.T) { gin.SetMode(gin.TestMode) - openListMCP.sessions = map[string]*mcpSession{ + defaultServer.sessions = map[string]*session{ "s2": {id: "s2", userID: 1, initialized: true}, } r := gin.New() r.POST("/mcp", func(c *gin.Context) { common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) - openListMCP.handlePost(c) + defaultServer.handlePost(c) }) req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ @@ -65,7 +65,7 @@ func TestMCPToolsListSuccess(t *testing.T) { }`)) req.Header.Set("Accept", "application/json") req.Header.Set("Origin", "http://example.com") - req.Header.Set(mcpSessionHeader, "s2") + req.Header.Set(SessionHeader, "s2") w := httptest.NewRecorder() r.ServeHTTP(w, req) @@ -73,7 +73,7 @@ func TestMCPToolsListSuccess(t *testing.T) { if w.Code != http.StatusOK { t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) } - resp := decodeMCPResponse(t, w) + resp := decodeResponse(t, w) if resp.Error != nil { t.Fatalf("unexpected error response: %+v", resp.Error) } @@ -86,25 +86,18 @@ func TestMCPToolsListSuccess(t *testing.T) { if !ok || len(tools) != 1 { t.Fatalf("unexpected tools payload: %#v", result["tools"]) } - tool, ok := tools[0].(map[string]any) - if !ok { - t.Fatalf("unexpected tool payload: %#v", tools[0]) - } - if tool["name"] != "openlist.fs.list" { - t.Fatalf("unexpected tool name: got %v", tool["name"]) - } } -func TestMCPToolsCallUnknownTool(t *testing.T) { +func TestToolsCallUnknownTool(t *testing.T) { gin.SetMode(gin.TestMode) - openListMCP.sessions = map[string]*mcpSession{ + defaultServer.sessions = map[string]*session{ "s3": {id: "s3", userID: 1, initialized: true}, } r := gin.New() r.POST("/mcp", func(c *gin.Context) { common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) - openListMCP.handlePost(c) + defaultServer.handlePost(c) }) req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ @@ -115,7 +108,7 @@ func TestMCPToolsCallUnknownTool(t *testing.T) { }`)) req.Header.Set("Accept", "application/json") req.Header.Set("Origin", "http://example.com") - req.Header.Set(mcpSessionHeader, "s3") + req.Header.Set(SessionHeader, "s3") w := httptest.NewRecorder() r.ServeHTTP(w, req) @@ -123,22 +116,22 @@ func TestMCPToolsCallUnknownTool(t *testing.T) { if w.Code != http.StatusOK { t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) } - resp := decodeMCPResponse(t, w) + resp := decodeResponse(t, w) if resp.Error == nil || resp.Error.Code != -32601 { t.Fatalf("unexpected error response: %+v", resp.Error) } } -func TestMCPToolsCallInvalidParams(t *testing.T) { +func TestToolsCallInvalidParams(t *testing.T) { gin.SetMode(gin.TestMode) - openListMCP.sessions = map[string]*mcpSession{ + defaultServer.sessions = map[string]*session{ "s4": {id: "s4", userID: 1, initialized: true}, } r := gin.New() r.POST("/mcp", func(c *gin.Context) { common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) - openListMCP.handlePost(c) + defaultServer.handlePost(c) }) req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ @@ -149,7 +142,7 @@ func TestMCPToolsCallInvalidParams(t *testing.T) { }`)) req.Header.Set("Accept", "application/json") req.Header.Set("Origin", "http://example.com") - req.Header.Set(mcpSessionHeader, "s4") + req.Header.Set(SessionHeader, "s4") w := httptest.NewRecorder() r.ServeHTTP(w, req) @@ -157,23 +150,16 @@ func TestMCPToolsCallInvalidParams(t *testing.T) { if w.Code != http.StatusOK { t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) } - resp := decodeMCPResponse(t, w) + resp := decodeResponse(t, w) if resp.Error != nil { t.Fatalf("expected tool error result, got protocol error: %+v", resp.Error) } - result, ok := resp.Result.(map[string]any) - if !ok { - t.Fatalf("unexpected result type: %T", resp.Result) - } - if isError, ok := result["isError"].(bool); !ok || !isError { - t.Fatalf("unexpected tool error flag: %#v", result["isError"]) - } } -func decodeMCPResponse(t *testing.T, w *httptest.ResponseRecorder) mcpResponse { +func decodeResponse(t *testing.T, w *httptest.ResponseRecorder) response { t.Helper() - var resp mcpResponse + var resp response if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("failed to decode response: %v", err) } diff --git a/server/mcp_fs_list.go b/server/mcp/fs_list.go similarity index 68% rename from server/mcp_fs_list.go rename to server/mcp/fs_list.go index a0ccf3feb..7410bb14e 100644 --- a/server/mcp_fs_list.go +++ b/server/mcp/fs_list.go @@ -1,4 +1,4 @@ -package server +package mcp import ( "context" @@ -17,7 +17,7 @@ import ( "github.com/pkg/errors" ) -type mcpFSListArgs struct { +type fsListArgs struct { Path string `json:"path"` Password string `json:"password"` Refresh bool `json:"refresh"` @@ -25,43 +25,43 @@ type mcpFSListArgs struct { PerPage int `json:"per_page"` } -type mcpToolCallEnvelope struct { +type toolCallEnvelope struct { Name string `json:"name"` Arguments json.RawMessage `json:"arguments"` } -func (s *mcpServer) callFSList(c *gin.Context, raw json.RawMessage) (any, *mcpError) { - args, mcpErr := parseMCPFSListArgs(raw) +func (s *Server) callFSList(c *gin.Context, raw json.RawMessage) (any, *rpcError) { + args, mcpErr := parseFSListArgs(raw) if mcpErr != nil { return nil, mcpErr } user, ok := c.Request.Context().Value(conf.UserKey).(*model.User) if !ok || user == nil { - return nil, &mcpError{Code: -32603, Message: "missing user context"} + return nil, &rpcError{Code: -32603, Message: "missing user context"} } if user.IsGuest() && user.Disabled { - return nil, &mcpError{Code: -32001, Message: "guest user is disabled"} + return nil, &rpcError{Code: -32001, Message: "guest user is disabled"} } reqPath, err := user.JoinPath(args.Path) if err != nil { - return nil, &mcpError{Code: -32003, Message: err.Error()} + return nil, &rpcError{Code: -32003, Message: err.Error()} } meta, err := op.GetNearestMeta(reqPath) if err != nil && !errors.Is(errors.Cause(err), errs.MetaNotFound) { - return nil, &mcpError{Code: -32603, Message: err.Error()} + return nil, &rpcError{Code: -32603, Message: err.Error()} } if !common.CanAccess(user, meta, reqPath, args.Password) { - return nil, &mcpError{Code: -32003, Message: "password is incorrect or you have no permission"} + return nil, &rpcError{Code: -32003, Message: "password is incorrect or you have no permission"} } write := common.CanWrite(user, meta, reqPath) writeContentBypass := common.CanWriteContentBypassUserPerms(meta, reqPath) canWriteContentAtPath := write && (user.CanWriteContent() || writeContentBypass) if args.Refresh && !canWriteContentAtPath { - return nil, &mcpError{Code: -32003, Message: "refresh without permission"} + return nil, &rpcError{Code: -32003, Message: "refresh without permission"} } ctx := context.WithValue(c.Request.Context(), conf.MetaKey, meta) @@ -70,23 +70,23 @@ func (s *mcpServer) callFSList(c *gin.Context, raw json.RawMessage) (any, *mcpEr WithStorageDetails: !user.IsGuest() && !setting.GetBool(conf.HideStorageDetails), }) if err != nil { - return nil, &mcpError{Code: -32603, Message: err.Error()} + return nil, &rpcError{Code: -32603, Message: err.Error()} } - total, paged := paginateMCPObjs(objs, args.Page, args.PerPage) + total, paged := paginateObjs(objs, args.Page, args.PerPage) return handles.FsListResp{ - Content: toMCPObjResp(paged, reqPath, isEncryptMCP(meta, reqPath)), + Content: toObjResp(paged, reqPath, isEncrypt(meta, reqPath)), Total: int64(total), Write: write, WriteContentBypass: writeContentBypass, - Provider: detectMCPProvider(reqPath, paged), - Readme: getMCPReadme(meta, reqPath), - Header: getMCPHeader(meta, reqPath), + Provider: detectProvider(reqPath, paged), + Readme: getReadme(meta, reqPath), + Header: getHeader(meta, reqPath), }, nil } -func parseMCPFSListArgs(raw json.RawMessage) (*mcpFSListArgs, *mcpError) { - args := &mcpFSListArgs{ +func parseFSListArgs(raw json.RawMessage) (*fsListArgs, *rpcError) { + args := &fsListArgs{ Page: 1, PerPage: model.MaxInt, } @@ -95,24 +95,24 @@ func parseMCPFSListArgs(raw json.RawMessage) (*mcpFSListArgs, *mcpError) { } if err := json.Unmarshal(raw, args); err == nil { - normalizeMCPFSListArgs(args) + normalizeFSListArgs(args) return args, nil } - var envelope mcpToolCallEnvelope + var envelope toolCallEnvelope if err := json.Unmarshal(raw, &envelope); err != nil { - return nil, &mcpError{Code: -32602, Message: "invalid openlist.fs.list arguments"} + return nil, &rpcError{Code: -32602, Message: "invalid openlist.fs.list arguments"} } if len(envelope.Arguments) > 0 { if err := json.Unmarshal(envelope.Arguments, args); err != nil { - return nil, &mcpError{Code: -32602, Message: "invalid openlist.fs.list arguments"} + return nil, &rpcError{Code: -32602, Message: "invalid openlist.fs.list arguments"} } } - normalizeMCPFSListArgs(args) + normalizeFSListArgs(args) return args, nil } -func normalizeMCPFSListArgs(args *mcpFSListArgs) { +func normalizeFSListArgs(args *fsListArgs) { pageReq := model.PageReq{ Page: args.Page, PerPage: args.PerPage, @@ -122,7 +122,7 @@ func normalizeMCPFSListArgs(args *mcpFSListArgs) { args.PerPage = pageReq.PerPage } -func paginateMCPObjs(objs []model.Obj, page, perPage int) (int, []model.Obj) { +func paginateObjs(objs []model.Obj, page, perPage int) (int, []model.Obj) { total := len(objs) start := (page - 1) * perPage if start > total { @@ -135,7 +135,7 @@ func paginateMCPObjs(objs []model.Obj, page, perPage int) (int, []model.Obj) { return total, objs[start:end] } -func toMCPObjResp(objs []model.Obj, parent string, encrypt bool) []handles.ObjResp { +func toObjResp(objs []model.Obj, parent string, encrypt bool) []handles.ObjResp { resp := make([]handles.ObjResp, 0, len(objs)) for _, obj := range objs { thumb, _ := model.GetThumb(obj) @@ -157,7 +157,7 @@ func toMCPObjResp(objs []model.Obj, parent string, encrypt bool) []handles.ObjRe return resp } -func detectMCPProvider(reqPath string, objs []model.Obj) string { +func detectProvider(reqPath string, objs []model.Obj) string { storage, err := fs.GetStorage(reqPath, &fs.GetStoragesArgs{}) if err == nil && storage != nil { return storage.Config().Name @@ -170,29 +170,26 @@ func detectMCPProvider(reqPath string, objs []model.Obj) string { return "unknown" } -func getMCPReadme(meta *model.Meta, path string) string { +func getReadme(meta *model.Meta, path string) string { if meta != nil && common.MetaCoversPath(meta.Path, path, meta.RSub) { return meta.Readme } return "" } -func getMCPHeader(meta *model.Meta, path string) string { +func getHeader(meta *model.Meta, path string) string { if meta != nil && common.MetaCoversPath(meta.Path, path, meta.HeaderSub) { return meta.Header } return "" } -func isEncryptMCP(meta *model.Meta, path string) bool { +func isEncrypt(meta *model.Meta, path string) bool { if common.IsStorageSignEnabled(path) { return true } if meta == nil || meta.Password == "" { return false } - if !common.MetaCoversPath(meta.Path, path, meta.PSub) { - return false - } - return true + return common.MetaCoversPath(meta.Path, path, meta.PSub) } diff --git a/server/mcp/handler.go b/server/mcp/handler.go new file mode 100644 index 000000000..95c23130a --- /dev/null +++ b/server/mcp/handler.go @@ -0,0 +1,331 @@ +package mcp + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/pkg/utils/random" + "github.com/OpenListTeam/OpenList/v4/server/common" + "github.com/OpenListTeam/OpenList/v4/server/middlewares" + "github.com/gin-gonic/gin" +) + +const ( + ProtocolVersion = "2025-06-18" + SessionHeader = "Mcp-Session-Id" +) + +type session struct { + id string + userID uint + initialized bool + createdAt time.Time +} + +type Server struct { + mu sync.RWMutex + sessions map[string]*session +} + +type request struct { + JSONRPC string `json:"jsonrpc"` + ID any `json:"id,omitempty"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` +} + +type response struct { + JSONRPC string `json:"jsonrpc"` + ID any `json:"id"` + Result any `json:"result,omitempty"` + Error *rpcError `json:"error,omitempty"` +} + +type rpcError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type initializeParams struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities map[string]any `json:"capabilities"` + ClientInfo map[string]any `json:"clientInfo"` +} + +var defaultServer = &Server{ + sessions: map[string]*session{}, +} + +func Register(g *gin.RouterGroup) { + mcpGroup := g.Group("/mcp", middlewares.Auth(false), middlewares.AuthAdmin) + mcpGroup.GET("", defaultServer.handleGet) + mcpGroup.POST("", defaultServer.handlePost) + mcpGroup.DELETE("", defaultServer.handleDelete) +} + +func (s *Server) handleGet(c *gin.Context) { + if !validateOrigin(c.Request) { + c.Status(http.StatusForbidden) + return + } + c.Status(http.StatusMethodNotAllowed) +} + +func (s *Server) handlePost(c *gin.Context) { + if !validateOrigin(c.Request) { + c.Status(http.StatusForbidden) + return + } + if !acceptsJSON(c.GetHeader("Accept")) { + c.Status(http.StatusNotAcceptable) + return + } + + body, err := io.ReadAll(io.LimitReader(c.Request.Body, 1<<20)) + if err != nil { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + Error: &rpcError{Code: -32700, Message: "failed to read request body"}, + }) + return + } + + var req request + if err := json.Unmarshal(body, &req); err != nil { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + Error: &rpcError{Code: -32700, Message: "parse error"}, + }) + return + } + if req.JSONRPC != "2.0" { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32600, Message: "invalid request"}, + }) + return + } + + if req.Method == "initialize" { + s.handleInitialize(c, req) + return + } + + sessionID := c.GetHeader(SessionHeader) + currentSession, ok := s.getSession(sessionID) + if !ok { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32000, Message: "missing or invalid MCP session"}, + }) + return + } + + user := c.Request.Context().Value(conf.UserKey).(*model.User) + if currentSession.userID != user.ID { + c.JSON(http.StatusNotFound, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32001, Message: "session not found"}, + }) + return + } + + switch req.Method { + case "ping": + c.JSON(http.StatusOK, response{JSONRPC: "2.0", ID: req.ID, Result: map[string]any{}}) + case "notifications/initialized": + s.markSessionInitialized(sessionID) + c.Status(http.StatusAccepted) + case "tools/list": + if !s.sessionInitialized(sessionID) { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32002, Message: "MCP session not initialized"}, + }) + return + } + c.JSON(http.StatusOK, s.handleToolsList(req)) + case "tools/call": + if !s.sessionInitialized(sessionID) { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32002, Message: "MCP session not initialized"}, + }) + return + } + status, resp := s.handleToolsCall(c, req) + c.JSON(status, resp) + default: + c.JSON(http.StatusOK, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32601, Message: fmt.Sprintf("method %q not implemented yet", req.Method)}, + }) + } +} + +func (s *Server) handleInitialize(c *gin.Context, req request) { + var params initializeParams + if len(req.Params) > 0 { + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32602, Message: "invalid initialize params"}, + }) + return + } + } + if params.ProtocolVersion != "" && params.ProtocolVersion != ProtocolVersion { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32602, Message: "unsupported protocol version"}, + }) + return + } + + currentSession := s.createSession(c.Request.Context().Value(conf.UserKey).(*model.User).ID) + c.Header(SessionHeader, currentSession.id) + c.JSON(http.StatusOK, response{ + JSONRPC: "2.0", + ID: req.ID, + Result: map[string]any{ + "protocolVersion": ProtocolVersion, + "capabilities": map[string]any{ + "tools": map[string]any{ + "listChanged": false, + }, + }, + "serverInfo": map[string]any{ + "name": "OpenList MCP", + "version": conf.Version, + }, + "instructions": "Complete initialization with notifications/initialized, then use tools/list and tools/call. The first available tool is openlist.fs.list.", + }, + }) +} + +func (s *Server) handleDelete(c *gin.Context) { + if !validateOrigin(c.Request) { + c.Status(http.StatusForbidden) + return + } + + currentSession, ok := s.getSession(c.GetHeader(SessionHeader)) + if !ok { + c.Status(http.StatusNotFound) + return + } + + user := c.Request.Context().Value(conf.UserKey).(*model.User) + if currentSession.userID != user.ID { + c.Status(http.StatusNotFound) + return + } + + s.deleteSession(currentSession.id) + c.Status(http.StatusNoContent) +} + +func (s *Server) createSession(userID uint) *session { + s.mu.Lock() + defer s.mu.Unlock() + + currentSession := &session{ + id: random.Token(), + userID: userID, + createdAt: time.Now(), + } + s.sessions[currentSession.id] = currentSession + return currentSession +} + +func (s *Server) getSession(id string) (session, bool) { + if id == "" { + return session{}, false + } + s.mu.RLock() + defer s.mu.RUnlock() + currentSession, ok := s.sessions[id] + if !ok || currentSession == nil { + return session{}, false + } + return *currentSession, true +} + +func (s *Server) markSessionInitialized(id string) bool { + s.mu.Lock() + defer s.mu.Unlock() + currentSession, ok := s.sessions[id] + if !ok || currentSession == nil { + return false + } + currentSession.initialized = true + return true +} + +func (s *Server) sessionInitialized(id string) bool { + s.mu.RLock() + defer s.mu.RUnlock() + currentSession, ok := s.sessions[id] + return ok && currentSession != nil && currentSession.initialized +} + +func (s *Server) deleteSession(id string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.sessions, id) +} + +func acceptsJSON(accept string) bool { + if accept == "" { + return false + } + return strings.Contains(accept, "application/json") || strings.Contains(accept, "*/*") +} + +func validateOrigin(r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin == "" { + return true + } + + originURL, err := url.Parse(origin) + if err != nil || originURL.Host == "" { + return false + } + if strings.EqualFold(originURL.Host, r.Host) { + return strings.EqualFold(originURL.Scheme, requestScheme(r)) + } + + siteURL := common.GetApiUrlFromRequest(r) + if siteURL == "" { + return false + } + siteParsed, err := url.Parse(siteURL) + if err != nil { + return false + } + return strings.EqualFold(originURL.Host, siteParsed.Host) && strings.EqualFold(originURL.Scheme, siteParsed.Scheme) +} + +func requestScheme(r *http.Request) string { + if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { + return "https" + } + return "http" +} diff --git a/server/mcp_test.go b/server/mcp/handler_test.go similarity index 70% rename from server/mcp_test.go rename to server/mcp/handler_test.go index bc61e04de..3ab64de20 100644 --- a/server/mcp_test.go +++ b/server/mcp/handler_test.go @@ -1,4 +1,4 @@ -package server +package mcp import ( "encoding/json" @@ -13,14 +13,14 @@ import ( "github.com/gin-gonic/gin" ) -func TestMCPInitializeCreatesSession(t *testing.T) { +func TestInitializeCreatesSession(t *testing.T) { gin.SetMode(gin.TestMode) - openListMCP.sessions = map[string]*mcpSession{} + defaultServer.sessions = map[string]*session{} r := gin.New() r.POST("/mcp", func(c *gin.Context) { common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) - openListMCP.handlePost(c) + defaultServer.handlePost(c) }) req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ @@ -33,7 +33,7 @@ func TestMCPInitializeCreatesSession(t *testing.T) { "clientInfo":{"name":"test-client","version":"1.0.0"} } }`)) - req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Accept", "application/json") req.Header.Set("Origin", "http://example.com") w := httptest.NewRecorder() @@ -42,41 +42,33 @@ func TestMCPInitializeCreatesSession(t *testing.T) { if w.Code != http.StatusOK { t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) } - if got := w.Header().Get(mcpSessionHeader); got == "" { + if got := w.Header().Get(SessionHeader); got == "" { t.Fatal("expected session header to be set") } - var resp mcpResponse + var resp response if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("failed to decode response: %v", err) } if resp.Error != nil { t.Fatalf("unexpected error response: %+v", resp.Error) } - - result, ok := resp.Result.(map[string]any) - if !ok { - t.Fatalf("unexpected result type: %T", resp.Result) - } - if result["protocolVersion"] != mcpProtocolVersion { - t.Fatalf("unexpected protocol version: got %v want %s", result["protocolVersion"], mcpProtocolVersion) - } } -func TestMCPDeleteRemovesSession(t *testing.T) { +func TestDeleteRemovesSession(t *testing.T) { gin.SetMode(gin.TestMode) - openListMCP.sessions = map[string]*mcpSession{} + defaultServer.sessions = map[string]*session{} - session := openListMCP.createSession(1) + currentSession := defaultServer.createSession(1) r := gin.New() r.DELETE("/mcp", func(c *gin.Context) { common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) - openListMCP.handleDelete(c) + defaultServer.handleDelete(c) }) req := httptest.NewRequest(http.MethodDelete, "http://example.com/mcp", nil) req.Header.Set("Origin", "http://example.com") - req.Header.Set(mcpSessionHeader, session.id) + req.Header.Set(SessionHeader, currentSession.id) w := httptest.NewRecorder() r.ServeHTTP(w, req) @@ -84,18 +76,18 @@ func TestMCPDeleteRemovesSession(t *testing.T) { if w.Code != http.StatusNoContent { t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusNoContent) } - if _, ok := openListMCP.getSession(session.id); ok { + if _, ok := defaultServer.getSession(currentSession.id); ok { t.Fatal("expected session to be deleted") } } -func TestMCPGetReturnsMethodNotAllowed(t *testing.T) { +func TestGetReturnsMethodNotAllowed(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() r.GET("/mcp", func(c *gin.Context) { common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) - openListMCP.handleGet(c) + defaultServer.handleGet(c) }) req := httptest.NewRequest(http.MethodGet, "http://example.com/mcp", nil) diff --git a/server/mcp_tools.go b/server/mcp/tools.go similarity index 56% rename from server/mcp_tools.go rename to server/mcp/tools.go index afed29143..22097fefd 100644 --- a/server/mcp_tools.go +++ b/server/mcp/tools.go @@ -1,37 +1,37 @@ -package server +package mcp import "encoding/json" -type mcpTool struct { - Name string `json:"name"` - Title string `json:"title,omitempty"` - Description string `json:"description,omitempty"` - InputSchema mcpToolInputSchema `json:"inputSchema"` +type tool struct { + Name string `json:"name"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + InputSchema toolInputSchema `json:"inputSchema"` } -type mcpToolInputSchema struct { - Type string `json:"type"` - Properties map[string]mcpToolSchemaProperty `json:"properties,omitempty"` - Required []string `json:"required,omitempty"` +type toolInputSchema struct { + Type string `json:"type"` + Properties map[string]schemaProperty `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` } -type mcpToolSchemaProperty struct { +type schemaProperty struct { Type string `json:"type,omitempty"` Description string `json:"description,omitempty"` } -type mcpToolsListParams struct { +type toolsListParams struct { Cursor string `json:"cursor,omitempty"` } -var openListMCPTools = []mcpTool{ +var openListTools = []tool{ { Name: "openlist.fs.list", Title: "OpenList FS List", Description: "List files and directories under a mount path that the current user can access.", - InputSchema: mcpToolInputSchema{ + InputSchema: toolInputSchema{ Type: "object", - Properties: map[string]mcpToolSchemaProperty{ + Properties: map[string]schemaProperty{ "path": { Type: "string", Description: "Mount path to list, for example \"/\" or \"/movies\".", @@ -58,23 +58,23 @@ var openListMCPTools = []mcpTool{ }, } -func (s *mcpServer) handleToolsList(req mcpRequest) mcpResponse { - var params mcpToolsListParams +func (s *Server) handleToolsList(req request) response { + var params toolsListParams if len(req.Params) > 0 { if err := json.Unmarshal(req.Params, ¶ms); err != nil { - return mcpResponse{ + return response{ JSONRPC: "2.0", ID: req.ID, - Error: &mcpError{Code: -32602, Message: "invalid tools/list params"}, + Error: &rpcError{Code: -32602, Message: "invalid tools/list params"}, } } } - return mcpResponse{ + return response{ JSONRPC: "2.0", ID: req.ID, Result: map[string]any{ - "tools": openListMCPTools, + "tools": openListTools, }, } } From 06e1e9bba96ca711da8bcbc5b3294ab4dd16a491 Mon Sep 17 00:00:00 2001 From: jyxjjj <773933146@qq.com> Date: Tue, 19 May 2026 00:04:27 +0800 Subject: [PATCH 03/16] fix(mcp): address endpoint review issues MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 收紧 Streamable HTTP Accept 头校验 - 为 MCP session 增加过期清理和数量上限 - 简化 fs.list 参数解析并修复分页边界 - 使用独立 Server 实例隔离 MCP 测试状态 Co-authored-by: Codex <267193182+codex@users.noreply.github.com> Signed-off-by: jyxjjj <16695261+jyxjjj@users.noreply.github.com> --- server/mcp/call_test.go | 32 +++++----- server/mcp/fs_list.go | 50 ++++++---------- server/mcp/handler.go | 117 +++++++++++++++++++++++++++++++++---- server/mcp/handler_test.go | 31 +++++++--- 4 files changed, 163 insertions(+), 67 deletions(-) diff --git a/server/mcp/call_test.go b/server/mcp/call_test.go index 43a501524..8c1661da4 100644 --- a/server/mcp/call_test.go +++ b/server/mcp/call_test.go @@ -15,14 +15,14 @@ import ( func TestToolsListRequiresInitializedSession(t *testing.T) { gin.SetMode(gin.TestMode) - defaultServer.sessions = map[string]*session{ + srv := newTestServer(map[string]*session{ "s1": {id: "s1", userID: 1}, - } + }) r := gin.New() r.POST("/mcp", func(c *gin.Context) { common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) - defaultServer.handlePost(c) + srv.handlePost(c) }) req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ @@ -30,7 +30,7 @@ func TestToolsListRequiresInitializedSession(t *testing.T) { "id":1, "method":"tools/list" }`)) - req.Header.Set("Accept", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") req.Header.Set("Origin", "http://example.com") req.Header.Set(SessionHeader, "s1") @@ -48,14 +48,14 @@ func TestToolsListRequiresInitializedSession(t *testing.T) { func TestToolsListSuccess(t *testing.T) { gin.SetMode(gin.TestMode) - defaultServer.sessions = map[string]*session{ + srv := newTestServer(map[string]*session{ "s2": {id: "s2", userID: 1, initialized: true}, - } + }) r := gin.New() r.POST("/mcp", func(c *gin.Context) { common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) - defaultServer.handlePost(c) + srv.handlePost(c) }) req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ @@ -63,7 +63,7 @@ func TestToolsListSuccess(t *testing.T) { "id":2, "method":"tools/list" }`)) - req.Header.Set("Accept", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") req.Header.Set("Origin", "http://example.com") req.Header.Set(SessionHeader, "s2") @@ -90,14 +90,14 @@ func TestToolsListSuccess(t *testing.T) { func TestToolsCallUnknownTool(t *testing.T) { gin.SetMode(gin.TestMode) - defaultServer.sessions = map[string]*session{ + srv := newTestServer(map[string]*session{ "s3": {id: "s3", userID: 1, initialized: true}, - } + }) r := gin.New() r.POST("/mcp", func(c *gin.Context) { common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) - defaultServer.handlePost(c) + srv.handlePost(c) }) req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ @@ -106,7 +106,7 @@ func TestToolsCallUnknownTool(t *testing.T) { "method":"tools/call", "params":{"name":"openlist.fs.unknown","arguments":{"path":"/"}} }`)) - req.Header.Set("Accept", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") req.Header.Set("Origin", "http://example.com") req.Header.Set(SessionHeader, "s3") @@ -124,14 +124,14 @@ func TestToolsCallUnknownTool(t *testing.T) { func TestToolsCallInvalidParams(t *testing.T) { gin.SetMode(gin.TestMode) - defaultServer.sessions = map[string]*session{ + srv := newTestServer(map[string]*session{ "s4": {id: "s4", userID: 1, initialized: true}, - } + }) r := gin.New() r.POST("/mcp", func(c *gin.Context) { common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) - defaultServer.handlePost(c) + srv.handlePost(c) }) req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ @@ -140,7 +140,7 @@ func TestToolsCallInvalidParams(t *testing.T) { "method":"tools/call", "params":{"name":"openlist.fs.list","arguments":"bad"} }`)) - req.Header.Set("Accept", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") req.Header.Set("Origin", "http://example.com") req.Header.Set(SessionHeader, "s4") diff --git a/server/mcp/fs_list.go b/server/mcp/fs_list.go index 7410bb14e..6e2c7c72e 100644 --- a/server/mcp/fs_list.go +++ b/server/mcp/fs_list.go @@ -25,11 +25,6 @@ type fsListArgs struct { PerPage int `json:"per_page"` } -type toolCallEnvelope struct { - Name string `json:"name"` - Arguments json.RawMessage `json:"arguments"` -} - func (s *Server) callFSList(c *gin.Context, raw json.RawMessage) (any, *rpcError) { args, mcpErr := parseFSListArgs(raw) if mcpErr != nil { @@ -79,7 +74,7 @@ func (s *Server) callFSList(c *gin.Context, raw json.RawMessage) (any, *rpcError Total: int64(total), Write: write, WriteContentBypass: writeContentBypass, - Provider: detectProvider(reqPath, paged), + Provider: "unknown", Readme: getReadme(meta, reqPath), Header: getHeader(meta, reqPath), }, nil @@ -94,20 +89,9 @@ func parseFSListArgs(raw json.RawMessage) (*fsListArgs, *rpcError) { return args, nil } - if err := json.Unmarshal(raw, args); err == nil { - normalizeFSListArgs(args) - return args, nil - } - - var envelope toolCallEnvelope - if err := json.Unmarshal(raw, &envelope); err != nil { + if err := json.Unmarshal(raw, args); err != nil { return nil, &rpcError{Code: -32602, Message: "invalid openlist.fs.list arguments"} } - if len(envelope.Arguments) > 0 { - if err := json.Unmarshal(envelope.Arguments, args); err != nil { - return nil, &rpcError{Code: -32602, Message: "invalid openlist.fs.list arguments"} - } - } normalizeFSListArgs(args) return args, nil } @@ -124,11 +108,24 @@ func normalizeFSListArgs(args *fsListArgs) { func paginateObjs(objs []model.Obj, page, perPage int) (int, []model.Obj) { total := len(objs) - start := (page - 1) * perPage + if page < 1 { + page = 1 + } + if perPage < 1 { + perPage = model.MaxInt + } + offset := page - 1 + if offset > total/perPage { + return total, []model.Obj{} + } + start := offset * perPage if start > total { return total, []model.Obj{} } - end := start + perPage + end := total + if perPage <= total-start { + end = start + perPage + } if end > total { end = total } @@ -157,19 +154,6 @@ func toObjResp(objs []model.Obj, parent string, encrypt bool) []handles.ObjResp return resp } -func detectProvider(reqPath string, objs []model.Obj) string { - storage, err := fs.GetStorage(reqPath, &fs.GetStoragesArgs{}) - if err == nil && storage != nil { - return storage.Config().Name - } - for _, obj := range objs { - if provider, ok := model.GetProvider(obj); ok && provider != "" { - return provider - } - } - return "unknown" -} - func getReadme(meta *model.Meta, path string) string { if meta != nil && common.MetaCoversPath(meta.Path, path, meta.RSub) { return meta.Readme diff --git a/server/mcp/handler.go b/server/mcp/handler.go index 95c23130a..2338b93bc 100644 --- a/server/mcp/handler.go +++ b/server/mcp/handler.go @@ -4,8 +4,10 @@ import ( "encoding/json" "fmt" "io" + "mime" "net/http" "net/url" + "strconv" "strings" "sync" "time" @@ -21,6 +23,8 @@ import ( const ( ProtocolVersion = "2025-06-18" SessionHeader = "Mcp-Session-Id" + sessionTTL = 30 * time.Minute + maxSessions = 1024 ) type session struct { @@ -28,6 +32,7 @@ type session struct { userID uint initialized bool createdAt time.Time + lastUsedAt time.Time } type Server struct { @@ -84,7 +89,7 @@ func (s *Server) handlePost(c *gin.Context) { c.Status(http.StatusForbidden) return } - if !acceptsJSON(c.GetHeader("Accept")) { + if !acceptsStreamableHTTP(c.GetHeader("Accept")) { c.Status(http.StatusNotAcceptable) return } @@ -245,10 +250,15 @@ func (s *Server) createSession(userID uint) *session { s.mu.Lock() defer s.mu.Unlock() + now := time.Now() + s.pruneExpiredSessionsLocked(now) + s.pruneLeastRecentlyUsedSessionsLocked(max(0, len(s.sessions)-maxSessions+1)) + currentSession := &session{ - id: random.Token(), - userID: userID, - createdAt: time.Now(), + id: random.Token(), + userID: userID, + createdAt: now, + lastUsedAt: now, } s.sessions[currentSession.id] = currentSession return currentSession @@ -258,12 +268,18 @@ func (s *Server) getSession(id string) (session, bool) { if id == "" { return session{}, false } - s.mu.RLock() - defer s.mu.RUnlock() + s.mu.Lock() + defer s.mu.Unlock() currentSession, ok := s.sessions[id] if !ok || currentSession == nil { return session{}, false } + now := time.Now() + if sessionExpired(currentSession, now) { + delete(s.sessions, id) + return session{}, false + } + currentSession.lastUsedAt = now return *currentSession, true } @@ -274,15 +290,30 @@ func (s *Server) markSessionInitialized(id string) bool { if !ok || currentSession == nil { return false } + now := time.Now() + if sessionExpired(currentSession, now) { + delete(s.sessions, id) + return false + } currentSession.initialized = true + currentSession.lastUsedAt = now return true } func (s *Server) sessionInitialized(id string) bool { - s.mu.RLock() - defer s.mu.RUnlock() + s.mu.Lock() + defer s.mu.Unlock() currentSession, ok := s.sessions[id] - return ok && currentSession != nil && currentSession.initialized + if !ok || currentSession == nil { + return false + } + now := time.Now() + if sessionExpired(currentSession, now) { + delete(s.sessions, id) + return false + } + currentSession.lastUsedAt = now + return currentSession.initialized } func (s *Server) deleteSession(id string) { @@ -291,11 +322,75 @@ func (s *Server) deleteSession(id string) { delete(s.sessions, id) } -func acceptsJSON(accept string) bool { +func (s *Server) pruneExpiredSessionsLocked(now time.Time) { + for id, currentSession := range s.sessions { + if currentSession == nil || sessionExpired(currentSession, now) { + delete(s.sessions, id) + } + } +} + +func (s *Server) pruneLeastRecentlyUsedSessionsLocked(count int) { + for range count { + var ( + oldestID string + oldest time.Time + ) + for id, currentSession := range s.sessions { + if currentSession == nil { + oldestID = id + break + } + lastUsedAt := sessionLastUsedAt(currentSession) + if oldestID == "" || lastUsedAt.Before(oldest) { + oldestID = id + oldest = lastUsedAt + } + } + if oldestID == "" { + return + } + delete(s.sessions, oldestID) + } +} + +func sessionExpired(currentSession *session, now time.Time) bool { + lastUsedAt := sessionLastUsedAt(currentSession) + return !lastUsedAt.IsZero() && now.Sub(lastUsedAt) > sessionTTL +} + +func sessionLastUsedAt(currentSession *session) time.Time { + if currentSession.lastUsedAt.IsZero() { + return currentSession.createdAt + } + return currentSession.lastUsedAt +} + +func acceptsStreamableHTTP(accept string) bool { if accept == "" { return false } - return strings.Contains(accept, "application/json") || strings.Contains(accept, "*/*") + hasJSON := false + hasSSE := false + for part := range strings.SplitSeq(accept, ",") { + mediaType, params, err := mime.ParseMediaType(strings.TrimSpace(part)) + if err != nil { + continue + } + if q, ok := params["q"]; ok { + quality, err := strconv.ParseFloat(q, 64) + if err == nil && quality == 0 { + continue + } + } + switch mediaType { + case "application/json": + hasJSON = true + case "text/event-stream": + hasSSE = true + } + } + return hasJSON && hasSSE } func validateOrigin(r *http.Request) bool { diff --git a/server/mcp/handler_test.go b/server/mcp/handler_test.go index 3ab64de20..3edfed8d4 100644 --- a/server/mcp/handler_test.go +++ b/server/mcp/handler_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" @@ -15,12 +16,12 @@ import ( func TestInitializeCreatesSession(t *testing.T) { gin.SetMode(gin.TestMode) - defaultServer.sessions = map[string]*session{} + srv := newTestServer(nil) r := gin.New() r.POST("/mcp", func(c *gin.Context) { common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) - defaultServer.handlePost(c) + srv.handlePost(c) }) req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ @@ -33,7 +34,7 @@ func TestInitializeCreatesSession(t *testing.T) { "clientInfo":{"name":"test-client","version":"1.0.0"} } }`)) - req.Header.Set("Accept", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") req.Header.Set("Origin", "http://example.com") w := httptest.NewRecorder() @@ -57,13 +58,13 @@ func TestInitializeCreatesSession(t *testing.T) { func TestDeleteRemovesSession(t *testing.T) { gin.SetMode(gin.TestMode) - defaultServer.sessions = map[string]*session{} + srv := newTestServer(nil) - currentSession := defaultServer.createSession(1) + currentSession := srv.createSession(1) r := gin.New() r.DELETE("/mcp", func(c *gin.Context) { common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) - defaultServer.handleDelete(c) + srv.handleDelete(c) }) req := httptest.NewRequest(http.MethodDelete, "http://example.com/mcp", nil) @@ -76,7 +77,7 @@ func TestDeleteRemovesSession(t *testing.T) { if w.Code != http.StatusNoContent { t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusNoContent) } - if _, ok := defaultServer.getSession(currentSession.id); ok { + if _, ok := srv.getSession(currentSession.id); ok { t.Fatal("expected session to be deleted") } } @@ -100,3 +101,19 @@ func TestGetReturnsMethodNotAllowed(t *testing.T) { t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusMethodNotAllowed) } } + +func newTestServer(sessions map[string]*session) *Server { + if sessions == nil { + sessions = map[string]*session{} + } + now := time.Now() + for _, currentSession := range sessions { + if currentSession.createdAt.IsZero() { + currentSession.createdAt = now + } + if currentSession.lastUsedAt.IsZero() { + currentSession.lastUsedAt = now + } + } + return &Server{sessions: sessions} +} From c824ca088e19d5bfbf68f1903b808720730e4f83 Mon Sep 17 00:00:00 2001 From: jyxjjj <773933146@qq.com> Date: Tue, 19 May 2026 10:08:10 +0800 Subject: [PATCH 04/16] feat(mcp): add fs get and link tools MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 openlist.fs.get 和 openlist.fs.link 工具 - 复用文件详情、代理链接和权限校验逻辑 - 补充工具列表和参数解析测试 Co-authored-by: Codex <267193182+codex@users.noreply.github.com> Signed-off-by: jyxjjj <16695261+jyxjjj@users.noreply.github.com> --- server/mcp/call.go | 4 + server/mcp/call_test.go | 14 ++- server/mcp/fs_get.go | 168 +++++++++++++++++++++++++++++++++ server/mcp/fs_link.go | 200 ++++++++++++++++++++++++++++++++++++++++ server/mcp/get_test.go | 39 ++++++++ server/mcp/handler.go | 2 +- server/mcp/link_test.go | 120 ++++++++++++++++++++++++ server/mcp/tools.go | 42 +++++++++ 8 files changed, 587 insertions(+), 2 deletions(-) create mode 100644 server/mcp/fs_get.go create mode 100644 server/mcp/fs_link.go create mode 100644 server/mcp/get_test.go create mode 100644 server/mcp/link_test.go diff --git a/server/mcp/call.go b/server/mcp/call.go index 7734e5786..5a19b9b4b 100644 --- a/server/mcp/call.go +++ b/server/mcp/call.go @@ -41,6 +41,10 @@ func (s *Server) handleToolsCall(c *gin.Context, req request) (int, response) { switch params.Name { case "openlist.fs.list": result, err = s.callFSList(c, params.Arguments) + case "openlist.fs.get": + result, err = s.callFSGet(c, params.Arguments) + case "openlist.fs.link": + result, err = s.callFSLink(c, params.Arguments) default: return http.StatusOK, response{ JSONRPC: "2.0", diff --git a/server/mcp/call_test.go b/server/mcp/call_test.go index 8c1661da4..d39b260af 100644 --- a/server/mcp/call_test.go +++ b/server/mcp/call_test.go @@ -83,9 +83,21 @@ func TestToolsListSuccess(t *testing.T) { t.Fatalf("unexpected result type: %T", resp.Result) } tools, ok := result["tools"].([]any) - if !ok || len(tools) != 1 { + if !ok || len(tools) != 3 { t.Fatalf("unexpected tools payload: %#v", result["tools"]) } + names := map[string]bool{} + for _, rawTool := range tools { + currentTool, ok := rawTool.(map[string]any) + if !ok { + t.Fatalf("unexpected tool payload: %#v", rawTool) + } + name, _ := currentTool["name"].(string) + names[name] = true + } + if !names["openlist.fs.list"] || !names["openlist.fs.get"] || !names["openlist.fs.link"] { + t.Fatalf("unexpected tool names: %#v", names) + } } func TestToolsCallUnknownTool(t *testing.T) { diff --git a/server/mcp/fs_get.go b/server/mcp/fs_get.go new file mode 100644 index 000000000..550cc036f --- /dev/null +++ b/server/mcp/fs_get.go @@ -0,0 +1,168 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + stdpath "path" + "strings" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/fs" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/OpenListTeam/OpenList/v4/internal/setting" + "github.com/OpenListTeam/OpenList/v4/internal/sign" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/OpenListTeam/OpenList/v4/server/common" + "github.com/OpenListTeam/OpenList/v4/server/handles" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +type fsGetArgs struct { + Path string `json:"path"` + Password string `json:"password"` +} + +func (s *Server) callFSGet(c *gin.Context, raw json.RawMessage) (any, *rpcError) { + args, mcpErr := parseFSGetArgs(raw) + if mcpErr != nil { + return nil, mcpErr + } + + user, ok := c.Request.Context().Value(conf.UserKey).(*model.User) + if !ok || user == nil { + return nil, &rpcError{Code: -32603, Message: "missing user context"} + } + if user.IsGuest() && user.Disabled { + return nil, &rpcError{Code: -32001, Message: "guest user is disabled"} + } + + reqPath, err := user.JoinPath(args.Path) + if err != nil { + return nil, &rpcError{Code: -32003, Message: err.Error()} + } + + meta, err := op.GetNearestMeta(reqPath) + if err != nil && !errors.Is(errors.Cause(err), errs.MetaNotFound) { + return nil, &rpcError{Code: -32603, Message: err.Error()} + } + if !common.CanAccess(user, meta, reqPath, args.Password) { + return nil, &rpcError{Code: -32003, Message: "password is incorrect or you have no permission"} + } + + ctx := context.WithValue(c.Request.Context(), conf.MetaKey, meta) + obj, err := fs.Get(ctx, reqPath, &fs.GetArgs{ + WithStorageDetails: !user.IsGuest() && !setting.GetBool(conf.HideStorageDetails), + }) + if err != nil { + return nil, &rpcError{Code: -32603, Message: err.Error()} + } + + rawURL, provider, err := buildFSGetRawURL(ctx, c, reqPath, obj, meta) + if err != nil { + return nil, &rpcError{Code: -32603, Message: err.Error()} + } + + parentPath := stdpath.Dir(reqPath) + var related []model.Obj + sameLevelFiles, err := fs.List(ctx, parentPath, &fs.ListArgs{}) + if err == nil { + related = filterRelatedObjs(sameLevelFiles, obj) + } + + parentMeta, _ := op.GetNearestMeta(parentPath) + thumb, _ := model.GetThumb(obj) + mountDetails, _ := model.GetStorageDetails(obj) + return handles.FsGetResp{ + ObjResp: handles.ObjResp{ + Name: obj.GetName(), + Size: obj.GetSize(), + IsDir: obj.IsDir(), + Modified: obj.ModTime(), + Created: obj.CreateTime(), + Sign: common.Sign(obj, parentPath, isEncrypt(meta, reqPath)), + Thumb: thumb, + Type: utils.GetFileType(obj.GetName()), + HashInfoStr: obj.GetHash().String(), + HashInfo: obj.GetHash().Export(), + MountDetails: mountDetails, + }, + RawURL: rawURL, + Readme: getReadme(meta, reqPath), + Header: getHeader(meta, reqPath), + Provider: provider, + Related: toObjResp(related, parentPath, isEncrypt(parentMeta, parentPath)), + }, nil +} + +func parseFSGetArgs(raw json.RawMessage) (*fsGetArgs, *rpcError) { + args := &fsGetArgs{} + if len(raw) == 0 || string(raw) == "null" { + return nil, &rpcError{Code: -32602, Message: "invalid openlist.fs.get arguments"} + } + + if err := json.Unmarshal(raw, args); err != nil { + return nil, &rpcError{Code: -32602, Message: "invalid openlist.fs.get arguments"} + } + if args.Path == "" { + return nil, &rpcError{Code: -32602, Message: "path is required"} + } + return args, nil +} + +func buildFSGetRawURL(ctx context.Context, c *gin.Context, reqPath string, obj model.Obj, meta *model.Meta) (string, string, error) { + storage, storageErr := fs.GetStorage(reqPath, &fs.GetStoragesArgs{}) + provider, ok := model.GetProvider(obj) + if !ok && storageErr == nil { + provider = storage.Config().Name + } + if obj.IsDir() { + return "", provider, nil + } + if storageErr != nil { + return "", provider, storageErr + } + + if storage.Config().MustProxy() || storage.GetStorage().WebProxy { + rawURL := common.GenerateDownProxyURL(storage.GetStorage(), reqPath) + if rawURL != "" { + return rawURL, provider, nil + } + query := "" + if isEncrypt(meta, reqPath) || setting.GetBool(conf.SignAll) { + query = "?sign=" + sign.Sign(reqPath) + } + return fmt.Sprintf("%s/p%s%s", common.GetApiUrl(ctx), utils.EncodePath(reqPath, true), query), provider, nil + } + + if url, ok := model.GetUrl(obj); ok { + return url, provider, nil + } + link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{ + IP: c.ClientIP(), + Header: c.Request.Header, + Redirect: true, + }) + if err != nil { + return "", provider, err + } + defer link.Close() + return link.URL, provider, nil +} + +func filterRelatedObjs(objs []model.Obj, obj model.Obj) []model.Obj { + related := make([]model.Obj, 0) + nameWithoutExt := strings.TrimSuffix(obj.GetName(), stdpath.Ext(obj.GetName())) + for _, current := range objs { + if current.GetName() == obj.GetName() { + continue + } + if strings.HasPrefix(current.GetName(), nameWithoutExt) { + related = append(related, current) + } + } + return related +} diff --git a/server/mcp/fs_link.go b/server/mcp/fs_link.go new file mode 100644 index 000000000..8c854672c --- /dev/null +++ b/server/mcp/fs_link.go @@ -0,0 +1,200 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + stdpath "path" + "time" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/driver" + "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/fs" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/OpenListTeam/OpenList/v4/internal/setting" + "github.com/OpenListTeam/OpenList/v4/internal/sign" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/OpenListTeam/OpenList/v4/server/common" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +type fsLinkArgs struct { + Path string `json:"path"` + Password string `json:"password"` + Type string `json:"type"` +} + +type fsLinkResp struct { + Path string `json:"path"` + Name string `json:"name"` + Size int64 `json:"size"` + IsDir bool `json:"is_dir"` + Modified time.Time `json:"modified"` + Provider string `json:"provider"` + URL string `json:"url"` + URLType string `json:"url_type"` + DirectURL string `json:"direct_url,omitempty"` + ProxyURL string `json:"proxy_url,omitempty"` + DownloadURL string `json:"download_url,omitempty"` + Header http.Header `json:"header,omitempty"` + ContentLength int64 `json:"content_length,omitempty"` + Concurrency int `json:"concurrency,omitempty"` + PartSize int `json:"part_size,omitempty"` +} + +func (s *Server) callFSLink(c *gin.Context, raw json.RawMessage) (any, *rpcError) { + args, mcpErr := parseFSLinkArgs(raw) + if mcpErr != nil { + return nil, mcpErr + } + + user, ok := c.Request.Context().Value(conf.UserKey).(*model.User) + if !ok || user == nil { + return nil, &rpcError{Code: -32603, Message: "missing user context"} + } + if user.IsGuest() && user.Disabled { + return nil, &rpcError{Code: -32001, Message: "guest user is disabled"} + } + + reqPath, err := user.JoinPath(args.Path) + if err != nil { + return nil, &rpcError{Code: -32003, Message: err.Error()} + } + + meta, err := op.GetNearestMeta(reqPath) + if err != nil && !errors.Is(errors.Cause(err), errs.MetaNotFound) { + return nil, &rpcError{Code: -32603, Message: err.Error()} + } + if !common.CanAccess(user, meta, reqPath, args.Password) { + return nil, &rpcError{Code: -32003, Message: "password is incorrect or you have no permission"} + } + + ctx := context.WithValue(c.Request.Context(), conf.MetaKey, meta) + obj, err := fs.Get(ctx, reqPath, &fs.GetArgs{ + WithStorageDetails: !user.IsGuest() && !setting.GetBool(conf.HideStorageDetails), + }) + if err != nil { + return nil, &rpcError{Code: -32603, Message: err.Error()} + } + if obj.IsDir() { + return nil, &rpcError{Code: -32003, Message: "path is a directory"} + } + + storage, err := fs.GetStorage(reqPath, &fs.GetStoragesArgs{}) + if err != nil { + return nil, &rpcError{Code: -32603, Message: err.Error()} + } + + linkInfo, err := buildFSLinkInfo(ctx, c, reqPath, args, obj, meta, storage) + if err != nil { + return nil, &rpcError{Code: -32603, Message: err.Error()} + } + return linkInfo, nil +} + +func parseFSLinkArgs(raw json.RawMessage) (*fsLinkArgs, *rpcError) { + args := &fsLinkArgs{} + if len(raw) == 0 || string(raw) == "null" { + return nil, &rpcError{Code: -32602, Message: "invalid openlist.fs.link arguments"} + } + if err := json.Unmarshal(raw, args); err != nil { + return nil, &rpcError{Code: -32602, Message: "invalid openlist.fs.link arguments"} + } + if args.Path == "" { + return nil, &rpcError{Code: -32602, Message: "path is required"} + } + return args, nil +} + +func buildFSLinkInfo(ctx context.Context, c *gin.Context, reqPath string, args *fsLinkArgs, obj model.Obj, meta *model.Meta, storage driver.Driver) (*fsLinkResp, error) { + provider, ok := model.GetProvider(obj) + if !ok { + provider = storage.Config().Name + } + + resp := &fsLinkResp{ + Path: reqPath, + Name: obj.GetName(), + Size: obj.GetSize(), + IsDir: obj.IsDir(), + Modified: obj.ModTime(), + Provider: provider, + DownloadURL: signedFileURL(ctx, "/d", reqPath, meta, args.Type), + } + + if canProxyFile(storage, stdpath.Base(reqPath)) { + proxyURL := proxyFileURL(ctx, reqPath, meta, storage.GetStorage(), args.Type) + resp.ProxyURL = proxyURL + } + + if common.ShouldProxy(storage, stdpath.Base(reqPath)) { + resp.URL = resp.ProxyURL + resp.URLType = "proxy" + return resp, nil + } + + link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{ + IP: c.ClientIP(), + Header: c.Request.Header, + Type: args.Type, + Redirect: true, + }) + if err != nil { + return nil, err + } + defer link.Close() + + resp.DirectURL = link.URL + resp.URL = link.URL + resp.URLType = "direct" + resp.Header = link.Header + resp.ContentLength = link.ContentLength + resp.Concurrency = link.Concurrency + resp.PartSize = link.PartSize + return resp, nil +} + +func canProxyFile(storage driver.Driver, filename string) bool { + if storage.Config().MustProxy() || storage.GetStorage().WebProxy || storage.GetStorage().WebdavProxyURL() { + return true + } + if utils.SliceContains(conf.SlicesMap[conf.ProxyTypes], utils.Ext(filename)) { + return true + } + if utils.SliceContains(conf.SlicesMap[conf.TextTypes], utils.Ext(filename)) { + return true + } + return false +} + +func proxyFileURL(ctx context.Context, reqPath string, meta *model.Meta, storage *model.Storage, linkType string) string { + if url := common.GenerateDownProxyURL(storage, reqPath); url != "" { + return url + } + return signedFileURL(ctx, "/p", reqPath, meta, linkType) +} + +func signedFileURL(ctx context.Context, prefix, reqPath string, meta *model.Meta, linkType string) string { + query := url.Values{} + if isEncrypt(meta, reqPath) || setting.GetBool(conf.SignAll) { + query.Set("sign", sign.Sign(reqPath)) + } + if linkType != "" { + query.Set("type", linkType) + } + rawQuery := "" + if encoded := query.Encode(); encoded != "" { + rawQuery = "?" + encoded + } + return fmt.Sprintf("%s%s%s%s", + common.GetApiUrl(ctx), + prefix, + utils.EncodePath(reqPath, true), + rawQuery, + ) +} diff --git a/server/mcp/get_test.go b/server/mcp/get_test.go new file mode 100644 index 000000000..81fbe3900 --- /dev/null +++ b/server/mcp/get_test.go @@ -0,0 +1,39 @@ +package mcp + +import ( + "encoding/json" + "testing" +) + +func TestParseFSGetArgsRequiresPath(t *testing.T) { + _, err := parseFSGetArgs(json.RawMessage(`{"password":"secret"}`)) + if err == nil { + t.Fatal("expected error") + } + if err.Code != -32602 { + t.Fatalf("unexpected error: %+v", err) + } +} + +func TestParseFSGetArgs(t *testing.T) { + args, err := parseFSGetArgs(json.RawMessage(`{"path":"/movie.mkv","password":"secret"}`)) + if err != nil { + t.Fatalf("unexpected error: %+v", err) + } + if args.Path != "/movie.mkv" { + t.Fatalf("unexpected path: %q", args.Path) + } + if args.Password != "secret" { + t.Fatalf("unexpected password: %q", args.Password) + } +} + +func TestParseFSGetArgsRejectsInvalidJSON(t *testing.T) { + _, err := parseFSGetArgs(json.RawMessage(`"bad"`)) + if err == nil { + t.Fatal("expected error") + } + if err.Code != -32602 { + t.Fatalf("unexpected error: %+v", err) + } +} diff --git a/server/mcp/handler.go b/server/mcp/handler.go index 2338b93bc..d5c17666f 100644 --- a/server/mcp/handler.go +++ b/server/mcp/handler.go @@ -219,7 +219,7 @@ func (s *Server) handleInitialize(c *gin.Context, req request) { "name": "OpenList MCP", "version": conf.Version, }, - "instructions": "Complete initialization with notifications/initialized, then use tools/list and tools/call. The first available tool is openlist.fs.list.", + "instructions": "Complete initialization with notifications/initialized, then use tools/list and tools/call. Available tools include openlist.fs.list, openlist.fs.get, and openlist.fs.link.", }, }) } diff --git a/server/mcp/link_test.go b/server/mcp/link_test.go new file mode 100644 index 000000000..2bb0189c7 --- /dev/null +++ b/server/mcp/link_test.go @@ -0,0 +1,120 @@ +package mcp + +import ( + "context" + "encoding/json" + "net/http/httptest" + "testing" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/driver" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/gin-gonic/gin" +) + +func TestParseFSLinkArgs(t *testing.T) { + args, err := parseFSLinkArgs(json.RawMessage(`{"path":"/file.txt","password":"pw","type":"thumb"}`)) + if err != nil { + t.Fatalf("unexpected error: %+v", err) + } + if args.Path != "/file.txt" || args.Password != "pw" || args.Type != "thumb" { + t.Fatalf("unexpected args: %+v", args) + } +} + +func TestParseFSLinkArgsRequiresPath(t *testing.T) { + _, err := parseFSLinkArgs(json.RawMessage(`{"type":"thumb"}`)) + if err == nil || err.Code != -32602 { + t.Fatalf("unexpected error: %+v", err) + } +} + +func TestCanProxyFile(t *testing.T) { + storage := &fsLinkTestDriver{ + config: driver.Config{Name: "Test"}, + storage: model.Storage{ + MountPath: "/", + }, + } + if canProxyFile(storage, "file.bin") { + t.Fatal("unexpected proxy support") + } + storage.config.OnlyProxy = true + if !canProxyFile(storage, "file.bin") { + t.Fatal("expected proxy support") + } +} + +func TestBuildFSLinkInfoUsesProxyWhenStorageRequiresProxy(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Cleanup(op.Cache.ClearAll) + op.Cache.SetSetting(conf.LinkExpiration, &model.SettingItem{Key: conf.LinkExpiration, Value: "0"}) + op.Cache.SetSetting(conf.Token, &model.SettingItem{Key: conf.Token, Value: "test-token"}) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("POST", "http://example.com/mcp", nil) + ctx := context.WithValue(c.Request.Context(), conf.ApiUrlKey, "http://openlist.test") + storage := &fsLinkTestDriver{ + config: driver.Config{Name: "Test", OnlyProxy: true}, + storage: model.Storage{ + MountPath: "/", + Proxy: model.Proxy{ + DownProxyURL: "http://proxy.test", + DisableProxySign: true, + }, + }, + } + obj := &model.ObjectURL{ + Object: model.Object{Name: "file.txt", Size: 12}, + Url: model.Url{Url: "http://direct.test/file.txt"}, + } + + meta := &model.Meta{Path: "/file.txt", Password: "secret"} + resp, err := buildFSLinkInfo(ctx, c, "/file.txt", &fsLinkArgs{Path: "/file.txt"}, obj, meta, storage) + if err != nil { + t.Fatalf("unexpected error: %+v", err) + } + if resp.URL != "http://proxy.test/file.txt" || resp.URLType != "proxy" { + t.Fatalf("unexpected selected link: %+v", resp) + } + if resp.DirectURL != "" { + t.Fatalf("direct link should not be resolved for proxy storage: %+v", resp) + } +} + +type fsLinkTestDriver struct { + config driver.Config + storage model.Storage +} + +func (d *fsLinkTestDriver) Config() driver.Config { + return d.config +} + +func (d *fsLinkTestDriver) GetStorage() *model.Storage { + return &d.storage +} + +func (d *fsLinkTestDriver) SetStorage(storage model.Storage) { + d.storage = storage +} + +func (d *fsLinkTestDriver) GetAddition() driver.Additional { + return nil +} + +func (d *fsLinkTestDriver) Init(context.Context) error { + return nil +} + +func (d *fsLinkTestDriver) Drop(context.Context) error { + return nil +} + +func (d *fsLinkTestDriver) List(context.Context, model.Obj, model.ListArgs) ([]model.Obj, error) { + return nil, nil +} + +func (d *fsLinkTestDriver) Link(context.Context, model.Obj, model.LinkArgs) (*model.Link, error) { + return nil, nil +} diff --git a/server/mcp/tools.go b/server/mcp/tools.go index 22097fefd..188c674ce 100644 --- a/server/mcp/tools.go +++ b/server/mcp/tools.go @@ -56,6 +56,48 @@ var openListTools = []tool{ Required: []string{"path"}, }, }, + { + Name: "openlist.fs.get", + Title: "OpenList FS Get", + Description: "Get file or directory details for a mount path that the current user can access.", + InputSchema: toolInputSchema{ + Type: "object", + Properties: map[string]schemaProperty{ + "path": { + Type: "string", + Description: "Mount path to inspect, for example \"/movies/demo.mp4\".", + }, + "password": { + Type: "string", + Description: "Optional password for protected paths.", + }, + }, + Required: []string{"path"}, + }, + }, + { + Name: "openlist.fs.link", + Title: "OpenList FS Link", + Description: "Return usable link information for a file path that the current user can access.", + InputSchema: toolInputSchema{ + Type: "object", + Properties: map[string]schemaProperty{ + "path": { + Type: "string", + Description: "File mount path, for example \"/movies/demo.mp4\".", + }, + "password": { + Type: "string", + Description: "Optional password for protected paths.", + }, + "type": { + Type: "string", + Description: "Optional link type forwarded to storage drivers.", + }, + }, + Required: []string{"path"}, + }, + }, } func (s *Server) handleToolsList(req request) response { From 72115d0b9311bcf1057b53d766e921c5ed23c022 Mon Sep 17 00:00:00 2001 From: jyxjjj <773933146@qq.com> Date: Tue, 19 May 2026 10:25:19 +0800 Subject: [PATCH 05/16] fix(mcp): improve protocol negotiation and proxy handling - Remove WebDAV proxy URL policy from MCP proxy link detection - Return server protocol version during initialize negotiation - Return JSON-RPC error body for invalid MCP Accept header - Add tests for MCP proxy and HTTP negotiation behavior Co-authored-by: Codex <267193182+codex@users.noreply.github.com> Signed-off-by: jyxjjj <16695261+jyxjjj@users.noreply.github.com> --- server/mcp/fs_link.go | 2 +- server/mcp/handler.go | 16 ++++----- server/mcp/handler_test.go | 72 ++++++++++++++++++++++++++++++++++++++ server/mcp/link_test.go | 16 +++++++++ 4 files changed, 96 insertions(+), 10 deletions(-) diff --git a/server/mcp/fs_link.go b/server/mcp/fs_link.go index 8c854672c..43ddad852 100644 --- a/server/mcp/fs_link.go +++ b/server/mcp/fs_link.go @@ -160,7 +160,7 @@ func buildFSLinkInfo(ctx context.Context, c *gin.Context, reqPath string, args * } func canProxyFile(storage driver.Driver, filename string) bool { - if storage.Config().MustProxy() || storage.GetStorage().WebProxy || storage.GetStorage().WebdavProxyURL() { + if storage.Config().MustProxy() || storage.GetStorage().WebProxy { return true } if utils.SliceContains(conf.SlicesMap[conf.ProxyTypes], utils.Ext(filename)) { diff --git a/server/mcp/handler.go b/server/mcp/handler.go index d5c17666f..395e4f567 100644 --- a/server/mcp/handler.go +++ b/server/mcp/handler.go @@ -90,7 +90,13 @@ func (s *Server) handlePost(c *gin.Context) { return } if !acceptsStreamableHTTP(c.GetHeader("Accept")) { - c.Status(http.StatusNotAcceptable) + c.JSON(http.StatusNotAcceptable, response{ + JSONRPC: "2.0", + Error: &rpcError{ + Code: -32000, + Message: "Not Acceptable: client must accept both application/json and text/event-stream", + }, + }) return } @@ -194,14 +200,6 @@ func (s *Server) handleInitialize(c *gin.Context, req request) { return } } - if params.ProtocolVersion != "" && params.ProtocolVersion != ProtocolVersion { - c.JSON(http.StatusBadRequest, response{ - JSONRPC: "2.0", - ID: req.ID, - Error: &rpcError{Code: -32602, Message: "unsupported protocol version"}, - }) - return - } currentSession := s.createSession(c.Request.Context().Value(conf.UserKey).(*model.User).ID) c.Header(SessionHeader, currentSession.id) diff --git a/server/mcp/handler_test.go b/server/mcp/handler_test.go index 3edfed8d4..eb18376be 100644 --- a/server/mcp/handler_test.go +++ b/server/mcp/handler_test.go @@ -56,6 +56,78 @@ func TestInitializeCreatesSession(t *testing.T) { } } +func TestInitializeNegotiatesUnsupportedProtocolVersion(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(nil) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"initialize", + "params":{ + "protocolVersion":"2026-01-01", + "capabilities":{}, + "clientInfo":{"name":"test-client","version":"1.0.0"} + } + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) + } + resp := decodeResponse(t, w) + if resp.Error != nil { + t.Fatalf("unexpected error response: %+v", resp.Error) + } + result, ok := resp.Result.(map[string]any) + if !ok { + t.Fatalf("unexpected result type: %T", resp.Result) + } + if result["protocolVersion"] != ProtocolVersion { + t.Fatalf("unexpected protocol version: %v", result["protocolVersion"]) + } +} + +func TestPostInvalidAcceptReturnsJSONRPCError(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(nil) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"initialize" + }`)) + req.Header.Set("Accept", "application/json") + req.Header.Set("Origin", "http://example.com") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotAcceptable { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusNotAcceptable) + } + resp := decodeResponse(t, w) + if resp.Error == nil || resp.Error.Code != -32000 { + t.Fatalf("unexpected error response: %+v", resp.Error) + } +} + func TestDeleteRemovesSession(t *testing.T) { gin.SetMode(gin.TestMode) srv := newTestServer(nil) diff --git a/server/mcp/link_test.go b/server/mcp/link_test.go index 2bb0189c7..66c820ef6 100644 --- a/server/mcp/link_test.go +++ b/server/mcp/link_test.go @@ -46,6 +46,22 @@ func TestCanProxyFile(t *testing.T) { } } +func TestCanProxyFileIgnoresWebdavProxyURLPolicy(t *testing.T) { + storage := &fsLinkTestDriver{ + config: driver.Config{Name: "Test"}, + storage: model.Storage{ + MountPath: "/", + Proxy: model.Proxy{ + WebdavPolicy: "use_proxy_url", + }, + }, + } + + if canProxyFile(storage, "file.bin") { + t.Fatal("webdav proxy url policy should not enable MCP proxy support") + } +} + func TestBuildFSLinkInfoUsesProxyWhenStorageRequiresProxy(t *testing.T) { gin.SetMode(gin.TestMode) t.Cleanup(op.Cache.ClearAll) From 4780d582b5bc2686408c04301f947d2ab6b4ded0 Mon Sep 17 00:00:00 2001 From: jyxjjj <773933146@qq.com> Date: Wed, 20 May 2026 09:09:54 +0800 Subject: [PATCH 06/16] fix(mcp): set Allow header for GET requests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 为 MCP GET 405 响应添加 Allow 头 - 补充 GET 405 响应头断言 Co-authored-by: Codex <267193182+codex@users.noreply.github.com> Signed-off-by: jyxjjj <16695261+jyxjjj@users.noreply.github.com> --- server/mcp/handler.go | 1 + server/mcp/handler_test.go | 3 +++ 2 files changed, 4 insertions(+) diff --git a/server/mcp/handler.go b/server/mcp/handler.go index 395e4f567..31349096a 100644 --- a/server/mcp/handler.go +++ b/server/mcp/handler.go @@ -81,6 +81,7 @@ func (s *Server) handleGet(c *gin.Context) { c.Status(http.StatusForbidden) return } + c.Header("Allow", "POST, DELETE") c.Status(http.StatusMethodNotAllowed) } diff --git a/server/mcp/handler_test.go b/server/mcp/handler_test.go index eb18376be..b6a42230e 100644 --- a/server/mcp/handler_test.go +++ b/server/mcp/handler_test.go @@ -172,6 +172,9 @@ func TestGetReturnsMethodNotAllowed(t *testing.T) { if w.Code != http.StatusMethodNotAllowed { t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusMethodNotAllowed) } + if allow := w.Header().Get("Allow"); allow != "POST, DELETE" { + t.Fatalf("unexpected Allow header: got %q want %q", allow, "POST, DELETE") + } } func newTestServer(sessions map[string]*session) *Server { From b0badb7735097243f653c10dae8acb9f3a99c9e6 Mon Sep 17 00:00:00 2001 From: jyxjjj <773933146@qq.com> Date: Wed, 20 May 2026 09:10:41 +0800 Subject: [PATCH 07/16] refactor(mcp): use mutex for session store MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 MCP session 锁改为 Mutex - 让锁类型匹配会更新 lastUsedAt 的访问路径 Co-authored-by: Codex <267193182+codex@users.noreply.github.com> Signed-off-by: jyxjjj <16695261+jyxjjj@users.noreply.github.com> --- server/mcp/handler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/mcp/handler.go b/server/mcp/handler.go index 31349096a..7234015c4 100644 --- a/server/mcp/handler.go +++ b/server/mcp/handler.go @@ -36,7 +36,7 @@ type session struct { } type Server struct { - mu sync.RWMutex + mu sync.Mutex sessions map[string]*session } From aeef265cebdbfcae0a87875657f12a0d83eda99c Mon Sep 17 00:00:00 2001 From: jyxjjj <773933146@qq.com> Date: Wed, 20 May 2026 09:15:09 +0800 Subject: [PATCH 08/16] fix(mcp): negotiate initialize protocol version MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加 MCP 协议版本协商函数 - 不支持客户端版本时返回服务端支持版本 Co-authored-by: Codex <267193182+codex@users.noreply.github.com> Signed-off-by: jyxjjj <16695261+jyxjjj@users.noreply.github.com> --- server/mcp/handler.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/server/mcp/handler.go b/server/mcp/handler.go index 7234015c4..a9336134a 100644 --- a/server/mcp/handler.go +++ b/server/mcp/handler.go @@ -27,6 +27,8 @@ const ( maxSessions = 1024 ) +var supportedProtocolVersions = []string{ProtocolVersion} + type session struct { id string userID uint @@ -208,7 +210,7 @@ func (s *Server) handleInitialize(c *gin.Context, req request) { JSONRPC: "2.0", ID: req.ID, Result: map[string]any{ - "protocolVersion": ProtocolVersion, + "protocolVersion": negotiateProtocolVersion(params.ProtocolVersion), "capabilities": map[string]any{ "tools": map[string]any{ "listChanged": false, @@ -223,6 +225,15 @@ func (s *Server) handleInitialize(c *gin.Context, req request) { }) } +func negotiateProtocolVersion(clientVersion string) string { + for _, supportedVersion := range supportedProtocolVersions { + if clientVersion == supportedVersion { + return supportedVersion + } + } + return ProtocolVersion +} + func (s *Server) handleDelete(c *gin.Context) { if !validateOrigin(c.Request) { c.Status(http.StatusForbidden) From a32810e655be59411402d6b7041837dce49826e4 Mon Sep 17 00:00:00 2001 From: jyxjjj <773933146@qq.com> Date: Wed, 20 May 2026 09:23:50 +0800 Subject: [PATCH 09/16] revert(mcp): drop protocol version negotiation wrapper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 撤回 MCP 协议版本协商包装函数 - 恢复 initialize 直接返回服务端协议版本 Co-authored-by: Codex <267193182+codex@users.noreply.github.com> Signed-off-by: jyxjjj <16695261+jyxjjj@users.noreply.github.com> --- server/mcp/handler.go | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/server/mcp/handler.go b/server/mcp/handler.go index a9336134a..7234015c4 100644 --- a/server/mcp/handler.go +++ b/server/mcp/handler.go @@ -27,8 +27,6 @@ const ( maxSessions = 1024 ) -var supportedProtocolVersions = []string{ProtocolVersion} - type session struct { id string userID uint @@ -210,7 +208,7 @@ func (s *Server) handleInitialize(c *gin.Context, req request) { JSONRPC: "2.0", ID: req.ID, Result: map[string]any{ - "protocolVersion": negotiateProtocolVersion(params.ProtocolVersion), + "protocolVersion": ProtocolVersion, "capabilities": map[string]any{ "tools": map[string]any{ "listChanged": false, @@ -225,15 +223,6 @@ func (s *Server) handleInitialize(c *gin.Context, req request) { }) } -func negotiateProtocolVersion(clientVersion string) string { - for _, supportedVersion := range supportedProtocolVersions { - if clientVersion == supportedVersion { - return supportedVersion - } - } - return ProtocolVersion -} - func (s *Server) handleDelete(c *gin.Context) { if !validateOrigin(c.Request) { c.Status(http.StatusForbidden) From 5ba832affb8f8cfef8c3b8653117289007e832fa Mon Sep 17 00:00:00 2001 From: jyxjjj <773933146@qq.com> Date: Wed, 20 May 2026 09:26:57 +0800 Subject: [PATCH 10/16] fix(mcp): limit and reuse sessions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 MCP 全局 session 上限调整为 128 - 添加单用户 16 个 session 上限 - initialize 复用同用户已有 session 标识 - 补充 session 复用和上限裁剪测试 Co-authored-by: Codex <267193182+codex@users.noreply.github.com> Signed-off-by: jyxjjj <16695261+jyxjjj@users.noreply.github.com> --- server/mcp/handler.go | 68 ++++++++++++++++++++++++-- server/mcp/handler_test.go | 99 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+), 3 deletions(-) diff --git a/server/mcp/handler.go b/server/mcp/handler.go index 7234015c4..98df91491 100644 --- a/server/mcp/handler.go +++ b/server/mcp/handler.go @@ -24,7 +24,8 @@ const ( ProtocolVersion = "2025-06-18" SessionHeader = "Mcp-Session-Id" sessionTTL = 30 * time.Minute - maxSessions = 1024 + maxSessions = 128 + maxUserSessions = 16 ) type session struct { @@ -202,7 +203,10 @@ func (s *Server) handleInitialize(c *gin.Context, req request) { } } - currentSession := s.createSession(c.Request.Context().Value(conf.UserKey).(*model.User).ID) + currentSession := s.initializeSession( + c.Request.Context().Value(conf.UserKey).(*model.User).ID, + c.GetHeader(SessionHeader), + ) c.Header(SessionHeader, currentSession.id) c.JSON(http.StatusOK, response{ JSONRPC: "2.0", @@ -223,6 +227,24 @@ func (s *Server) handleInitialize(c *gin.Context, req request) { }) } +func (s *Server) initializeSession(userID uint, requestedID string) *session { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + s.pruneExpiredSessionsLocked(now) + if requestedID != "" { + currentSession, ok := s.sessions[requestedID] + if ok && currentSession != nil && currentSession.userID == userID { + currentSession.initialized = false + currentSession.lastUsedAt = now + return currentSession + } + } + + return s.createSessionLocked(userID, now) +} + func (s *Server) handleDelete(c *gin.Context) { if !validateOrigin(c.Request) { c.Status(http.StatusForbidden) @@ -251,8 +273,12 @@ func (s *Server) createSession(userID uint) *session { now := time.Now() s.pruneExpiredSessionsLocked(now) - s.pruneLeastRecentlyUsedSessionsLocked(max(0, len(s.sessions)-maxSessions+1)) + return s.createSessionLocked(userID, now) +} +func (s *Server) createSessionLocked(userID uint, now time.Time) *session { + s.pruneLeastRecentlyUsedUserSessionsLocked(userID, max(0, s.countUserSessionsLocked(userID)-maxUserSessions+1)) + s.pruneLeastRecentlyUsedSessionsLocked(max(0, len(s.sessions)-maxSessions+1)) currentSession := &session{ id: random.Token(), userID: userID, @@ -353,6 +379,42 @@ func (s *Server) pruneLeastRecentlyUsedSessionsLocked(count int) { } } +func (s *Server) countUserSessionsLocked(userID uint) int { + count := 0 + for _, currentSession := range s.sessions { + if currentSession != nil && currentSession.userID == userID { + count++ + } + } + return count +} + +func (s *Server) pruneLeastRecentlyUsedUserSessionsLocked(userID uint, count int) { + for range count { + var ( + oldestID string + oldest time.Time + ) + for id, currentSession := range s.sessions { + if currentSession == nil { + continue + } + if currentSession.userID != userID { + continue + } + lastUsedAt := sessionLastUsedAt(currentSession) + if oldestID == "" || lastUsedAt.Before(oldest) { + oldestID = id + oldest = lastUsedAt + } + } + if oldestID == "" { + return + } + delete(s.sessions, oldestID) + } +} + func sessionExpired(currentSession *session, now time.Time) bool { lastUsedAt := sessionLastUsedAt(currentSession) return !lastUsedAt.IsZero() && now.Sub(lastUsedAt) > sessionTTL diff --git a/server/mcp/handler_test.go b/server/mcp/handler_test.go index b6a42230e..70bbf22b3 100644 --- a/server/mcp/handler_test.go +++ b/server/mcp/handler_test.go @@ -56,6 +56,53 @@ func TestInitializeCreatesSession(t *testing.T) { } } +func TestInitializeReusesExistingSession(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(nil) + currentSession := srv.createSession(1) + currentSession.initialized = true + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"initialize", + "params":{ + "protocolVersion":"2025-06-18", + "capabilities":{}, + "clientInfo":{"name":"test-client","version":"1.0.0"} + } + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(SessionHeader, currentSession.id) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) + } + if got := w.Header().Get(SessionHeader); got != currentSession.id { + t.Fatalf("unexpected session header: got %q want %q", got, currentSession.id) + } + if len(srv.sessions) != 1 { + t.Fatalf("unexpected session count: got %d want %d", len(srv.sessions), 1) + } + reusedSession, ok := srv.getSession(currentSession.id) + if !ok { + t.Fatal("expected existing session to be reused") + } + if reusedSession.initialized { + t.Fatal("expected reused session to require initialized notification again") + } +} + func TestInitializeNegotiatesUnsupportedProtocolVersion(t *testing.T) { gin.SetMode(gin.TestMode) srv := newTestServer(nil) @@ -98,6 +145,48 @@ func TestInitializeNegotiatesUnsupportedProtocolVersion(t *testing.T) { } } +func TestCreateSessionPrunesUserSessionLimit(t *testing.T) { + srv := newTestServer(nil) + var firstSessionID string + for i := range maxUserSessions { + currentSession := srv.createSession(1) + if i == 0 { + firstSessionID = currentSession.id + } + } + srv.createSession(2) + srv.createSession(1) + + if _, ok := srv.sessions[firstSessionID]; ok { + t.Fatal("expected oldest user session to be pruned") + } + if got := countSessionsForUser(srv, 1); got != maxUserSessions { + t.Fatalf("unexpected user session count: got %d want %d", got, maxUserSessions) + } + if got := len(srv.sessions); got != maxUserSessions+1 { + t.Fatalf("unexpected total session count: got %d want %d", got, maxUserSessions+1) + } +} + +func TestCreateSessionPrunesGlobalSessionLimit(t *testing.T) { + srv := newTestServer(nil) + var firstSessionID string + for i := range maxSessions { + currentSession := srv.createSession(uint(i + 1)) + if i == 0 { + firstSessionID = currentSession.id + } + } + srv.createSession(uint(maxSessions + 1)) + + if _, ok := srv.sessions[firstSessionID]; ok { + t.Fatal("expected oldest global session to be pruned") + } + if got := len(srv.sessions); got != maxSessions { + t.Fatalf("unexpected session count: got %d want %d", got, maxSessions) + } +} + func TestPostInvalidAcceptReturnsJSONRPCError(t *testing.T) { gin.SetMode(gin.TestMode) srv := newTestServer(nil) @@ -192,3 +281,13 @@ func newTestServer(sessions map[string]*session) *Server { } return &Server{sessions: sessions} } + +func countSessionsForUser(srv *Server, userID uint) int { + count := 0 + for _, currentSession := range srv.sessions { + if currentSession != nil && currentSession.userID == userID { + count++ + } + } + return count +} From b6fc2f5c773cd0168a5c83f6b02a161ab55e5108 Mon Sep 17 00:00:00 2001 From: jyxjjj <773933146@qq.com> Date: Wed, 20 May 2026 11:37:31 +0800 Subject: [PATCH 11/16] fix(mcp): require path for fs list MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 拒绝 openlist.fs.list 的空参数和 null 参数 - 校验 fs.list 缺失 path 时返回 -32602 - 添加 fs.list 参数解析测试覆盖错误文案 Co-authored-by: Codex <267193182+codex@users.noreply.github.com> Signed-off-by: jyxjjj <16695261+jyxjjj@users.noreply.github.com> --- server/mcp/fs_list.go | 5 +++- server/mcp/list_test.go | 54 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 server/mcp/list_test.go diff --git a/server/mcp/fs_list.go b/server/mcp/fs_list.go index 6e2c7c72e..ac28cf756 100644 --- a/server/mcp/fs_list.go +++ b/server/mcp/fs_list.go @@ -86,12 +86,15 @@ func parseFSListArgs(raw json.RawMessage) (*fsListArgs, *rpcError) { PerPage: model.MaxInt, } if len(raw) == 0 || string(raw) == "null" { - return args, nil + return nil, &rpcError{Code: -32602, Message: "invalid openlist.fs.list arguments"} } if err := json.Unmarshal(raw, args); err != nil { return nil, &rpcError{Code: -32602, Message: "invalid openlist.fs.list arguments"} } + if args.Path == "" { + return nil, &rpcError{Code: -32602, Message: "path is required"} + } normalizeFSListArgs(args) return args, nil } diff --git a/server/mcp/list_test.go b/server/mcp/list_test.go new file mode 100644 index 000000000..8079245aa --- /dev/null +++ b/server/mcp/list_test.go @@ -0,0 +1,54 @@ +package mcp + +import ( + "encoding/json" + "testing" +) + +func TestParseFSListArgsRequiresPath(t *testing.T) { + assertFSListPathRequired(t, json.RawMessage(`{"refresh":true}`)) +} + +func TestParseFSListArgsRejectsEmptyArguments(t *testing.T) { + for _, raw := range []json.RawMessage{nil, json.RawMessage(`null`)} { + assertFSListInvalidArguments(t, raw) + } +} + +func TestParseFSListArgsRejectsInvalidJSON(t *testing.T) { + assertFSListInvalidArguments(t, json.RawMessage(`"bad"`)) +} + +func TestParseFSListArgs(t *testing.T) { + args, err := parseFSListArgs(json.RawMessage(`{"path":"/movies","password":"secret","refresh":true,"page":2,"per_page":10}`)) + if err != nil { + t.Fatalf("unexpected error: %+v", err) + } + if args.Path != "/movies" || args.Password != "secret" || !args.Refresh || args.Page != 2 || args.PerPage != 10 { + t.Fatalf("unexpected args: %+v", args) + } +} + +func assertFSListPathRequired(t *testing.T, raw json.RawMessage) { + t.Helper() + + _, err := parseFSListArgs(raw) + if err == nil { + t.Fatal("expected error") + } + if err.Code != -32602 || err.Message != "path is required" { + t.Fatalf("unexpected error: %+v", err) + } +} + +func assertFSListInvalidArguments(t *testing.T, raw json.RawMessage) { + t.Helper() + + _, err := parseFSListArgs(raw) + if err == nil { + t.Fatal("expected error") + } + if err.Code != -32602 || err.Message != "invalid openlist.fs.list arguments" { + t.Fatalf("unexpected error: %+v", err) + } +} From 210d912454ba32a0409c248f3dad0a1bade92a18 Mon Sep 17 00:00:00 2001 From: jyxjjj <773933146@qq.com> Date: Wed, 20 May 2026 11:41:17 +0800 Subject: [PATCH 12/16] fix(mcp): enforce protocol version header MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 校验非 initialize 请求的 MCP-Protocol-Version 头 - 拒绝缺失或不支持协议版本的后续 POST 请求 - 添加协议版本头缺失和不匹配测试 Co-authored-by: Codex <267193182+codex@users.noreply.github.com> Signed-off-by: jyxjjj <16695261+jyxjjj@users.noreply.github.com> --- server/mcp/call_test.go | 4 +++ server/mcp/handler.go | 19 ++++++++--- server/mcp/handler_test.go | 67 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 5 deletions(-) diff --git a/server/mcp/call_test.go b/server/mcp/call_test.go index d39b260af..b922ffd52 100644 --- a/server/mcp/call_test.go +++ b/server/mcp/call_test.go @@ -32,6 +32,7 @@ func TestToolsListRequiresInitializedSession(t *testing.T) { }`)) req.Header.Set("Accept", "application/json, text/event-stream") req.Header.Set("Origin", "http://example.com") + req.Header.Set(ProtocolVersionHeader, ProtocolVersion) req.Header.Set(SessionHeader, "s1") w := httptest.NewRecorder() @@ -65,6 +66,7 @@ func TestToolsListSuccess(t *testing.T) { }`)) req.Header.Set("Accept", "application/json, text/event-stream") req.Header.Set("Origin", "http://example.com") + req.Header.Set(ProtocolVersionHeader, ProtocolVersion) req.Header.Set(SessionHeader, "s2") w := httptest.NewRecorder() @@ -120,6 +122,7 @@ func TestToolsCallUnknownTool(t *testing.T) { }`)) req.Header.Set("Accept", "application/json, text/event-stream") req.Header.Set("Origin", "http://example.com") + req.Header.Set(ProtocolVersionHeader, ProtocolVersion) req.Header.Set(SessionHeader, "s3") w := httptest.NewRecorder() @@ -154,6 +157,7 @@ func TestToolsCallInvalidParams(t *testing.T) { }`)) req.Header.Set("Accept", "application/json, text/event-stream") req.Header.Set("Origin", "http://example.com") + req.Header.Set(ProtocolVersionHeader, ProtocolVersion) req.Header.Set(SessionHeader, "s4") w := httptest.NewRecorder() diff --git a/server/mcp/handler.go b/server/mcp/handler.go index 98df91491..1ed8ae926 100644 --- a/server/mcp/handler.go +++ b/server/mcp/handler.go @@ -21,11 +21,12 @@ import ( ) const ( - ProtocolVersion = "2025-06-18" - SessionHeader = "Mcp-Session-Id" - sessionTTL = 30 * time.Minute - maxSessions = 128 - maxUserSessions = 16 + ProtocolVersion = "2025-06-18" + ProtocolVersionHeader = "MCP-Protocol-Version" + SessionHeader = "Mcp-Session-Id" + sessionTTL = 30 * time.Minute + maxSessions = 128 + maxUserSessions = 16 ) type session struct { @@ -132,6 +133,14 @@ func (s *Server) handlePost(c *gin.Context) { s.handleInitialize(c, req) return } + if c.GetHeader(ProtocolVersionHeader) != ProtocolVersion { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32000, Message: "missing or unsupported MCP protocol version"}, + }) + return + } sessionID := c.GetHeader(SessionHeader) currentSession, ok := s.getSession(sessionID) diff --git a/server/mcp/handler_test.go b/server/mcp/handler_test.go index 70bbf22b3..5439f5b87 100644 --- a/server/mcp/handler_test.go +++ b/server/mcp/handler_test.go @@ -217,6 +217,73 @@ func TestPostInvalidAcceptReturnsJSONRPCError(t *testing.T) { } } +func TestPostRequiresProtocolVersionAfterInitialize(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(map[string]*session{ + "s1": {id: "s1", userID: 1, initialized: true}, + }) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"tools/list" + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(SessionHeader, "s1") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusBadRequest) + } + resp := decodeResponse(t, w) + if resp.Error == nil || resp.Error.Code != -32000 { + t.Fatalf("unexpected error response: %+v", resp.Error) + } +} + +func TestPostRejectsUnsupportedProtocolVersion(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(map[string]*session{ + "s1": {id: "s1", userID: 1, initialized: true}, + }) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"tools/list" + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(ProtocolVersionHeader, "2025-03-26") + req.Header.Set(SessionHeader, "s1") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusBadRequest) + } + resp := decodeResponse(t, w) + if resp.Error == nil || resp.Error.Code != -32000 { + t.Fatalf("unexpected error response: %+v", resp.Error) + } +} + func TestDeleteRemovesSession(t *testing.T) { gin.SetMode(gin.TestMode) srv := newTestServer(nil) From ba81f6be0d7e7e0ca99d9f21443f2049d4999fef Mon Sep 17 00:00:00 2001 From: jyxjjj <773933146@qq.com> Date: Wed, 20 May 2026 11:49:49 +0800 Subject: [PATCH 13/16] test(mcp): serialize setting cache mutation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 为修改全局设置缓存的测试添加包级互斥锁 - 保留 ClearAll 清理逻辑,避免并发测试互相影响 Co-authored-by: Codex <267193182+codex@users.noreply.github.com> Signed-off-by: jyxjjj <16695261+jyxjjj@users.noreply.github.com> --- server/mcp/link_test.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/server/mcp/link_test.go b/server/mcp/link_test.go index 66c820ef6..10896d80c 100644 --- a/server/mcp/link_test.go +++ b/server/mcp/link_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "net/http/httptest" + "sync" "testing" "github.com/OpenListTeam/OpenList/v4/internal/conf" @@ -13,6 +14,8 @@ import ( "github.com/gin-gonic/gin" ) +var settingCacheMu sync.Mutex + func TestParseFSLinkArgs(t *testing.T) { args, err := parseFSLinkArgs(json.RawMessage(`{"path":"/file.txt","password":"pw","type":"thumb"}`)) if err != nil { @@ -64,7 +67,11 @@ func TestCanProxyFileIgnoresWebdavProxyURLPolicy(t *testing.T) { func TestBuildFSLinkInfoUsesProxyWhenStorageRequiresProxy(t *testing.T) { gin.SetMode(gin.TestMode) - t.Cleanup(op.Cache.ClearAll) + settingCacheMu.Lock() + t.Cleanup(func() { + op.Cache.ClearAll() + settingCacheMu.Unlock() + }) op.Cache.SetSetting(conf.LinkExpiration, &model.SettingItem{Key: conf.LinkExpiration, Value: "0"}) op.Cache.SetSetting(conf.Token, &model.SettingItem{Key: conf.Token, Value: "test-token"}) c, _ := gin.CreateTestContext(httptest.NewRecorder()) From f9b47b3b4f572614076400f96a55959fbd2c949b Mon Sep 17 00:00:00 2001 From: jyxjjj <773933146@qq.com> Date: Wed, 27 May 2026 09:33:09 +0800 Subject: [PATCH 14/16] feat(mcp): add config switch - Add MCP config section with disabled default - Register MCP routes only when enabled Co-authored-by: Codex <267193182+codex@users.noreply.github.com> Signed-off-by: jyxjjj <16695261+jyxjjj@users.noreply.github.com> --- internal/conf/config.go | 8 ++++++++ server/mcp.go | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/internal/conf/config.go b/internal/conf/config.go index 4a1172815..0f6676116 100644 --- a/internal/conf/config.go +++ b/internal/conf/config.go @@ -106,6 +106,10 @@ type SFTP struct { Listen string `json:"listen" env:"LISTEN"` } +type MCP struct { + Enable bool `json:"enable" env:"ENABLE"` +} + type Config struct { Force bool `json:"force" env:"FORCE"` SiteURL string `json:"site_url" env:"SITE_URL"` @@ -131,6 +135,7 @@ type Config struct { S3 S3 `json:"s3" envPrefix:"S3_"` FTP FTP `json:"ftp" envPrefix:"FTP_"` SFTP SFTP `json:"sftp" envPrefix:"SFTP_"` + MCP MCP `json:"mcp" envPrefix:"MCP_"` LastLaunchedVersion string `json:"last_launched_version"` ProxyAddress string `json:"proxy_address" env:"PROXY_ADDRESS"` } @@ -244,6 +249,9 @@ func DefaultConfig(dataDir string) *Config { Enable: false, Listen: ":5222", }, + MCP: MCP{ + Enable: false, + }, LastLaunchedVersion: "", ProxyAddress: "", } diff --git a/server/mcp.go b/server/mcp.go index 9d85c02cc..c036ce389 100644 --- a/server/mcp.go +++ b/server/mcp.go @@ -1,10 +1,18 @@ package server import ( + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/server/common" "github.com/OpenListTeam/OpenList/v4/server/mcp" "github.com/gin-gonic/gin" ) func MCP(g *gin.RouterGroup) { + if !conf.Conf.MCP.Enable { + g.Any("/mcp", func(c *gin.Context) { + common.ErrorStrResp(c, "MCP server is not enabled", 403) + }) + return + } mcp.Register(g) } From eb40fad3bfcda80985f8b9f4d516254e18fd0e5c Mon Sep 17 00:00:00 2001 From: jyxjjj <773933146@qq.com> Date: Wed, 27 May 2026 10:16:57 +0800 Subject: [PATCH 15/16] fix(mcp): handle missing session explicitly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 区分缺失 MCP session 与未知 MCP session - 为未知 session 返回 not found 错误 - 添加 MCP session 错误处理测试 Co-authored-by: Codex <267193182+codex@users.noreply.github.com> Signed-off-by: jyxjjj <16695261+jyxjjj@users.noreply.github.com> --- server/mcp/handler.go | 13 ++++++-- server/mcp/handler_test.go | 63 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/server/mcp/handler.go b/server/mcp/handler.go index 1ed8ae926..986743e0e 100644 --- a/server/mcp/handler.go +++ b/server/mcp/handler.go @@ -143,12 +143,21 @@ func (s *Server) handlePost(c *gin.Context) { } sessionID := c.GetHeader(SessionHeader) + if sessionID == "" { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32000, Message: "missing MCP session"}, + }) + return + } + currentSession, ok := s.getSession(sessionID) if !ok { - c.JSON(http.StatusBadRequest, response{ + c.JSON(http.StatusNotFound, response{ JSONRPC: "2.0", ID: req.ID, - Error: &rpcError{Code: -32000, Message: "missing or invalid MCP session"}, + Error: &rpcError{Code: -32001, Message: "session not found"}, }) return } diff --git a/server/mcp/handler_test.go b/server/mcp/handler_test.go index 5439f5b87..dbc6078e7 100644 --- a/server/mcp/handler_test.go +++ b/server/mcp/handler_test.go @@ -284,6 +284,69 @@ func TestPostRejectsUnsupportedProtocolVersion(t *testing.T) { } } +func TestPostRejectsMissingSession(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(nil) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"tools/list" + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(ProtocolVersionHeader, ProtocolVersion) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusBadRequest) + } + resp := decodeResponse(t, w) + if resp.Error == nil || resp.Error.Code != -32000 { + t.Fatalf("unexpected error response: %+v", resp.Error) + } +} + +func TestPostRejectsUnknownSessionWithNotFound(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(nil) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"tools/list" + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(ProtocolVersionHeader, ProtocolVersion) + req.Header.Set(SessionHeader, "unknown") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusNotFound) + } + resp := decodeResponse(t, w) + if resp.Error == nil || resp.Error.Code != -32001 { + t.Fatalf("unexpected error response: %+v", resp.Error) + } +} + func TestDeleteRemovesSession(t *testing.T) { gin.SetMode(gin.TestMode) srv := newTestServer(nil) From 0753d8ffc1d0513d969de35fb442f711697e04b8 Mon Sep 17 00:00:00 2001 From: jyxjjj <773933146@qq.com> Date: Wed, 27 May 2026 10:26:11 +0800 Subject: [PATCH 16/16] feat(mcp): support latest protocol negotiation - Upgrade default MCP protocol version to 2025-11-25 - Preserve 2025-06-18 compatibility through initialize negotiation - Validate subsequent request protocol version against the negotiated session version - Add tests for latest and older protocol negotiation paths Co-authored-by: Codex <267193182+codex@users.noreply.github.com> Signed-off-by: jyxjjj <16695261+jyxjjj@users.noreply.github.com> --- server/mcp/handler.go | 77 ++++++++++++++++++++++----------- server/mcp/handler_test.go | 88 +++++++++++++++++++++++++++++++++++++- 2 files changed, 139 insertions(+), 26 deletions(-) diff --git a/server/mcp/handler.go b/server/mcp/handler.go index 986743e0e..ac08747a7 100644 --- a/server/mcp/handler.go +++ b/server/mcp/handler.go @@ -21,20 +21,21 @@ import ( ) const ( - ProtocolVersion = "2025-06-18" + ProtocolVersion = "2025-11-25" ProtocolVersionHeader = "MCP-Protocol-Version" - SessionHeader = "Mcp-Session-Id" + SessionHeader = "MCP-Session-Id" sessionTTL = 30 * time.Minute maxSessions = 128 maxUserSessions = 16 ) type session struct { - id string - userID uint - initialized bool - createdAt time.Time - lastUsedAt time.Time + id string + userID uint + protocolVersion string + initialized bool + createdAt time.Time + lastUsedAt time.Time } type Server struct { @@ -71,6 +72,11 @@ var defaultServer = &Server{ sessions: map[string]*session{}, } +var supportedProtocolVersions = map[string]struct{}{ + "2025-11-25": {}, + "2025-06-18": {}, +} + func Register(g *gin.RouterGroup) { mcpGroup := g.Group("/mcp", middlewares.Auth(false), middlewares.AuthAdmin) mcpGroup.GET("", defaultServer.handleGet) @@ -133,15 +139,6 @@ func (s *Server) handlePost(c *gin.Context) { s.handleInitialize(c, req) return } - if c.GetHeader(ProtocolVersionHeader) != ProtocolVersion { - c.JSON(http.StatusBadRequest, response{ - JSONRPC: "2.0", - ID: req.ID, - Error: &rpcError{Code: -32000, Message: "missing or unsupported MCP protocol version"}, - }) - return - } - sessionID := c.GetHeader(SessionHeader) if sessionID == "" { c.JSON(http.StatusBadRequest, response{ @@ -172,6 +169,15 @@ func (s *Server) handlePost(c *gin.Context) { return } + if !s.validateRequestProtocolVersion(c.GetHeader(ProtocolVersionHeader), currentSession.protocolVersion) { + c.JSON(http.StatusBadRequest, response{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32000, Message: "missing or unsupported MCP protocol version"}, + }) + return + } + switch req.Method { case "ping": c.JSON(http.StatusOK, response{JSONRPC: "2.0", ID: req.ID, Result: map[string]any{}}) @@ -221,16 +227,18 @@ func (s *Server) handleInitialize(c *gin.Context, req request) { } } + protocolVersion := negotiateProtocolVersion(params.ProtocolVersion) currentSession := s.initializeSession( c.Request.Context().Value(conf.UserKey).(*model.User).ID, c.GetHeader(SessionHeader), + protocolVersion, ) c.Header(SessionHeader, currentSession.id) c.JSON(http.StatusOK, response{ JSONRPC: "2.0", ID: req.ID, Result: map[string]any{ - "protocolVersion": ProtocolVersion, + "protocolVersion": protocolVersion, "capabilities": map[string]any{ "tools": map[string]any{ "listChanged": false, @@ -245,7 +253,7 @@ func (s *Server) handleInitialize(c *gin.Context, req request) { }) } -func (s *Server) initializeSession(userID uint, requestedID string) *session { +func (s *Server) initializeSession(userID uint, requestedID string, protocolVersion string) *session { s.mu.Lock() defer s.mu.Unlock() @@ -255,12 +263,13 @@ func (s *Server) initializeSession(userID uint, requestedID string) *session { currentSession, ok := s.sessions[requestedID] if ok && currentSession != nil && currentSession.userID == userID { currentSession.initialized = false + currentSession.protocolVersion = protocolVersion currentSession.lastUsedAt = now return currentSession } } - return s.createSessionLocked(userID, now) + return s.createSessionLocked(userID, protocolVersion, now) } func (s *Server) handleDelete(c *gin.Context) { @@ -291,22 +300,40 @@ func (s *Server) createSession(userID uint) *session { now := time.Now() s.pruneExpiredSessionsLocked(now) - return s.createSessionLocked(userID, now) + return s.createSessionLocked(userID, ProtocolVersion, now) } -func (s *Server) createSessionLocked(userID uint, now time.Time) *session { +func (s *Server) createSessionLocked(userID uint, protocolVersion string, now time.Time) *session { s.pruneLeastRecentlyUsedUserSessionsLocked(userID, max(0, s.countUserSessionsLocked(userID)-maxUserSessions+1)) s.pruneLeastRecentlyUsedSessionsLocked(max(0, len(s.sessions)-maxSessions+1)) currentSession := &session{ - id: random.Token(), - userID: userID, - createdAt: now, - lastUsedAt: now, + id: random.Token(), + userID: userID, + protocolVersion: protocolVersion, + createdAt: now, + lastUsedAt: now, } s.sessions[currentSession.id] = currentSession return currentSession } +func (s *Server) validateRequestProtocolVersion(requestedVersion string, negotiatedVersion string) bool { + if requestedVersion == "" { + return false + } + if _, ok := supportedProtocolVersions[requestedVersion]; !ok { + return false + } + return negotiatedVersion == "" || requestedVersion == negotiatedVersion +} + +func negotiateProtocolVersion(requestedVersion string) string { + if _, ok := supportedProtocolVersions[requestedVersion]; ok { + return requestedVersion + } + return ProtocolVersion +} + func (s *Server) getSession(id string) (session, bool) { if id == "" { return session{}, false diff --git a/server/mcp/handler_test.go b/server/mcp/handler_test.go index dbc6078e7..ac7a79158 100644 --- a/server/mcp/handler_test.go +++ b/server/mcp/handler_test.go @@ -29,7 +29,7 @@ func TestInitializeCreatesSession(t *testing.T) { "id":1, "method":"initialize", "params":{ - "protocolVersion":"2025-06-18", + "protocolVersion":"2025-11-25", "capabilities":{}, "clientInfo":{"name":"test-client","version":"1.0.0"} } @@ -54,6 +54,55 @@ func TestInitializeCreatesSession(t *testing.T) { if resp.Error != nil { t.Fatalf("unexpected error response: %+v", resp.Error) } + result, ok := resp.Result.(map[string]any) + if !ok { + t.Fatalf("unexpected result type: %T", resp.Result) + } + if result["protocolVersion"] != ProtocolVersion { + t.Fatalf("unexpected protocol version: %v", result["protocolVersion"]) + } +} + +func TestInitializeNegotiatesSupportedOlderProtocolVersion(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(nil) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"initialize", + "params":{ + "protocolVersion":"2025-06-18", + "capabilities":{}, + "clientInfo":{"name":"test-client","version":"1.0.0"} + } + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusOK) + } + resp := decodeResponse(t, w) + if resp.Error != nil { + t.Fatalf("unexpected error response: %+v", resp.Error) + } + result, ok := resp.Result.(map[string]any) + if !ok { + t.Fatalf("unexpected result type: %T", resp.Result) + } + if result["protocolVersion"] != "2025-06-18" { + t.Fatalf("unexpected protocol version: %v", result["protocolVersion"]) + } } func TestInitializeReusesExistingSession(t *testing.T) { @@ -101,6 +150,9 @@ func TestInitializeReusesExistingSession(t *testing.T) { if reusedSession.initialized { t.Fatal("expected reused session to require initialized notification again") } + if reusedSession.protocolVersion != "2025-06-18" { + t.Fatalf("unexpected protocol version: got %q want %q", reusedSession.protocolVersion, "2025-06-18") + } } func TestInitializeNegotiatesUnsupportedProtocolVersion(t *testing.T) { @@ -284,6 +336,40 @@ func TestPostRejectsUnsupportedProtocolVersion(t *testing.T) { } } +func TestPostRejectsProtocolVersionMismatch(t *testing.T) { + gin.SetMode(gin.TestMode) + srv := newTestServer(map[string]*session{ + "s1": {id: "s1", userID: 1, protocolVersion: "2025-06-18", initialized: true}, + }) + + r := gin.New() + r.POST("/mcp", func(c *gin.Context) { + common.GinAppendValues(c, conf.UserKey, &model.User{ID: 1, Role: model.ADMIN}) + srv.handlePost(c) + }) + + req := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", strings.NewReader(`{ + "jsonrpc":"2.0", + "id":1, + "method":"tools/list" + }`)) + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Origin", "http://example.com") + req.Header.Set(ProtocolVersionHeader, ProtocolVersion) + req.Header.Set(SessionHeader, "s1") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("unexpected status: got %d want %d", w.Code, http.StatusBadRequest) + } + resp := decodeResponse(t, w) + if resp.Error == nil || resp.Error.Code != -32000 { + t.Fatalf("unexpected error response: %+v", resp.Error) + } +} + func TestPostRejectsMissingSession(t *testing.T) { gin.SetMode(gin.TestMode) srv := newTestServer(nil)