Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tavern/internal/mcp/export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package mcp

import (
"context"

mcpserver "github.com/mark3labs/mcp-go/server"
)

// Export internal tool handlers for testing purposes without polluting the public API.
var (
HandleListHosts = handleListHosts
HandleListQuests = handleListQuests
HandleListTomes = handleListTomes
HandleQuestOutput = handleQuestOutput
HandleWaitForQuest = handleWaitForQuest
)

// HandleCreateQuestForTest is an exported wrapper that provides an MCPServer for the handler.
func HandleCreateQuestForTest(srv *mcpserver.MCPServer) mcpserver.ToolHandlerFunc {
return handleCreateQuest(srv)
}

// ClientFromContextForTest exports the context key injection for testing.
func ClientFromContextForTest(ctx context.Context, client any) context.Context {
return context.WithValue(ctx, contextKey{}, client)
}
140 changes: 93 additions & 47 deletions tavern/internal/mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package mcp_test

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"realm.pub/tavern/internal/c2/c2pb"
Expand Down Expand Up @@ -88,6 +90,7 @@ func TestListQuestsHandler(t *testing.T) {
ctx := context.Background()
client := setupTestDB(t)
defer client.Close()
ctx = tavernmcp.ClientFromContextForTest(ctx, client)

// Create a tome
testTome := client.Tome.Create().
Expand All @@ -108,20 +111,25 @@ func TestListQuestsHandler(t *testing.T) {
SetEldritchAtCreation("print('hello')").
SaveX(ctx)

// Verify quest was created
quests, err := client.Quest.Query().WithTome().All(ctx)
// Call the tool handler
req := mcp.CallToolRequest{}
res, err := tavernmcp.HandleListQuests(ctx, req)
require.NoError(t, err)
assert.Len(t, quests, 1)
assert.Equal(t, "test-quest", quests[0].Name)
assert.NotNil(t, quests[0].Edges.Tome)
assert.Equal(t, "test-tome", quests[0].Edges.Tome.Name)
require.False(t, res.IsError)
require.Len(t, res.Content, 1)

text, ok := res.Content[0].(mcp.TextContent)
require.True(t, ok)
assert.Contains(t, text.Text, "test-quest")
assert.Contains(t, text.Text, "test-tome")
}

// TestListHostsHandler tests the list_hosts tool by creating test data.
func TestListHostsHandler(t *testing.T) {
ctx := context.Background()
client := setupTestDB(t)
defer client.Close()
ctx = tavernmcp.ClientFromContextForTest(ctx, client)

// Create a host
host := client.Host.Create().
Expand All @@ -145,20 +153,25 @@ func TestListHostsHandler(t *testing.T) {
AddHosts(host).
SaveX(ctx)

// Verify data was created
hosts, err := client.Host.Query().WithBeacons().WithTags().All(ctx)
// Call the tool handler
req := mcp.CallToolRequest{}
res, err := tavernmcp.HandleListHosts(ctx, req)
require.NoError(t, err)
assert.Len(t, hosts, 1)
assert.Equal(t, "test-host", hosts[0].Name)
assert.Len(t, hosts[0].Edges.Beacons, 1)
assert.Len(t, hosts[0].Edges.Tags, 1)
require.False(t, res.IsError)
require.Len(t, res.Content, 1)

text, ok := res.Content[0].(mcp.TextContent)
require.True(t, ok)
assert.Contains(t, text.Text, "test-host")
assert.Contains(t, text.Text, "test-tag")
}

// TestListTomesHandler tests the list_tomes tool by creating test data.
func TestListTomesHandler(t *testing.T) {
ctx := context.Background()
client := setupTestDB(t)
defer client.Close()
ctx = tavernmcp.ClientFromContextForTest(ctx, client)

// Create tomes
client.Tome.Create().
Expand All @@ -180,16 +193,25 @@ func TestListTomesHandler(t *testing.T) {
SetHash("hash2").
SaveX(ctx)

tomes, err := client.Tome.Query().All(ctx)
// Call the tool handler
req := mcp.CallToolRequest{}
res, err := tavernmcp.HandleListTomes(ctx, req)
require.NoError(t, err)
assert.Len(t, tomes, 2)
require.False(t, res.IsError)
require.Len(t, res.Content, 1)

text, ok := res.Content[0].(mcp.TextContent)
require.True(t, ok)
assert.Contains(t, text.Text, "tome-1")
assert.Contains(t, text.Text, "tome-2")
}

// TestCreateQuestHandler tests the create_quest tool by creating a quest.
func TestCreateQuestHandler(t *testing.T) {
ctx := context.Background()
client := setupTestDB(t)
defer client.Close()
ctx = tavernmcp.ClientFromContextForTest(ctx, client)

// Create a tome
testTome := client.Tome.Create().
Expand All @@ -213,22 +235,32 @@ func TestCreateQuestHandler(t *testing.T) {
SetTransport(c2pb.Transport_TRANSPORT_UNSPECIFIED).
SaveX(ctx)

// Create a quest directly to validate the pattern
q := client.Quest.Create().
SetName("mcp-quest").
SetParameters(`{"key":"value"}`).
SetTomeID(testTome.ID).
SetParamDefsAtCreation(testTome.ParamDefs).
SetEldritchAtCreation(testTome.Eldritch).
SaveX(ctx)
// Call the tool handler
req := mcp.CallToolRequest{
Params: mcp.CallToolParams{
Arguments: map[string]any{
"name": "mcp-quest",
"beacon_ids": []any{fmt.Sprintf("%d", testBeacon.ID)},
"parameters": `{"key":"value"}`,
"tome_id": fmt.Sprintf("%d", testTome.ID),
},
},
}

// Create a task for the beacon
client.Task.Create().
SetQuestID(q.ID).
SetBeaconID(testBeacon.ID).
SaveX(ctx)
// Create an MCP server to provide elicitation (we won't actually get it since no active session, but it will fall through)
srv := mcpserver.NewMCPServer("test", "1.0")

handler := tavernmcp.HandleCreateQuestForTest(srv)
res, err := handler(ctx, req)
require.NoError(t, err)
require.False(t, res.IsError)
require.Len(t, res.Content, 1)

// Verify quest and tasks
text, ok := res.Content[0].(mcp.TextContent)
require.True(t, ok)
assert.Contains(t, text.Text, "mcp-quest")

// Verify quest and tasks in DB
createdQuest, err := client.Quest.Query().WithTasks().All(ctx)
require.NoError(t, err)
assert.Len(t, createdQuest, 1)
Expand All @@ -241,6 +273,7 @@ func TestQuestOutputHandler(t *testing.T) {
ctx := context.Background()
client := setupTestDB(t)
defer client.Close()
ctx = tavernmcp.ClientFromContextForTest(ctx, client)

// Create a tome
testTome := client.Tome.Create().
Expand Down Expand Up @@ -281,25 +314,31 @@ func TestQuestOutputHandler(t *testing.T) {
SetExecFinishedAt(time.Now()).
SaveX(ctx)

// Verify output is queryable
quests, err := client.Quest.Query().
WithTasks(func(tq *ent.TaskQuery) {
tq.WithBeacon(func(bq *ent.BeaconQuery) {
bq.WithHost()
})
}).
All(ctx)
// Call the tool handler
req := mcp.CallToolRequest{
Params: mcp.CallToolParams{
Arguments: map[string]any{
"ids": []any{fmt.Sprintf("%d", q.ID)},
},
},
}
res, err := tavernmcp.HandleQuestOutput(ctx, req)
require.NoError(t, err)
assert.Len(t, quests, 1)
assert.Len(t, quests[0].Edges.Tasks, 1)
assert.Equal(t, "task output result", quests[0].Edges.Tasks[0].Output)
require.False(t, res.IsError)
require.Len(t, res.Content, 1)

text, ok := res.Content[0].(mcp.TextContent)
require.True(t, ok)
assert.Contains(t, text.Text, "output-quest")
assert.Contains(t, text.Text, "task output result")
}

// TestWaitForQuestHandler tests the wait_for_quest tool with already-finished tasks.
func TestWaitForQuestHandler(t *testing.T) {
ctx := context.Background()
client := setupTestDB(t)
defer client.Close()
ctx = tavernmcp.ClientFromContextForTest(ctx, client)

// Create a tome
testTome := client.Tome.Create().
Expand Down Expand Up @@ -339,15 +378,22 @@ func TestWaitForQuestHandler(t *testing.T) {
SetExecFinishedAt(time.Now()).
SaveX(ctx)

// Verify the quest tasks are finished
quest, err := client.Quest.Query().
WithTasks(func(tq *ent.TaskQuery) {
tq.WithBeacon()
}).
Only(ctx)
// Call the tool handler
req := mcp.CallToolRequest{
Params: mcp.CallToolParams{
Arguments: map[string]any{
"quest_id": fmt.Sprintf("%d", q.ID),
},
},
}
res, err := tavernmcp.HandleWaitForQuest(ctx, req)
require.NoError(t, err)
assert.Len(t, quest.Edges.Tasks, 1)
assert.False(t, quest.Edges.Tasks[0].ExecFinishedAt.IsZero())
require.False(t, res.IsError)
require.Len(t, res.Content, 1)

text, ok := res.Content[0].(mcp.TextContent)
require.True(t, ok)
assert.Contains(t, text.Text, "finished-quest")
}

// TestParseIntIDs tests the ParseIntIDs helper.
Expand Down
Loading