diff --git a/.changeset/fix-malformed-tool-call-input.md b/.changeset/fix-malformed-tool-call-input.md new file mode 100644 index 0000000000..eac8cac742 --- /dev/null +++ b/.changeset/fix-malformed-tool-call-input.md @@ -0,0 +1,5 @@ +--- +'@workflow/ai': patch +--- + +Preserve malformed streamed tool-call input until repair hooks can run diff --git a/packages/ai/src/agent/do-stream-step.test.ts b/packages/ai/src/agent/do-stream-step.test.ts index 0bdd2b991f..28f6122881 100644 --- a/packages/ai/src/agent/do-stream-step.test.ts +++ b/packages/ai/src/agent/do-stream-step.test.ts @@ -1,5 +1,7 @@ +import type { UIMessageChunk } from 'ai'; import { describe, expect, it } from 'vitest'; -import { normalizeFinishReason } from './do-stream-step.js'; +import { doStreamStep, normalizeFinishReason } from './do-stream-step.js'; +import { safeParseToolCallInput } from './safe-parse-tool-call-input.js'; describe('normalizeFinishReason', () => { describe('string finish reasons', () => { @@ -122,3 +124,78 @@ describe('normalizeFinishReason', () => { }); }); }); + +describe('safeParseToolCallInput', () => { + it('should parse valid JSON input', () => { + expect(safeParseToolCallInput('{"city":"San Francisco"}')).toEqual({ + city: 'San Francisco', + }); + }); + + it('should return empty object for undefined input', () => { + expect(safeParseToolCallInput(undefined)).toEqual({}); + }); + + it('should preserve malformed input as a string', () => { + expect(safeParseToolCallInput('{"city":"San Francisco"')).toBe( + '{"city":"San Francisco"' + ); + }); +}); + +describe('doStreamStep', () => { + it('should not throw when streamed tool-call input is malformed JSON', async () => { + const writtenChunks: UIMessageChunk[] = []; + const writable = new WritableStream({ + write: async (chunk) => { + writtenChunks.push(chunk); + }, + }); + + const model = { + provider: 'mock-provider', + modelId: 'mock-model', + doStream: async () => ({ + stream: new ReadableStream({ + start(controller) { + controller.enqueue({ type: 'stream-start', warnings: [] }); + controller.enqueue({ + type: 'tool-call', + toolCallId: 'call-1', + toolName: 'getWeather', + input: '{"city":"San Francisco"', + }); + controller.enqueue({ + type: 'finish', + finishReason: 'tool-calls', + usage: { + inputTokens: { total: 10, noCache: 10 }, + outputTokens: { total: 5, text: 0, reasoning: 5 }, + }, + }); + controller.close(); + }, + }), + }), + }; + + const result = await doStreamStep( + [{ role: 'user', content: [{ type: 'text', text: 'test' }] }], + async () => model as any, + writable, + undefined, + { sendStart: false } + ); + + expect(result.step.toolCalls).toHaveLength(1); + expect(result.step.toolCalls[0]?.input).toBe('{"city":"San Francisco"'); + expect(writtenChunks).toContainEqual( + expect.objectContaining({ + type: 'tool-input-available', + toolCallId: 'call-1', + toolName: 'getWeather', + input: '{"city":"San Francisco"', + }) + ); + }); +}); diff --git a/packages/ai/src/agent/do-stream-step.ts b/packages/ai/src/agent/do-stream-step.ts index 45089fd226..15ba779dad 100644 --- a/packages/ai/src/agent/do-stream-step.ts +++ b/packages/ai/src/agent/do-stream-step.ts @@ -22,6 +22,7 @@ import type { TelemetrySettings, } from './durable-agent.js'; import { getErrorMessage } from '../get-error-message.js'; +import { safeParseToolCallInput } from './safe-parse-tool-call-input.js'; import { recordSpan } from './telemetry.js'; import type { CompatibleLanguageModel } from './types.js'; @@ -454,7 +455,7 @@ export async function doStreamStep( type: 'tool-input-available', toolCallId: part.toolCallId, toolName: part.toolName, - input: JSON.parse(part.input || '{}'), + input: safeParseToolCallInput(part.input), ...(part.providerExecuted != null ? { providerExecuted: part.providerExecuted } : {}), @@ -779,6 +780,14 @@ function chunksToStep( ? v3FinishReason : undefined; + const mapToolCall = (toolCall: LanguageModelV3ToolCall) => ({ + type: 'tool-call' as const, + toolCallId: toolCall.toolCallId, + toolName: toolCall.toolName, + input: safeParseToolCallInput(toolCall.input), + dynamic: true as const, + }); + const stepResult: StepResult = { stepNumber: 0, // Will be overridden by the caller model: { @@ -790,13 +799,7 @@ function chunksToStep( experimental_context: undefined, content: [ ...(text ? [{ type: 'text' as const, text }] : []), - ...toolCalls.map((toolCall) => ({ - type: 'tool-call' as const, - toolCallId: toolCall.toolCallId, - toolName: toolCall.toolName, - input: JSON.parse(toolCall.input), - dynamic: true as const, - })), + ...toolCalls.map(mapToolCall), ], text, reasoning: reasoning.map((r) => ({ @@ -809,21 +812,9 @@ function chunksToStep( reasoningText: reasoningText || undefined, files, sources, - toolCalls: toolCalls.map((toolCall) => ({ - type: 'tool-call' as const, - toolCallId: toolCall.toolCallId, - toolName: toolCall.toolName, - input: JSON.parse(toolCall.input), - dynamic: true as const, - })), + toolCalls: toolCalls.map(mapToolCall), staticToolCalls: [], - dynamicToolCalls: toolCalls.map((toolCall) => ({ - type: 'tool-call' as const, - toolCallId: toolCall.toolCallId, - toolName: toolCall.toolName, - input: JSON.parse(toolCall.input), - dynamic: true as const, - })), + dynamicToolCalls: toolCalls.map(mapToolCall), toolResults: [], staticToolResults: [], dynamicToolResults: [], @@ -864,13 +855,7 @@ function chunksToStep( request: { body: JSON.stringify({ prompt: conversationPrompt, - tools: toolCalls.map((toolCall) => ({ - type: 'tool-call' as const, - toolCallId: toolCall.toolCallId, - toolName: toolCall.toolName, - input: JSON.parse(toolCall.input), - dynamic: true as const, - })), + tools: toolCalls.map(mapToolCall), }), }, response: { diff --git a/packages/ai/src/agent/durable-agent.test.ts b/packages/ai/src/agent/durable-agent.test.ts index 2fc23e8f56..f78346ade2 100644 --- a/packages/ai/src/agent/durable-agent.test.ts +++ b/packages/ai/src/agent/durable-agent.test.ts @@ -2446,6 +2446,92 @@ describe('DurableAgent', () => { }) ); }); + + it('should patch repaired tool-call input back into the conversation prompt', async () => { + const repairFn: ToolCallRepairFunction = vi + .fn() + .mockReturnValue({ + toolCallId: 'test-call-id', + toolName: 'testTool', + input: '{"name":"repaired"}', + }); + + const tools: ToolSet = { + testTool: { + description: 'A test tool', + inputSchema: z.object({ name: z.string() }), + execute: async () => ({ result: 'success' }), + }, + }; + + const mockModel = createMockModel(); + + const agent = new DurableAgent({ + model: async () => mockModel, + tools, + }); + + const mockWritable = new WritableStream({ + write: vi.fn(), + close: vi.fn(), + }); + + const mockMessages: LanguageModelV3Prompt = [ + { role: 'user', content: [{ type: 'text', text: 'test' }] }, + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'test-call-id', + toolName: 'testTool', + input: 'invalid json', + }, + ], + }, + ]; + + const { streamTextIterator } = await import('./stream-text-iterator.js'); + const mockIterator = { + next: vi + .fn() + .mockResolvedValueOnce({ + done: false, + value: { + toolCalls: [ + { + toolCallId: 'test-call-id', + toolName: 'testTool', + input: 'invalid json', + } as LanguageModelV3ToolCall, + ], + messages: mockMessages, + }, + }) + .mockResolvedValueOnce({ done: true, value: [] }), + }; + vi.mocked(streamTextIterator).mockReturnValue( + mockIterator as unknown as MockIterator + ); + + await agent.stream({ + messages: [{ role: 'user', content: 'test' }], + writable: mockWritable, + experimental_repairToolCall: repairFn, + }); + + expect(mockMessages[1]).toMatchObject({ + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'test-call-id', + toolName: 'testTool', + input: { name: 'repaired' }, + }, + ], + }); + }); }); describe('includeRawChunks', () => { diff --git a/packages/ai/src/agent/durable-agent.ts b/packages/ai/src/agent/durable-agent.ts index 0c5cc28953..a84d7214dc 100644 --- a/packages/ai/src/agent/durable-agent.ts +++ b/packages/ai/src/agent/durable-agent.ts @@ -27,6 +27,7 @@ import { } from 'ai'; import { convertToLanguageModelPrompt, standardizePrompt } from 'ai/internal'; import { getErrorMessage } from '../get-error-message.js'; +import { safeParseToolCallInput } from './safe-parse-tool-call-input.js'; import { streamTextIterator } from './stream-text-iterator.js'; import { recordSpan, runInContext } from './telemetry.js'; import type { CompatibleLanguageModel } from './types.js'; @@ -1133,7 +1134,7 @@ export class DurableAgent { type: 'tool-call' as const, toolCallId: tc.toolCallId, toolName: tc.toolName, - input: safeParseInput(tc.input), + input: safeParseToolCallInput(tc.input), })); // Build toolResults only for tools that were executed @@ -1141,7 +1142,7 @@ export class DurableAgent { type: 'tool-result' as const, toolCallId: r.toolCallId, toolName: r.toolName, - input: safeParseInput( + input: safeParseToolCallInput( toolCalls.find((tc) => tc.toolCallId === r.toolCallId)?.input ), output: 'value' in r.output ? r.output.value : undefined, @@ -1244,13 +1245,13 @@ export class DurableAgent { type: 'tool-call' as const, toolCallId: tc.toolCallId, toolName: tc.toolName, - input: safeParseInput(tc.input), + input: safeParseToolCallInput(tc.input), })); lastStepToolResults = toolResults.map((r) => ({ type: 'tool-result' as const, toolCallId: r.toolCallId, toolName: r.toolName, - input: safeParseInput( + input: safeParseToolCallInput( toolCalls.find((tc) => tc.toolCallId === r.toolCallId)?.input ), output: 'value' in r.output ? r.output.value : undefined, @@ -1478,10 +1479,6 @@ async function convertChunksToUIMessages( return messages; } -/** - * Safely parse tool call input JSON. Returns the parsed value or the raw string - * if parsing fails (e.g., for tool calls that were repaired). - */ /** * Valid `type` values for LanguageModelV3ToolResultOutput. * When a tool returns an object whose `type` matches one of these, @@ -1503,11 +1500,35 @@ function isToolResultOutput( return TOOL_RESULT_OUTPUT_TYPES.has((result as { type?: string }).type ?? ''); } -function safeParseInput(input: string | undefined): unknown { - try { - return JSON.parse(input || '{}'); - } catch { - return input; +function patchToolCallInMessages( + messages: LanguageModelV3Prompt, + toolCall: LanguageModelV3ToolCall +): void { + const repairedInput = safeParseToolCallInput(toolCall.input); + + for (let i = messages.length - 1; i >= 0; i--) { + const message = messages[i]; + + if (message.role !== 'assistant' || !Array.isArray(message.content)) { + continue; + } + + const toolCallPart = message.content.find( + ( + part + ): part is { + type: 'tool-call'; + toolCallId: string; + toolName: string; + input: unknown; + } => part.type === 'tool-call' && part.toolCallId === toolCall.toolCallId + ); + + if (toolCallPart) { + toolCallPart.toolName = toolCall.toolName; + toolCallPart.input = repairedInput; + return; + } } } @@ -1589,6 +1610,9 @@ async function executeTool( messages, }); if (repairedToolCall) { + toolCall.toolName = repairedToolCall.toolName; + toolCall.input = repairedToolCall.input; + patchToolCallInMessages(messages, repairedToolCall); // Retry with repaired tool call return executeTool( repairedToolCall, @@ -1614,6 +1638,9 @@ async function executeTool( messages, }); if (repairedToolCall) { + toolCall.toolName = repairedToolCall.toolName; + toolCall.input = repairedToolCall.input; + patchToolCallInMessages(messages, repairedToolCall); // Retry with repaired tool call return executeTool( repairedToolCall, diff --git a/packages/ai/src/agent/safe-parse-tool-call-input.ts b/packages/ai/src/agent/safe-parse-tool-call-input.ts new file mode 100644 index 0000000000..ec3666f5ff --- /dev/null +++ b/packages/ai/src/agent/safe-parse-tool-call-input.ts @@ -0,0 +1,15 @@ +/** + * Parse streamed tool-call input without crashing the workflow step when a + * provider emits malformed or truncated JSON. + */ +export function safeParseToolCallInput(input: string | undefined): unknown { + if (input == null || input === '') { + return {}; + } + + try { + return JSON.parse(input); + } catch { + return input; + } +} diff --git a/packages/ai/src/agent/stream-text-iterator.test.ts b/packages/ai/src/agent/stream-text-iterator.test.ts index 179a75b6a2..90eabc1446 100644 --- a/packages/ai/src/agent/stream-text-iterator.test.ts +++ b/packages/ai/src/agent/stream-text-iterator.test.ts @@ -1095,4 +1095,92 @@ describe('streamTextIterator', () => { }); }); }); + + describe('malformed tool-call input handling', () => { + it('should preserve malformed tool-call input instead of throwing', async () => { + const mockWritable = createMockWritable(); + const mockModel = vi.fn(); + + let capturedPrompt: LanguageModelV3Prompt | undefined; + + const malformedToolCall: LanguageModelV3ToolCall = { + type: 'tool-call', + toolCallId: 'call-1', + toolName: 'testTool', + input: '{"query":"test"', + }; + + vi.mocked(doStreamStep) + .mockResolvedValueOnce({ + toolCalls: [malformedToolCall], + finish: { finishReason: 'tool-calls' }, + step: createMockStepResult({ + finishReason: 'tool-calls', + toolCalls: [ + { + type: 'tool-call', + toolCallId: 'call-1', + toolName: 'testTool', + input: '{"query":"test"', + dynamic: true as const, + }, + ], + dynamicToolCalls: [ + { + type: 'tool-call', + toolCallId: 'call-1', + toolName: 'testTool', + input: '{"query":"test"', + dynamic: true as const, + }, + ], + }), + }) + .mockImplementationOnce(async (prompt) => { + capturedPrompt = prompt; + return { + toolCalls: [], + finish: { finishReason: 'stop' }, + step: createMockStepResult({ finishReason: 'stop' }), + }; + }); + + const iterator = streamTextIterator({ + prompt: [{ role: 'user', content: [{ type: 'text', text: 'test' }] }], + tools: { + testTool: { + description: 'A test tool', + execute: async () => ({ result: 'success' }), + }, + } as ToolSet, + writable: mockWritable, + model: mockModel as any, + }); + + const firstResult = await iterator.next(); + expect(firstResult.done).toBe(false); + expect(firstResult.value.toolCalls).toHaveLength(1); + + const toolResults: LanguageModelV3ToolResult[] = [ + { + type: 'tool-result', + toolCallId: 'call-1', + toolName: 'testTool', + output: { type: 'text', value: '{"result":"success"}' }, + }, + ]; + + await iterator.next(toolResults); + + const assistantMessage = capturedPrompt?.find( + (msg) => msg.role === 'assistant' + ); + const toolCallPart = (assistantMessage?.content as any[])?.find( + (part) => part.type === 'tool-call' + ); + + expect(toolCallPart).toBeDefined(); + expect(toolCallPart.input).toBe('{"query":"test"'); + }); + }); }); diff --git a/packages/ai/src/agent/stream-text-iterator.ts b/packages/ai/src/agent/stream-text-iterator.ts index 5ff2ae8696..07b7e87bd5 100644 --- a/packages/ai/src/agent/stream-text-iterator.ts +++ b/packages/ai/src/agent/stream-text-iterator.ts @@ -30,6 +30,7 @@ import { runInContext, type SpanHandle, } from './telemetry.js'; +import { safeParseToolCallInput } from './safe-parse-tool-call-input.js'; import { toolsToModelTools } from './tools-to-model-tools.js'; import type { CompatibleLanguageModel } from './types.js'; @@ -358,7 +359,7 @@ export async function* streamTextIterator({ type: 'tool-call' as const, toolCallId: toolCall.toolCallId, toolName: toolCall.toolName, - input: JSON.parse(toolCall.input), + input: safeParseToolCallInput(toolCall.input), ...(meta != null ? { providerOptions: meta } : {}), }; }),