diff --git a/tavern/internal/mcp/export_test.go b/tavern/internal/mcp/export_test.go new file mode 100644 index 000000000..05bc769db --- /dev/null +++ b/tavern/internal/mcp/export_test.go @@ -0,0 +1,28 @@ +package mcp + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// Export for testing +var ( + HandleListHosts = handleListHosts + HandleListQuests = handleListQuests + HandleListTomes = handleListTomes + HandleQuestOutput = handleQuestOutput + HandleWaitForQuest = handleWaitForQuest + ClientFromContext = clientFromContext +) + +// Export internal context keys +type TestContextKey = contextKey +type TestGraphqlHandlerKey = graphqlHandlerKey + +// HandleCreateQuestForTest exposes handleCreateQuest which requires the server instance. +func HandleCreateQuestForTest(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // The real handler expects a server instance, but in unit tests we might + // pass a dummy or nil if the internal logic allows it. + return handleCreateQuest(nil)(ctx, req) +} diff --git a/tavern/internal/mcp/mcp_test.go b/tavern/internal/mcp/mcp_test.go index 15ca1ac34..916703c46 100644 --- a/tavern/internal/mcp/mcp_test.go +++ b/tavern/internal/mcp/mcp_test.go @@ -2,6 +2,7 @@ package mcp_test import ( "context" + "fmt" "net/http" "net/http/httptest" "testing" @@ -503,3 +504,548 @@ func TestGraphQLQueryValidation(t *testing.T) { }) } } + +func TestHandleListHosts(t *testing.T) { + ctx := context.Background() + client := setupTestDB(t) + defer client.Close() + + // Create test data + host := client.Host.Create(). + SetIdentifier("test-host-id-123"). + SetName("test-host"). + SetPlatform(c2pb.Host_PLATFORM_LINUX). + SetPrimaryIP("192.168.1.1"). + SaveX(ctx) + + client.Beacon.Create(). + SetHost(host). + SetName("test-beacon"). + SetTransport(c2pb.Transport_TRANSPORT_UNSPECIFIED). + SetPrincipal("root"). + SaveX(ctx) + + client.Tag.Create(). + SetName("test-tag"). + SetKind("group"). + AddHosts(host). + SaveX(ctx) + + // Setup context with client + testCtx := context.WithValue(ctx, tavernmcp.TestContextKey{}, client) + + req := mcp.CallToolRequest{} + + res, err := tavernmcp.HandleListHosts(testCtx, req) + require.NoError(t, err) + assert.False(t, res.IsError) + require.Len(t, res.Content, 1) + + textContent, ok := res.Content[0].(mcp.TextContent) + require.True(t, ok) + + // Verify the JSON string + assert.Contains(t, textContent.Text, "test-host-id-123") + assert.Contains(t, textContent.Text, "test-host") + assert.Contains(t, textContent.Text, "PLATFORM_LINUX") + assert.Contains(t, textContent.Text, "test-beacon") + assert.Contains(t, textContent.Text, "root") + assert.Contains(t, textContent.Text, "test-tag") + assert.Contains(t, textContent.Text, "group") + + // Test nil client error + resErr, err := tavernmcp.HandleListHosts(context.Background(), req) + require.NoError(t, err) + assert.True(t, resErr.IsError) + textContentErr, ok := resErr.Content[0].(mcp.TextContent) + require.True(t, ok) + assert.Contains(t, textContentErr.Text, "internal error: no database client") +} + +func TestHandleListQuests(t *testing.T) { + ctx := context.Background() + client := setupTestDB(t) + defer client.Close() + + // Create test data + testTome := client.Tome.Create(). + SetName("test-tome"). + SetDescription("A test tome"). + SetAuthor("test-author"). + SetSupportModel(tome.SupportModelCOMMUNITY). + SetEldritch("print('hello')"). + SetHash("abc123"). + SetParamDefs(`[{"name":"key","type":"string"}]`). + SaveX(ctx) + + client.Quest.Create(). + SetName("test-quest"). + SetParameters(`{"key":"value"}`). + SetTomeID(testTome.ID). + SetParamDefsAtCreation(testTome.ParamDefs). + SetEldritchAtCreation(testTome.Eldritch). + SaveX(ctx) + + // Setup context with client + testCtx := context.WithValue(ctx, tavernmcp.TestContextKey{}, client) + + req := mcp.CallToolRequest{} + + res, err := tavernmcp.HandleListQuests(testCtx, req) + require.NoError(t, err) + assert.False(t, res.IsError) + require.Len(t, res.Content, 1) + + textContent, ok := res.Content[0].(mcp.TextContent) + require.True(t, ok) + + // Verify the JSON string + assert.Contains(t, textContent.Text, "test-quest") + assert.Contains(t, textContent.Text, "test-tome") + assert.Contains(t, textContent.Text, "A test tome") + + // Test nil client error + resErr, err := tavernmcp.HandleListQuests(context.Background(), req) + require.NoError(t, err) + assert.True(t, resErr.IsError) + textContentErr, ok := resErr.Content[0].(mcp.TextContent) + require.True(t, ok) + assert.Contains(t, textContentErr.Text, "internal error: no database client") +} + +func TestHandleListTomes(t *testing.T) { + ctx := context.Background() + client := setupTestDB(t) + defer client.Close() + + // Create test data + client.Tome.Create(). + SetName("tome-handler-test"). + SetDescription("First tome for handler"). + SetAuthor("test-author"). + SetSupportModel(tome.SupportModelCOMMUNITY). + SetEldritch("print('one')"). + SetHash("hash1"). + SetParamDefs(`[{"name":"param1","type":"string"}]`). + SaveX(ctx) + + // Setup context with client + testCtx := context.WithValue(ctx, tavernmcp.TestContextKey{}, client) + + req := mcp.CallToolRequest{} + + res, err := tavernmcp.HandleListTomes(testCtx, req) + require.NoError(t, err) + assert.False(t, res.IsError) + require.Len(t, res.Content, 1) + + textContent, ok := res.Content[0].(mcp.TextContent) + require.True(t, ok) + + // Verify the JSON string + assert.Contains(t, textContent.Text, "tome-handler-test") + assert.Contains(t, textContent.Text, "First tome for handler") + assert.Contains(t, textContent.Text, "param1") + + // Test nil client error + resErr, err := tavernmcp.HandleListTomes(context.Background(), req) + require.NoError(t, err) + assert.True(t, resErr.IsError) + textContentErr, ok := resErr.Content[0].(mcp.TextContent) + require.True(t, ok) + assert.Contains(t, textContentErr.Text, "internal error: no database client") +} + +func TestHandleQuestOutput(t *testing.T) { + ctx := context.Background() + client := setupTestDB(t) + defer client.Close() + + // Create test data + testTome := client.Tome.Create(). + SetName("test-tome"). + SetDescription("A test tome"). + SetAuthor("test-author"). + SetSupportModel(tome.SupportModelCOMMUNITY). + SetEldritch("print('hello')"). + SetHash("abc123"). + SaveX(ctx) + + testHost := client.Host.Create(). + SetIdentifier("test-host-id"). + SetName("test-host"). + SetPlatform(c2pb.Host_PLATFORM_UNSPECIFIED). + SaveX(ctx) + + testBeacon := client.Beacon.Create(). + SetHost(testHost). + SetName("test-beacon"). + SetTransport(c2pb.Transport_TRANSPORT_UNSPECIFIED). + SaveX(ctx) + + q := client.Quest.Create(). + SetName("output-quest"). + SetParameters("{}"). + SetTome(testTome). + SetParamDefsAtCreation("[]"). + SetEldritchAtCreation("print('hello')"). + SaveX(ctx) + + client.Task.Create(). + SetQuest(q). + SetBeacon(testBeacon). + SetOutput("task output result"). + SetExecFinishedAt(time.Now()). + SaveX(ctx) + + // Setup context with client + testCtx := context.WithValue(ctx, tavernmcp.TestContextKey{}, client) + + // Valid request + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]interface{}{ + "ids": []interface{}{float64(q.ID)}, + }, + }, + } + + res, err := tavernmcp.HandleQuestOutput(testCtx, req) + require.NoError(t, err) + assert.False(t, res.IsError) + require.Len(t, res.Content, 1) + + textContent, ok := res.Content[0].(mcp.TextContent) + require.True(t, ok) + + // Verify the JSON string + assert.Contains(t, textContent.Text, "output-quest") + assert.Contains(t, textContent.Text, "task output result") + assert.Contains(t, textContent.Text, "test-beacon") + assert.Contains(t, textContent.Text, "test-host") + + // Test nil client error + resErr, err := tavernmcp.HandleQuestOutput(context.Background(), req) + require.NoError(t, err) + assert.True(t, resErr.IsError) + textContentErr, ok := resErr.Content[0].(mcp.TextContent) + require.True(t, ok) + assert.Contains(t, textContentErr.Text, "internal error: no database client") + + // Test invalid IDs + reqInvalid := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]interface{}{ + "ids": "not-an-array", + }, + }, + } + resInvalid, err := tavernmcp.HandleQuestOutput(testCtx, reqInvalid) + require.NoError(t, err) + assert.True(t, resInvalid.IsError) + textContentInvalid, ok := resInvalid.Content[0].(mcp.TextContent) + require.True(t, ok) + assert.Contains(t, textContentInvalid.Text, "invalid ids") +} + +func TestHandleWaitForQuest(t *testing.T) { + ctx := context.Background() + client := setupTestDB(t) + defer client.Close() + + // Create test data + testHost := client.Host.Create(). + SetIdentifier("test-host-id"). + SetName("test-host"). + SetPlatform(c2pb.Host_PLATFORM_UNSPECIFIED). + SaveX(ctx) + + testBeacon := client.Beacon.Create(). + SetHost(testHost). + SetName("test-beacon"). + SetTransport(c2pb.Transport_TRANSPORT_UNSPECIFIED). + SaveX(ctx) + + testTome := client.Tome.Create(). + SetName("test-tome"). + SetDescription("A test tome"). + SetAuthor("test-author"). + SetSupportModel(tome.SupportModelCOMMUNITY). + SetEldritch("print('hello')"). + SetHash("abc123"). + SaveX(ctx) + + q := client.Quest.Create(). + SetName("wait-quest"). + SetParameters("{}"). + SetTome(testTome). + SetParamDefsAtCreation("[]"). + SetEldritchAtCreation("print('hello')"). + SaveX(ctx) + + client.Task.Create(). + SetQuest(q). + SetBeacon(testBeacon). + SetExecFinishedAt(time.Now()). + SaveX(ctx) + + // Setup context with client + testCtx := context.WithValue(ctx, tavernmcp.TestContextKey{}, client) + + // Valid request + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]interface{}{ + "quest_id": fmt.Sprintf("%d", q.ID), + }, + }, + } + + res, err := tavernmcp.HandleWaitForQuest(testCtx, req) + require.NoError(t, err) + assert.False(t, res.IsError) + require.Len(t, res.Content, 1) + + textContent, ok := res.Content[0].(mcp.TextContent) + require.True(t, ok) + + // Verify the JSON string + assert.Contains(t, textContent.Text, "have finished") + assert.Contains(t, textContent.Text, "test-beacon") + + // Test nil client error + resErr, err := tavernmcp.HandleWaitForQuest(context.Background(), req) + require.NoError(t, err) + assert.True(t, resErr.IsError) + textContentErr, ok := resErr.Content[0].(mcp.TextContent) + require.True(t, ok) + assert.Contains(t, textContentErr.Text, "internal error: no database client") + + // Test invalid request (missing quest_id) + reqMissing := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]interface{}{}, + }, + } + resMissing, err := tavernmcp.HandleWaitForQuest(testCtx, reqMissing) + require.NoError(t, err) + assert.True(t, resMissing.IsError) + + // Test invalid request (quest_id not a number) + reqInvalidStr := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]interface{}{ + "quest_id": "not-a-number", + }, + }, + } + resInvalidStr, err := tavernmcp.HandleWaitForQuest(testCtx, reqInvalidStr) + require.NoError(t, err) + assert.True(t, resInvalidStr.IsError) + + // Test quest not found + reqNotFound := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]interface{}{ + "quest_id": "99999", + }, + }, + } + resNotFound, err := tavernmcp.HandleWaitForQuest(testCtx, reqNotFound) + require.NoError(t, err) + assert.True(t, resNotFound.IsError) + + // Test quest with no tasks + qNoTasks := client.Quest.Create(). + SetName("no-tasks-quest"). + SetParameters("{}"). + SetTome(testTome). + SetParamDefsAtCreation("[]"). + SetEldritchAtCreation("print('hello')"). + SaveX(ctx) + + reqNoTasks := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]interface{}{ + "quest_id": fmt.Sprintf("%d", qNoTasks.ID), + }, + }, + } + resNoTasks, err := tavernmcp.HandleWaitForQuest(testCtx, reqNoTasks) + require.NoError(t, err) + assert.True(t, resNoTasks.IsError) + textContentNoTasks, ok := resNoTasks.Content[0].(mcp.TextContent) + require.True(t, ok) + assert.Contains(t, textContentNoTasks.Text, "has no tasks") +} + +func TestHandleCreateQuest(t *testing.T) { + ctx := context.Background() + client := setupTestDB(t) + defer client.Close() + + // Create test data + testHost := client.Host.Create(). + SetIdentifier("test-host-id"). + SetName("test-host"). + SetPlatform(c2pb.Host_PLATFORM_UNSPECIFIED). + SaveX(ctx) + + testBeacon := client.Beacon.Create(). + SetHost(testHost). + SetName("test-beacon"). + SetTransport(c2pb.Transport_TRANSPORT_UNSPECIFIED). + SaveX(ctx) + + testTome := client.Tome.Create(). + SetName("test-tome"). + SetDescription("A test tome"). + SetAuthor("test-author"). + SetSupportModel(tome.SupportModelCOMMUNITY). + SetEldritch("print('hello')"). + SetHash("abc123"). + SaveX(ctx) + + // Setup context with client + testCtx := context.WithValue(ctx, tavernmcp.TestContextKey{}, client) + + // Valid request + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]interface{}{ + "name": "new-test-quest", + "beacon_ids": []interface{}{fmt.Sprintf("%d", testBeacon.ID)}, + "parameters": `{}`, + "tome_id": fmt.Sprintf("%d", testTome.ID), + }, + }, + } + + // Use HandleCreateQuestForTest which allows passing nil for MCPServer + res, err := tavernmcp.HandleCreateQuestForTest(testCtx, req) + require.NoError(t, err) + assert.False(t, res.IsError) + require.Len(t, res.Content, 1) + + textContent, ok := res.Content[0].(mcp.TextContent) + require.True(t, ok) + + // Verify the JSON string + assert.Contains(t, textContent.Text, "new-test-quest") + + // Verify quest was created in db + quests, err := client.Quest.Query().WithTasks().All(ctx) + require.NoError(t, err) + assert.Len(t, quests, 1) + assert.Equal(t, "new-test-quest", quests[0].Name) + assert.Len(t, quests[0].Edges.Tasks, 1) + + // Test nil client error + resErr, err := tavernmcp.HandleCreateQuestForTest(context.Background(), req) + require.NoError(t, err) + assert.True(t, resErr.IsError) + textContentErr, ok := resErr.Content[0].(mcp.TextContent) + require.True(t, ok) + assert.Contains(t, textContentErr.Text, "internal error: no database client") + + // Test missing name + reqNoName := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]interface{}{ + "beacon_ids": []interface{}{fmt.Sprintf("%d", testBeacon.ID)}, + "parameters": `{"test": true}`, + "tome_id": fmt.Sprintf("%d", testTome.ID), + }, + }, + } + resNoName, err := tavernmcp.HandleCreateQuestForTest(testCtx, reqNoName) + require.NoError(t, err) + assert.True(t, resNoName.IsError) + + // Test missing beacon_ids + reqNoBeacons := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]interface{}{ + "name": "new-test-quest", + "parameters": `{"test": true}`, + "tome_id": fmt.Sprintf("%d", testTome.ID), + }, + }, + } + resNoBeacons, err := tavernmcp.HandleCreateQuestForTest(testCtx, reqNoBeacons) + require.NoError(t, err) + assert.True(t, resNoBeacons.IsError) + + // Test empty beacon_ids + reqEmptyBeacons := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]interface{}{ + "name": "new-test-quest", + "beacon_ids": []interface{}{}, + "parameters": `{"test": true}`, + "tome_id": fmt.Sprintf("%d", testTome.ID), + }, + }, + } + resEmptyBeacons, err := tavernmcp.HandleCreateQuestForTest(testCtx, reqEmptyBeacons) + require.NoError(t, err) + assert.True(t, resEmptyBeacons.IsError) + + // Test missing parameters (should use {}) + reqNoParams := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]interface{}{ + "name": "new-test-quest-noparams", + "beacon_ids": []interface{}{fmt.Sprintf("%d", testBeacon.ID)}, + "tome_id": fmt.Sprintf("%d", testTome.ID), + }, + }, + } + resNoParams, err := tavernmcp.HandleCreateQuestForTest(testCtx, reqNoParams) + require.NoError(t, err) + assert.True(t, resNoParams.IsError) // Wait, actually RequireString will return an error if missing + + // Test empty parameters + reqEmptyParams := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]interface{}{ + "name": "new-test-quest-emptyparams", + "beacon_ids": []interface{}{fmt.Sprintf("%d", testBeacon.ID)}, + "parameters": "", + "tome_id": fmt.Sprintf("%d", testTome.ID), + }, + }, + } + resEmptyParams, err := tavernmcp.HandleCreateQuestForTest(testCtx, reqEmptyParams) + require.NoError(t, err) + assert.False(t, resEmptyParams.IsError) // Should succeed + + // Test missing tome_id + reqNoTome := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]interface{}{ + "name": "new-test-quest", + "beacon_ids": []interface{}{fmt.Sprintf("%d", testBeacon.ID)}, + "parameters": `{"test": true}`, + }, + }, + } + resNoTome, err := tavernmcp.HandleCreateQuestForTest(testCtx, reqNoTome) + require.NoError(t, err) + assert.True(t, resNoTome.IsError) + + // Test tome not found + reqBadTome := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]interface{}{ + "name": "new-test-quest", + "beacon_ids": []interface{}{fmt.Sprintf("%d", testBeacon.ID)}, + "parameters": `{"test": true}`, + "tome_id": "999999", + }, + }, + } + resBadTome, err := tavernmcp.HandleCreateQuestForTest(testCtx, reqBadTome) + require.NoError(t, err) + assert.True(t, resBadTome.IsError) +}