diff --git a/tavern/internal/mcp/export_test.go b/tavern/internal/mcp/export_test.go new file mode 100644 index 000000000..2acff2415 --- /dev/null +++ b/tavern/internal/mcp/export_test.go @@ -0,0 +1,26 @@ +package mcp + +import ( + "context" + + mcpserver "github.com/mark3labs/mcp-go/server" +) + +// Export internal tool handlers for testing purposes without polluting the public API. +var ( + HandleListHosts = handleListHosts + HandleListQuests = handleListQuests + HandleListTomes = handleListTomes + HandleQuestOutput = handleQuestOutput + HandleWaitForQuest = handleWaitForQuest +) + +// HandleCreateQuestForTest is an exported wrapper that provides an MCPServer for the handler. +func HandleCreateQuestForTest(srv *mcpserver.MCPServer) mcpserver.ToolHandlerFunc { + return handleCreateQuest(srv) +} + +// ClientFromContextForTest exports the context key injection for testing. +func ClientFromContextForTest(ctx context.Context, client any) context.Context { + return context.WithValue(ctx, contextKey{}, client) +} diff --git a/tavern/internal/mcp/mcp_test.go b/tavern/internal/mcp/mcp_test.go index 15ca1ac34..fd4e85c23 100644 --- a/tavern/internal/mcp/mcp_test.go +++ b/tavern/internal/mcp/mcp_test.go @@ -2,12 +2,14 @@ package mcp_test import ( "context" + "fmt" "net/http" "net/http/httptest" "testing" "time" "github.com/mark3labs/mcp-go/mcp" + mcpserver "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "realm.pub/tavern/internal/c2/c2pb" @@ -88,6 +90,7 @@ func TestListQuestsHandler(t *testing.T) { ctx := context.Background() client := setupTestDB(t) defer client.Close() + ctx = tavernmcp.ClientFromContextForTest(ctx, client) // Create a tome testTome := client.Tome.Create(). @@ -108,13 +111,17 @@ func TestListQuestsHandler(t *testing.T) { SetEldritchAtCreation("print('hello')"). SaveX(ctx) - // Verify quest was created - quests, err := client.Quest.Query().WithTome().All(ctx) + // Call the tool handler + req := mcp.CallToolRequest{} + res, err := tavernmcp.HandleListQuests(ctx, req) require.NoError(t, err) - assert.Len(t, quests, 1) - assert.Equal(t, "test-quest", quests[0].Name) - assert.NotNil(t, quests[0].Edges.Tome) - assert.Equal(t, "test-tome", quests[0].Edges.Tome.Name) + require.False(t, res.IsError) + require.Len(t, res.Content, 1) + + text, ok := res.Content[0].(mcp.TextContent) + require.True(t, ok) + assert.Contains(t, text.Text, "test-quest") + assert.Contains(t, text.Text, "test-tome") } // TestListHostsHandler tests the list_hosts tool by creating test data. @@ -122,6 +129,7 @@ func TestListHostsHandler(t *testing.T) { ctx := context.Background() client := setupTestDB(t) defer client.Close() + ctx = tavernmcp.ClientFromContextForTest(ctx, client) // Create a host host := client.Host.Create(). @@ -145,13 +153,17 @@ func TestListHostsHandler(t *testing.T) { AddHosts(host). SaveX(ctx) - // Verify data was created - hosts, err := client.Host.Query().WithBeacons().WithTags().All(ctx) + // Call the tool handler + req := mcp.CallToolRequest{} + res, err := tavernmcp.HandleListHosts(ctx, req) require.NoError(t, err) - assert.Len(t, hosts, 1) - assert.Equal(t, "test-host", hosts[0].Name) - assert.Len(t, hosts[0].Edges.Beacons, 1) - assert.Len(t, hosts[0].Edges.Tags, 1) + require.False(t, res.IsError) + require.Len(t, res.Content, 1) + + text, ok := res.Content[0].(mcp.TextContent) + require.True(t, ok) + assert.Contains(t, text.Text, "test-host") + assert.Contains(t, text.Text, "test-tag") } // TestListTomesHandler tests the list_tomes tool by creating test data. @@ -159,6 +171,7 @@ func TestListTomesHandler(t *testing.T) { ctx := context.Background() client := setupTestDB(t) defer client.Close() + ctx = tavernmcp.ClientFromContextForTest(ctx, client) // Create tomes client.Tome.Create(). @@ -180,9 +193,17 @@ func TestListTomesHandler(t *testing.T) { SetHash("hash2"). SaveX(ctx) - tomes, err := client.Tome.Query().All(ctx) + // Call the tool handler + req := mcp.CallToolRequest{} + res, err := tavernmcp.HandleListTomes(ctx, req) require.NoError(t, err) - assert.Len(t, tomes, 2) + require.False(t, res.IsError) + require.Len(t, res.Content, 1) + + text, ok := res.Content[0].(mcp.TextContent) + require.True(t, ok) + assert.Contains(t, text.Text, "tome-1") + assert.Contains(t, text.Text, "tome-2") } // TestCreateQuestHandler tests the create_quest tool by creating a quest. @@ -190,6 +211,7 @@ func TestCreateQuestHandler(t *testing.T) { ctx := context.Background() client := setupTestDB(t) defer client.Close() + ctx = tavernmcp.ClientFromContextForTest(ctx, client) // Create a tome testTome := client.Tome.Create(). @@ -213,22 +235,32 @@ func TestCreateQuestHandler(t *testing.T) { SetTransport(c2pb.Transport_TRANSPORT_UNSPECIFIED). SaveX(ctx) - // Create a quest directly to validate the pattern - q := client.Quest.Create(). - SetName("mcp-quest"). - SetParameters(`{"key":"value"}`). - SetTomeID(testTome.ID). - SetParamDefsAtCreation(testTome.ParamDefs). - SetEldritchAtCreation(testTome.Eldritch). - SaveX(ctx) + // Call the tool handler + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "name": "mcp-quest", + "beacon_ids": []any{fmt.Sprintf("%d", testBeacon.ID)}, + "parameters": `{"key":"value"}`, + "tome_id": fmt.Sprintf("%d", testTome.ID), + }, + }, + } - // Create a task for the beacon - client.Task.Create(). - SetQuestID(q.ID). - SetBeaconID(testBeacon.ID). - SaveX(ctx) + // Create an MCP server to provide elicitation (we won't actually get it since no active session, but it will fall through) + srv := mcpserver.NewMCPServer("test", "1.0") + + handler := tavernmcp.HandleCreateQuestForTest(srv) + res, err := handler(ctx, req) + require.NoError(t, err) + require.False(t, res.IsError) + require.Len(t, res.Content, 1) - // Verify quest and tasks + text, ok := res.Content[0].(mcp.TextContent) + require.True(t, ok) + assert.Contains(t, text.Text, "mcp-quest") + + // Verify quest and tasks in DB createdQuest, err := client.Quest.Query().WithTasks().All(ctx) require.NoError(t, err) assert.Len(t, createdQuest, 1) @@ -241,6 +273,7 @@ func TestQuestOutputHandler(t *testing.T) { ctx := context.Background() client := setupTestDB(t) defer client.Close() + ctx = tavernmcp.ClientFromContextForTest(ctx, client) // Create a tome testTome := client.Tome.Create(). @@ -281,18 +314,23 @@ func TestQuestOutputHandler(t *testing.T) { SetExecFinishedAt(time.Now()). SaveX(ctx) - // Verify output is queryable - quests, err := client.Quest.Query(). - WithTasks(func(tq *ent.TaskQuery) { - tq.WithBeacon(func(bq *ent.BeaconQuery) { - bq.WithHost() - }) - }). - All(ctx) + // Call the tool handler + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "ids": []any{fmt.Sprintf("%d", q.ID)}, + }, + }, + } + res, err := tavernmcp.HandleQuestOutput(ctx, req) require.NoError(t, err) - assert.Len(t, quests, 1) - assert.Len(t, quests[0].Edges.Tasks, 1) - assert.Equal(t, "task output result", quests[0].Edges.Tasks[0].Output) + require.False(t, res.IsError) + require.Len(t, res.Content, 1) + + text, ok := res.Content[0].(mcp.TextContent) + require.True(t, ok) + assert.Contains(t, text.Text, "output-quest") + assert.Contains(t, text.Text, "task output result") } // TestWaitForQuestHandler tests the wait_for_quest tool with already-finished tasks. @@ -300,6 +338,7 @@ func TestWaitForQuestHandler(t *testing.T) { ctx := context.Background() client := setupTestDB(t) defer client.Close() + ctx = tavernmcp.ClientFromContextForTest(ctx, client) // Create a tome testTome := client.Tome.Create(). @@ -339,15 +378,22 @@ func TestWaitForQuestHandler(t *testing.T) { SetExecFinishedAt(time.Now()). SaveX(ctx) - // Verify the quest tasks are finished - quest, err := client.Quest.Query(). - WithTasks(func(tq *ent.TaskQuery) { - tq.WithBeacon() - }). - Only(ctx) + // Call the tool handler + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "quest_id": fmt.Sprintf("%d", q.ID), + }, + }, + } + res, err := tavernmcp.HandleWaitForQuest(ctx, req) require.NoError(t, err) - assert.Len(t, quest.Edges.Tasks, 1) - assert.False(t, quest.Edges.Tasks[0].ExecFinishedAt.IsZero()) + require.False(t, res.IsError) + require.Len(t, res.Content, 1) + + text, ok := res.Content[0].(mcp.TextContent) + require.True(t, ok) + assert.Contains(t, text.Text, "finished-quest") } // TestParseIntIDs tests the ParseIntIDs helper.