Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/fix-malformed-tool-call-input.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@workflow/ai': patch
---

fix(ai): preserve malformed streamed tool-call input instead of crashing before repair hooks can run
82 changes: 81 additions & 1 deletion packages/ai/src/agent/do-stream-step.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import type { UIMessageChunk } from 'ai';
import { describe, expect, it } from 'vitest';
import { normalizeFinishReason } from './do-stream-step.js';
import {
doStreamStep,
normalizeFinishReason,
safeParseToolCallInput,
} from './do-stream-step.js';

describe('normalizeFinishReason', () => {
describe('string finish reasons', () => {
Expand Down Expand Up @@ -122,3 +127,78 @@ describe('normalizeFinishReason', () => {
});
});
});

describe('safeParseToolCallInput', () => {
it('should parse valid JSON input', () => {
expect(safeParseToolCallInput('{"city":"San Francisco"}')).toEqual({
city: 'San Francisco',
});
});

it('should return empty object for undefined input', () => {
expect(safeParseToolCallInput(undefined)).toEqual({});
});

it('should preserve malformed input as a string', () => {
expect(safeParseToolCallInput('{"city":"San Francisco"')).toBe(
'{"city":"San Francisco"'
);
});
});

describe('doStreamStep', () => {
it('should not throw when streamed tool-call input is malformed JSON', async () => {
const writtenChunks: UIMessageChunk[] = [];
const writable = new WritableStream<UIMessageChunk>({
write: async (chunk) => {
writtenChunks.push(chunk);
},
});

const model = {
provider: 'mock-provider',
modelId: 'mock-model',
doStream: async () => ({
stream: new ReadableStream({
start(controller) {
controller.enqueue({ type: 'stream-start', warnings: [] });
controller.enqueue({
type: 'tool-call',
toolCallId: 'call-1',
toolName: 'getWeather',
input: '{"city":"San Francisco"',
});
controller.enqueue({
type: 'finish',
finishReason: 'tool-calls',
usage: {
inputTokens: { total: 10, noCache: 10 },
outputTokens: { total: 5, text: 0, reasoning: 5 },
},
});
controller.close();
},
}),
}),
};

const result = await doStreamStep(
[{ role: 'user', content: [{ type: 'text', text: 'test' }] }],
async () => model as any,
writable,
undefined,
{ sendStart: false }
);

expect(result.step.toolCalls).toHaveLength(1);
expect(result.step.toolCalls[0]?.input).toBe('{"city":"San Francisco"');
expect(writtenChunks).toContainEqual(
expect.objectContaining({
type: 'tool-input-available',
toolCallId: 'call-1',
toolName: 'getWeather',
input: '{"city":"San Francisco"',
})
);
});
});
58 changes: 29 additions & 29 deletions packages/ai/src/agent/do-stream-step.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ function uint8ArrayToBase64(data: Uint8Array): string {
return btoa(binary);
}

/**
* Parse streamed tool-call input without crashing the workflow step when a
* provider emits malformed or truncated JSON.
*/
export function safeParseToolCallInput(input: string | undefined): unknown {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this also re-use safeParseInput from durable-agent.ts? Or is the undefined fallthrough intentional?

if (input == null || input === '') {
return {};
}

try {
return JSON.parse(input);
} catch {
return input;
}
}

