diff --git a/.changeset/openai-reconnect-tool-call-anchors.md b/.changeset/openai-reconnect-tool-call-anchors.md new file mode 100644 index 000000000..d468e3aeb --- /dev/null +++ b/.changeset/openai-reconnect-tool-call-anchors.md @@ -0,0 +1,6 @@ +--- +'@livekit/agents': patch +'@livekit/agents-plugin-openai': patch +--- + +Discard stale OpenAI Realtime tool outputs after reconnect when their original function-call anchor is no longer present in the active realtime session. diff --git a/agents/src/voice/agent_activity.test.ts b/agents/src/voice/agent_activity.test.ts index 5600418d3..728f90e62 100644 --- a/agents/src/voice/agent_activity.test.ts +++ b/agents/src/voice/agent_activity.test.ts @@ -16,11 +16,11 @@ */ import { Heap } from 'heap-js'; import { describe, expect, it, vi } from 'vitest'; -import { ChatContext } from '../llm/chat_context.js'; +import { ChatContext, FunctionCall, FunctionCallOutput } from '../llm/chat_context.js'; import { LLM, type LLMStream } from '../llm/llm.js'; import { Future, Task } from '../utils.js'; import { _getActivityTaskInfo } from './agent.js'; -import { AgentActivity } from './agent_activity.js'; +import { AgentActivity, filterFunctionCallOutputsForRealtimeSession } from './agent_activity.js'; import type { PreemptiveGenerationInfo } from './audio_recognition.js'; import { SpeechHandle } from './speech_handle.js'; @@ -131,6 +131,36 @@ function buildMainTaskRunner() { }; } +describe('filterFunctionCallOutputsForRealtimeSession', () => { + it('keeps only outputs whose function call is still in the realtime session', () => { + const liveOutput = FunctionCallOutput.create({ + callId: 'call_live', + name: 'lookup', + output: '{"ok":true}', + isError: false, + }); + const staleOutput = FunctionCallOutput.create({ + callId: 'call_stale', + name: 'lookup', + output: '{"ok":false}', + isError: false, + }); + const chatCtx = new ChatContext([ + FunctionCall.create({ + callId: 'call_live', + name: 'lookup', + args: '{}', + }), + ]); + + const { currentFunctionCallOutputs, staleFunctionCallOutputs } = + filterFunctionCallOutputsForRealtimeSession(chatCtx, [liveOutput, staleOutput]); + + expect(currentFunctionCallOutputs).toEqual([liveOutput]); + expect(staleFunctionCallOutputs).toEqual([staleOutput]); + }); +}); + describe('AgentActivity - mainTask', () => { it('should recover when speech handle is interrupted after authorization', async () => { const { fakeActivity, mainTask, speechQueue, q_updated } = buildMainTaskRunner(); diff --git a/agents/src/voice/agent_activity.ts b/agents/src/voice/agent_activity.ts index 21a6a9fac..1a20d5c17 100644 --- a/agents/src/voice/agent_activity.ts +++ b/agents/src/voice/agent_activity.ts @@ -203,6 +203,34 @@ interface PausedSpeechInfo { timeout: number; } +/** @internal */ +export function filterFunctionCallOutputsForRealtimeSession( + chatCtx: ChatContext, + functionCallOutputs: FunctionCallOutput[], +): { + currentFunctionCallOutputs: FunctionCallOutput[]; + staleFunctionCallOutputs: FunctionCallOutput[]; +} { + const currentFunctionCallIds = new Set(); + for (const item of chatCtx.items) { + if (item.type === 'function_call') { + currentFunctionCallIds.add(item.callId); + } + } + + const currentFunctionCallOutputs: FunctionCallOutput[] = []; + const staleFunctionCallOutputs: FunctionCallOutput[] = []; + for (const output of functionCallOutputs) { + if (currentFunctionCallIds.has(output.callId)) { + currentFunctionCallOutputs.push(output); + } else { + staleFunctionCallOutputs.push(output); + } + } + + return { currentFunctionCallOutputs, staleFunctionCallOutputs }; +} + export class AgentActivity implements RecognitionHooks { agent: Agent; agentSession: AgentSession; @@ -3601,8 +3629,12 @@ export class AgentActivity implements RecognitionHooks { return; } - const { functionToolsExecutedEvent, shouldGenerateToolReply, newAgentTask, ignoreTaskSwitch } = - this.summarizeToolExecutionOutput(toolOutput, speechHandle); + const { + functionToolsExecutedEvent, + replyRequiredFunctionCallIds, + newAgentTask, + ignoreTaskSwitch, + } = this.summarizeToolExecutionOutput(toolOutput, speechHandle); this.agentSession.emit( AgentSessionEventTypes.FunctionToolsExecuted, @@ -3615,6 +3647,8 @@ export class AgentActivity implements RecognitionHooks { schedulingPaused = true; } + let shouldGenerateToolReply = false; + if (functionToolsExecutedEvent.functionCallOutputs.length > 0) { // wait all speeches played before updating the tool output and generating the response // most realtime models dont support generating multiple responses at the same time @@ -3630,13 +3664,40 @@ export class AgentActivity implements RecognitionHooks { await new ThrowsPromise((resolve) => setImmediate(resolve)); } } - const chatCtx = realtimeSession.chatCtx.copy(); - chatCtx.items.push(...functionToolsExecutedEvent.functionCallOutputs); + const { currentFunctionCallOutputs, staleFunctionCallOutputs } = + filterFunctionCallOutputsForRealtimeSession( + realtimeSession.chatCtx, + functionToolsExecutedEvent.functionCallOutputs as FunctionCallOutput[], + ); - this.agentSession._toolItemsAdded( - functionToolsExecutedEvent.functionCallOutputs as FunctionCallOutput[], + if (staleFunctionCallOutputs.length > 0) { + this.logger.warn( + { + callIds: staleFunctionCallOutputs.map((output) => output.callId), + }, + 'discarding tool outputs for function calls no longer present in the realtime session', + ); + } + + if (currentFunctionCallOutputs.length === 0) { + if (this.agentSession.agentState === 'thinking') { + this.agentSession._updateAgentState('listening'); + if (this.audioRecognition) { + this.audioRecognition.onEndOfAgentSpeech(Date.now()); + } + } + return; + } + + shouldGenerateToolReply = currentFunctionCallOutputs.some((output) => + replyRequiredFunctionCallIds.has(output.callId), ); + const chatCtx = realtimeSession.chatCtx.copy(); + chatCtx.items.push(...currentFunctionCallOutputs); + + this.agentSession._toolItemsAdded(currentFunctionCallOutputs); + // If the realtime model auto-generates the tool reply, install a // placeholder so the active RunResult waits for that reply. let fut: Future | undefined; @@ -3735,6 +3796,7 @@ export class AgentActivity implements RecognitionHooks { }); let shouldGenerateToolReply = false; + const replyRequiredFunctionCallIds = new Set(); let newAgentTask: Agent | null = null; let ignoreTaskSwitch = false; @@ -3745,6 +3807,7 @@ export class AgentActivity implements RecognitionHooks { functionToolsExecutedEvent.functionCallOutputs.push(sanitizedOut.toolCallOutput); if (sanitizedOut.replyRequired) { shouldGenerateToolReply = true; + replyRequiredFunctionCallIds.add(sanitizedOut.toolCall.callId); } } @@ -3770,6 +3833,7 @@ export class AgentActivity implements RecognitionHooks { return { functionToolsExecutedEvent, shouldGenerateToolReply, + replyRequiredFunctionCallIds, newAgentTask, ignoreTaskSwitch, }; diff --git a/plugins/openai/src/realtime/realtime_model.test.ts b/plugins/openai/src/realtime/realtime_model.test.ts index 6080cdbd6..d6ace309c 100644 --- a/plugins/openai/src/realtime/realtime_model.test.ts +++ b/plugins/openai/src/realtime/realtime_model.test.ts @@ -31,6 +31,16 @@ type ResponseDoneSessionInternals = { }; }; +type ReconnectSessionInternals = { + remoteChatCtx: llm.RemoteChatContext; + chatCtxForReconnect: () => llm.ChatContext; +}; + +type GenerationCleanupSessionInternals = { + currentGeneration?: ResponseDoneSessionInternals['currentGeneration']; + closeCurrentGeneration: (reason: string) => void; +}; + function createSessionForTest(): RealtimeSessionInternals { const session = Object.create(RealtimeSession.prototype) as RealtimeSessionInternals; session.responseCreatedFutures = {}; @@ -618,6 +628,73 @@ describe('livekitItemToOpenAIItem', () => { }); }); +describe('RealtimeSession reconnect chat context', () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('excludes function call items when rebuilding context after reconnect', () => { + stubTaskRuntime(); + + const model = new RealtimeModel({ apiKey: 'test-key' }); + const session = model.session() as unknown as ReconnectSessionInternals; + session.remoteChatCtx = new llm.RemoteChatContext(); + + session.remoteChatCtx.insert( + undefined, + llm.ChatMessage.create({ + id: 'msg_user', + role: 'user', + content: ['what is the weather?'], + }), + ); + session.remoteChatCtx.insert( + 'msg_user', + new llm.FunctionCall({ + id: 'func_call', + callId: 'call_weather', + name: 'get_weather', + args: '{"location":"San Francisco"}', + }), + ); + session.remoteChatCtx.insert( + 'func_call', + new llm.FunctionCallOutput({ + id: 'func_output', + callId: 'call_weather', + name: 'get_weather', + output: '{"temperature":72}', + isError: false, + }), + ); + + const chatCtx = session.chatCtxForReconnect(); + + expect(chatCtx.items.map((item) => item.type)).toEqual(['message']); + }); + + it('closes the current generation when cleaning up a reconnected session', () => { + stubTaskRuntime(); + + const model = new RealtimeModel({ apiKey: 'test-key' }); + const session = model.session() as unknown as GenerationCleanupSessionInternals; + const doneFut = new Future(); + + session.currentGeneration = { + messageChannel: stream.createStreamChannel(), + functionChannel: stream.createStreamChannel(), + messages: new Map(), + _doneFut: doneFut, + _createdTimestamp: Date.now(), + }; + + session.closeCurrentGeneration('session reconnection'); + + expect(doneFut.done).toBe(true); + expect(session.currentGeneration).toBeUndefined(); + }); +}); + describe('RealtimeSession.updateOptions', () => { afterEach(() => { vi.restoreAllMocks(); diff --git a/plugins/openai/src/realtime/realtime_model.ts b/plugins/openai/src/realtime/realtime_model.ts index e8f88035d..e24c5b776 100644 --- a/plugins/openai/src/realtime/realtime_model.ts +++ b/plugins/openai/src/realtime/realtime_model.ts @@ -1045,6 +1045,7 @@ export class RealtimeSession extends llm.RealtimeSession { } } this.responseCreatedFutures = {}; + this.closeCurrentGeneration('session reconnection'); // Clear audio-capable item tracking - restored items are text-only on the server this.audioCapableItemIds.clear(); @@ -1061,11 +1062,7 @@ export class RealtimeSession extends llm.RealtimeSession { } // chat context - const chatCtx = this.chatCtx.copy({ - excludeFunctionCall: true, - excludeInstructions: true, - excludeEmptyMessage: true, - }); + const chatCtx = this.chatCtxForReconnect(); const oldChatCtx = this.remoteChatCtx; this.remoteChatCtx = new llm.RemoteChatContext(); @@ -1150,6 +1147,14 @@ export class RealtimeSession extends llm.RealtimeSession { } } + private chatCtxForReconnect(): llm.ChatContext { + return this.chatCtx.copy({ + excludeFunctionCall: true, + excludeInstructions: true, + excludeEmptyMessage: true, + }); + } + private async runWs(wsConn: WebSocket): Promise { const forwardEvents = async (signal: AbortSignal): Promise => { const abortFuture = new Future(); @@ -1336,29 +1341,41 @@ export class RealtimeSession extends llm.RealtimeSession { this.itemDeleteFutures = {}; this.inputTranscriptAccumulators.clear(); - - // Clean up current generation if exists - if (this.currentGeneration) { - for (const gen of this.currentGeneration.messages.values()) { - gen.textChannel.close(); - gen.audioChannel.close(); - if (!gen.modalities.done) { - gen.modalities.resolve(this._options.modalities); - } - } - this.currentGeneration.messages.clear(); - this.currentGeneration.messageChannel.close(); - this.currentGeneration.functionChannel.close(); - if (!this.currentGeneration._doneFut.done) { - this.currentGeneration._doneFut.resolve(); - } - this.currentGeneration = undefined; - } + this.closeCurrentGeneration('session close'); // Clear the message queue this.messageChannel.items.length = 0; } + private closeCurrentGeneration(reason: string): void { + if (!this.currentGeneration) { + return; + } + + this.#logger.debug( + { + reason, + messageCount: this.currentGeneration.messages.size, + }, + 'Closing current OpenAI Realtime generation', + ); + + for (const gen of this.currentGeneration.messages.values()) { + gen.textChannel.close(); + gen.audioChannel.close(); + if (!gen.modalities.done) { + gen.modalities.resolve(this._options.modalities); + } + } + this.currentGeneration.messages.clear(); + this.currentGeneration.messageChannel.close(); + this.currentGeneration.functionChannel.close(); + if (!this.currentGeneration._doneFut.done) { + this.currentGeneration._doneFut.resolve(); + } + this.currentGeneration = undefined; + } + private handleInputAudioBufferSpeechStarted( _event: api_proto.InputAudioBufferSpeechStartedEvent, ): void {