/**
* Options for the doStreamStep function.
*/
Expand Down Expand Up @@ -454,7 +470,7 @@ export async function doStreamStep(
type: 'tool-input-available',
toolCallId: part.toolCallId,
toolName: part.toolName,
input: JSON.parse(part.input || '{}'),
input: safeParseToolCallInput(part.input),
...(part.providerExecuted != null
? { providerExecuted: part.providerExecuted }
: {}),
Expand Down Expand Up @@ -779,6 +795,14 @@ function chunksToStep(
? v3FinishReason
: undefined;

const mapToolCall = (toolCall: LanguageModelV3ToolCall) => ({
type: 'tool-call' as const,
toolCallId: toolCall.toolCallId,
toolName: toolCall.toolName,
input: safeParseToolCallInput(toolCall.input),
dynamic: true as const,
});

const stepResult: StepResult<any> = {
stepNumber: 0, // Will be overridden by the caller
model: {
Expand All @@ -790,13 +814,7 @@ function chunksToStep(
experimental_context: undefined,
content: [
...(text ? [{ type: 'text' as const, text }] : []),
...toolCalls.map((toolCall) => ({
type: 'tool-call' as const,
toolCallId: toolCall.toolCallId,
toolName: toolCall.toolName,
input: JSON.parse(toolCall.input),
dynamic: true as const,
})),
...toolCalls.map(mapToolCall),
],
text,
reasoning: reasoning.map((r) => ({
Expand All @@ -809,21 +827,9 @@ function chunksToStep(
reasoningText: reasoningText || undefined,
files,
sources,
toolCalls: toolCalls.map((toolCall) => ({
type: 'tool-call' as const,
toolCallId: toolCall.toolCallId,
toolName: toolCall.toolName,
input: JSON.parse(toolCall.input),
dynamic: true as const,
})),
toolCalls: toolCalls.map(mapToolCall),
staticToolCalls: [],
dynamicToolCalls: toolCalls.map((toolCall) => ({
type: 'tool-call' as const,
toolCallId: toolCall.toolCallId,
toolName: toolCall.toolName,
input: JSON.parse(toolCall.input),
dynamic: true as const,
})),
dynamicToolCalls: toolCalls.map(mapToolCall),
toolResults: [],
staticToolResults: [],
dynamicToolResults: [],
Expand Down Expand Up @@ -864,13 +870,7 @@ function chunksToStep(
request: {
body: JSON.stringify({
prompt: conversationPrompt,
tools: toolCalls.map((toolCall) => ({
type: 'tool-call' as const,
toolCallId: toolCall.toolCallId,
toolName: toolCall.toolName,
input: JSON.parse(toolCall.input),
dynamic: true as const,
})),
tools: toolCalls.map(mapToolCall),
}),
},
response: {
Expand Down
99 changes: 99 additions & 0 deletions packages/ai/src/agent/stream-text-iterator.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ import { beforeEach, describe, expect, it, vi } from 'vitest';
// Mock doStreamStep
vi.mock('./do-stream-step.js', () => ({
doStreamStep: vi.fn(),
safeParseToolCallInput: (input: string | undefined) => {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can just be original.safeParseToolCallInput instead of rewriting the function, to avoid drift

if (input == null || input === '') {
return {};
}

try {
return JSON.parse(input);
} catch {
return input;
}
},
}));

// Import after mocking
Expand Down Expand Up @@ -1095,4 +1106,92 @@ describe('streamTextIterator', () => {
});
});
});

describe('malformed tool-call input handling', () => {
it('should preserve malformed tool-call input instead of throwing', async () => {
const mockWritable = createMockWritable();
const mockModel = vi.fn();

let capturedPrompt: LanguageModelV3Prompt | undefined;

const malformedToolCall: LanguageModelV3ToolCall = {
type: 'tool-call',
toolCallId: 'call-1',
toolName: 'testTool',
input: '{"query":"test"',
};

vi.mocked(doStreamStep)
.mockResolvedValueOnce({
toolCalls: [malformedToolCall],
finish: { finishReason: 'tool-calls' },
step: createMockStepResult({
finishReason: 'tool-calls',
toolCalls: [
{
type: 'tool-call',
toolCallId: 'call-1',
toolName: 'testTool',
input: '{"query":"test"',
dynamic: true as const,
},
],
dynamicToolCalls: [
{
type: 'tool-call',
toolCallId: 'call-1',
toolName: 'testTool',
input: '{"query":"test"',
dynamic: true as const,
},
],
}),
})
.mockImplementationOnce(async (prompt) => {
capturedPrompt = prompt;
return {
toolCalls: [],
finish: { finishReason: 'stop' },
step: createMockStepResult({ finishReason: 'stop' }),
};
});

const iterator = streamTextIterator({
prompt: [{ role: 'user', content: [{ type: 'text', text: 'test' }] }],
tools: {
testTool: {
description: 'A test tool',
execute: async () => ({ result: 'success' }),
},
} as ToolSet,
writable: mockWritable,
model: mockModel as any,
});

const firstResult = await iterator.next();
expect(firstResult.done).toBe(false);
expect(firstResult.value.toolCalls).toHaveLength(1);

const toolResults: LanguageModelV3ToolResult[] = [
{
type: 'tool-result',
toolCallId: 'call-1',
toolName: 'testTool',
output: { type: 'text', value: '{"result":"success"}' },
},
];

await iterator.next(toolResults);

const assistantMessage = capturedPrompt?.find(
(msg) => msg.role === 'assistant'
);
const toolCallPart = (assistantMessage?.content as any[])?.find(
(part) => part.type === 'tool-call'
);

expect(toolCallPart).toBeDefined();
expect(toolCallPart.input).toBe('{"query":"test"');
});
});
});
3 changes: 2 additions & 1 deletion packages/ai/src/agent/stream-text-iterator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
doStreamStep,
type ModelStopCondition,
type ProviderExecutedToolResult,
safeParseToolCallInput,
} from './do-stream-step.js';
import type {
GenerationSettings,
Expand Down Expand Up @@ -358,7 +359,7 @@ export async function* streamTextIterator({
type: 'tool-call' as const,
toolCallId: toolCall.toolCallId,
toolName: toolCall.toolName,
input: JSON.parse(toolCall.input),
input: safeParseToolCallInput(toolCall.input),
...(meta != null ? { providerOptions: meta } : {}),
};
}),
Expand Down