diff --git a/strands-ts/src/agent/agent.ts b/strands-ts/src/agent/agent.ts index 7fb3d6484..fd060ea0d 100644 --- a/strands-ts/src/agent/agent.ts +++ b/strands-ts/src/agent/agent.ts @@ -44,6 +44,8 @@ import { PluginRegistry } from '../plugins/registry.js' import { SlidingWindowConversationManager } from '../conversation-manager/sliding-window-conversation-manager.js' import { NullConversationManager } from '../conversation-manager/null-conversation-manager.js' import { ConversationManager } from '../conversation-manager/conversation-manager.js' +import type { ContextManagerParam } from '../context-manager/context-manager.js' +import { resolveContextManager } from '../context-manager/context-manager.js' import { HookRegistryImplementation } from '../hooks/registry.js' import type { HookableEventConstructor, HookCallback, HookCallbackOptions, HookCleanup } from '../hooks/types.js' import { @@ -167,9 +169,24 @@ export type AgentConfig = { * Defaults to true. */ printer?: boolean + /** + * Pre-composed context management strategy. + * + * - `"auto"`: enables tool result caching and proactive compression with defaults. + * - Object: fine-grained control over strategy, storage, caching, and compression settings. + * - `undefined` (default): no context management facade; use `conversationManager` + * and `plugins` directly. + * + * When set, takes priority over `conversationManager` — `NullConversationManager` is used. + */ + contextManager?: ContextManagerParam /** * Conversation manager for handling message history and context overflow. * Defaults to SlidingWindowConversationManager with windowSize of 40. + * + * @remarks Pending deprecation — use `contextManager` instead. The `contextManager` parameter + * composes compression, tool result caching, and token estimation into a single + * configuration surface. This field will be deprecated in a future version. */ conversationManager?: ConversationManager /** @@ -331,15 +348,29 @@ export class Agent implements LocalAgent, InvokableAgent { this.model = config?.model ?? new BedrockModel() } - // Validate and assign conversation manager + let contextManagerPlugin: Plugin | undefined + if (config?.contextManager) { + contextManagerPlugin = resolveContextManager(config.contextManager, config.plugins) + } + + // Validate and assign conversation manager. + // When contextManager is set, ContextCompression owns compression — use NullConversationManager. if (this.model.stateful) { - if (config?.conversationManager) { + if (config?.conversationManager || config?.contextManager) { throw new Error( - 'Cannot use a conversationManager with a stateful model. The model manages conversation state server-side.' + 'Cannot use a conversationManager or contextManager with a stateful model. The model manages conversation state server-side.' ) } this._conversationManager = new NullConversationManager() + } else if (contextManagerPlugin) { + if (config?.conversationManager) { + logger.warn('contextManager takes priority over conversationManager — conversationManager will be ignored') + } + this._conversationManager = new NullConversationManager() } else { + if (config?.conversationManager) { + logger.warn('conversationManager is deprecated and will be removed in v2. Use contextManager instead.') + } this._conversationManager = config?.conversationManager ?? new SlidingWindowConversationManager({ windowSize: 40 }) } @@ -372,9 +403,12 @@ export class Agent implements LocalAgent, InvokableAgent { // - Retry-strategy ordering is not load-bearing for correctness: `DefaultModelRetryStrategy` // guards on `event.retry`, so a user hook that already set it short-circuits // the strategy regardless of registration order. + // - contextManager plugin goes before user plugins so the offloader's AfterToolCallEvent + // hook fires first, ensuring large results are cached before user hooks see the event. this._pluginRegistry = new PluginRegistry([ this._conversationManager, ...retryStrategies, + ...(contextManagerPlugin ? [contextManagerPlugin] : []), ...(config?.plugins ?? []), ...(config?.sessionManager ? [config.sessionManager] : []), new ModelPlugin(this.model), @@ -1397,7 +1431,8 @@ export class Agent implements LocalAgent, InvokableAgent { let attemptCount = 1 while (true) { - // Estimate input tokens for the upcoming model call (non-fatal if estimation fails) + // Pending deprecation: token estimation will move fully to ContextManager. + // This remains for backward compat with standalone ConversationManager.proactiveCompression. let projectedInputTokens: number | undefined try { projectedInputTokens = await this._estimateInputTokens(streamOptions) diff --git a/strands-ts/src/context-manager/__tests__/context-manager.test.ts b/strands-ts/src/context-manager/__tests__/context-manager.test.ts new file mode 100644 index 000000000..e03944b58 --- /dev/null +++ b/strands-ts/src/context-manager/__tests__/context-manager.test.ts @@ -0,0 +1,164 @@ +import { describe, it, expect } from 'vitest' +import { ContextManager, resolveContextManager } from '../context-manager.js' +import { ContextCompression } from '../compression/context-compression.js' +import { InMemoryStorage } from '../../vended-plugins/context-offloader/storage.js' +import { createMockAgent } from '../../__fixtures__/agent-helpers.js' +import type { Plugin } from '../../plugins/plugin.js' + +describe('resolveContextManager', () => { + it('with "auto" enables both compression and offloader', () => { + const cm = resolveContextManager('auto') + const subPlugins = (cm as any)._subPlugins as Plugin[] + + expect(subPlugins).toHaveLength(2) + const names = subPlugins.map((p) => p.name) + expect(names).toContain('strands:context-compression') + expect(names).toContain('strands:context-offloader') + }) + + it('with config object is additive (omitted = disabled)', () => { + const cm = resolveContextManager({ compression: true }) + const subPlugins = (cm as any)._subPlugins as Plugin[] + + const names = subPlugins.map((p) => p.name) + expect(names).toContain('strands:context-compression') + expect(names).not.toContain('strands:context-offloader') + }) + + it('with strategy: "auto" applies override semantics (omitted features stay enabled)', () => { + const cm = resolveContextManager({ strategy: 'auto', compression: 'summarize' }) + const subPlugins = (cm as any)._subPlugins as Plugin[] + + const names = subPlugins.map((p) => p.name) + // Both should be enabled because strategy: 'auto' defaults include offloader: true + expect(names).toContain('strands:context-compression') + expect(names).toContain('strands:context-offloader') + + // Compression should use summarize method + const compression = subPlugins.find((p) => p.name === 'strands:context-compression') as ContextCompression + expect((compression as any)._method).toBe('summarize') + }) + + it('with offloader: true uses default thresholds', () => { + const cm = resolveContextManager({ offloader: true }) + const subPlugins = (cm as any)._subPlugins as Plugin[] + + const offloader = subPlugins.find((p) => p.name === 'strands:context-offloader') + expect(offloader).toBeDefined() + }) + + it('with offloader config applies custom settings', () => { + const cm = resolveContextManager({ offloader: { threshold: 5000, previewTokens: 1000 } }) + const subPlugins = (cm as any)._subPlugins as Plugin[] + + const offloader = subPlugins.find((p) => p.name === 'strands:context-offloader') + expect(offloader).toBeDefined() + }) +}) + +describe('ContextManager._buildSubPlugins', () => { + it('skips compression plugin when user already provides one', () => { + const userCompression: Plugin = { + name: 'strands:context-compression', + initAgent: () => {}, + getTools: () => [], + } + + const cm = resolveContextManager('auto', [userCompression]) + const subPlugins = (cm as any)._subPlugins as Plugin[] + + const compressionPlugins = subPlugins.filter((p) => p.name === 'strands:context-compression') + expect(compressionPlugins).toHaveLength(0) + // Offloader should still be present + const offloaderPlugins = subPlugins.filter((p) => p.name === 'strands:context-offloader') + expect(offloaderPlugins).toHaveLength(1) + }) + + it('skips offloader plugin when user already provides one', () => { + const userOffloader: Plugin = { + name: 'strands:context-offloader', + initAgent: () => {}, + getTools: () => [], + } + + const cm = resolveContextManager('auto', [userOffloader]) + const subPlugins = (cm as any)._subPlugins as Plugin[] + + const offloaderPlugins = subPlugins.filter((p) => p.name === 'strands:context-offloader') + expect(offloaderPlugins).toHaveLength(0) + // Compression should still be present + const compressionPlugins = subPlugins.filter((p) => p.name === 'strands:context-compression') + expect(compressionPlugins).toHaveLength(1) + }) +}) + +describe('ContextManager', () => { + describe('constructor', () => { + it('uses InMemoryStorage by default', () => { + const cm = new ContextManager() + expect(cm.storage).toBeInstanceOf(InMemoryStorage) + }) + + it('accepts custom storage', () => { + const storage = new InMemoryStorage() + const cm = new ContextManager({ storage }) + expect(cm.storage).toBe(storage) + }) + + it('has correct plugin name', () => { + const cm = new ContextManager() + expect(cm.name).toBe('strands:context-manager') + }) + }) + + describe('initAgent', () => { + it('initializes sub-plugins', () => { + const cm = new ContextManager({ compression: true, offloader: true }) + const agent = createMockAgent() + + cm.initAgent(agent) + + // Should have registered hooks from both sub-plugins + expect(agent.trackedHooks.length).toBeGreaterThan(0) + }) + + it('builds sub-plugins if not already resolved', () => { + const cm = new ContextManager({ compression: true }) + const agent = createMockAgent() + + // Don't call _resolveSubPlugins first + cm.initAgent(agent) + + // Should still work and register hooks + expect(agent.trackedHooks.length).toBeGreaterThan(0) + }) + }) + + describe('getTools', () => { + it('returns tools from sub-plugins', () => { + const cm = new ContextManager({ offloader: true }) + cm._resolveSubPlugins() + + const tools = cm.getTools() + // ContextOffloader provides retrieval tool by default + expect(tools.length).toBeGreaterThan(0) + expect(tools[0]!.name).toBe('retrieve_offloaded_content') + }) + + it('returns empty array when no sub-plugins configured', () => { + const cm = new ContextManager({}) + cm._resolveSubPlugins() + + const tools = cm.getTools() + expect(tools).toHaveLength(0) + }) + + it('returns empty array when sub-plugins are not resolved yet', () => { + const cm = new ContextManager({ offloader: true }) + // Don't resolve sub-plugins + + const tools = cm.getTools() + expect(tools).toHaveLength(0) + }) + }) +}) diff --git a/strands-ts/src/context-manager/__tests__/token-estimation.test.ts b/strands-ts/src/context-manager/__tests__/token-estimation.test.ts new file mode 100644 index 000000000..c82026ef1 --- /dev/null +++ b/strands-ts/src/context-manager/__tests__/token-estimation.test.ts @@ -0,0 +1,104 @@ +import { describe, it, expect, vi } from 'vitest' +import { estimateInputTokens } from '../token-estimation.js' +import { Message, TextBlock } from '../../types/messages.js' +import type { Model } from '../../models/model.js' + +function userMsg(text: string): Message { + return new Message({ role: 'user', content: [new TextBlock(text)] }) +} + +function assistantMsg( + text: string, + usage?: { inputTokens: number; outputTokens: number; totalTokens: number } +): Message { + return new Message({ + role: 'assistant', + content: [new TextBlock(text)], + ...(usage && { metadata: { usage } }), + }) +} + +function mockModel(countTokens?: (messages: Message[]) => Promise): Model { + return { + countTokens: countTokens ?? vi.fn().mockResolvedValue(100), + } as unknown as Model +} + +describe('estimateInputTokens', () => { + it('returns baseline from last assistant message usage metadata', async () => { + const messages = [ + userMsg('hello'), + assistantMsg('response', { inputTokens: 50, outputTokens: 20, totalTokens: 70 }), + ] + const model = mockModel() + + const result = await estimateInputTokens(messages, model) + + expect(result).toBe(70) // 50 + 20 + }) + + it('adds new message tokens to baseline when messages exist after the assistant message', async () => { + const messages = [ + userMsg('hello'), + assistantMsg('response', { inputTokens: 50, outputTokens: 20, totalTokens: 70 }), + userMsg('follow up'), + ] + const countTokens = vi.fn().mockResolvedValue(15) + const model = mockModel(countTokens) + + const result = await estimateInputTokens(messages, model) + + expect(result).toBe(85) // 70 + 15 + expect(countTokens).toHaveBeenCalledWith([messages[2]]) + }) + + it('uses the last assistant message with usage (not earlier ones)', async () => { + const messages = [ + userMsg('hello'), + assistantMsg('first', { inputTokens: 10, outputTokens: 5, totalTokens: 15 }), + userMsg('second'), + assistantMsg('latest', { inputTokens: 80, outputTokens: 30, totalTokens: 110 }), + ] + const model = mockModel() + + const result = await estimateInputTokens(messages, model) + + expect(result).toBe(110) // 80 + 30 + }) + + it('falls back to model.countTokens when no assistant message has usage metadata', async () => { + const messages = [ + userMsg('hello'), + new Message({ role: 'assistant', content: [new TextBlock('no metadata')] }), + userMsg('world'), + ] + const countTokens = vi.fn().mockResolvedValue(42) + const model = mockModel(countTokens) + + const result = await estimateInputTokens(messages, model) + + expect(result).toBe(42) + expect(countTokens).toHaveBeenCalledWith(messages) + }) + + it('falls back to model.countTokens when there are no assistant messages', async () => { + const messages = [userMsg('hello')] + const countTokens = vi.fn().mockResolvedValue(10) + const model = mockModel(countTokens) + + const result = await estimateInputTokens(messages, model) + + expect(result).toBe(10) + expect(countTokens).toHaveBeenCalledWith(messages) + }) + + it('returns undefined on error', async () => { + const messages = [userMsg('hello')] + const countTokens = vi.fn().mockRejectedValue(new Error('API error')) + const model = mockModel(countTokens) + + const result = await estimateInputTokens(messages, model) + + expect(result).toBeUndefined() + }) +}) diff --git a/strands-ts/src/context-manager/compression/__tests__/context-compression.test.ts b/strands-ts/src/context-manager/compression/__tests__/context-compression.test.ts new file mode 100644 index 000000000..ccd13542d --- /dev/null +++ b/strands-ts/src/context-manager/compression/__tests__/context-compression.test.ts @@ -0,0 +1,354 @@ +import { describe, it, expect, vi } from 'vitest' +import { ContextCompression } from '../context-compression.js' +import { Message, TextBlock } from '../../../types/messages.js' +import { AfterInvocationEvent, AfterModelCallEvent, BeforeModelCallEvent } from '../../../hooks/events.js' +import { ContextWindowOverflowError } from '../../../errors.js' +import { createMockAgent, invokeTrackedHook } from '../../../__fixtures__/agent-helpers.js' +import type { BaseModelConfig } from '../../../models/model.js' + +function userMsg(text: string): Message { + return new Message({ role: 'user', content: [new TextBlock(text)] }) +} + +function assistantMsg(text: string): Message { + return new Message({ role: 'assistant', content: [new TextBlock(text)] }) +} + +describe('ContextCompression', () => { + describe('constructor', () => { + it('validates proactive threshold must be > 0', () => { + expect(() => new ContextCompression({ proactive: { threshold: 0 } })).toThrow( + 'proactive compression threshold must be between 0 (exclusive) and 1 (inclusive)' + ) + }) + + it('validates proactive threshold must be <= 1', () => { + expect(() => new ContextCompression({ proactive: { threshold: 1.5 } })).toThrow( + 'proactive compression threshold must be between 0 (exclusive) and 1 (inclusive)' + ) + }) + + it('accepts valid proactive threshold', () => { + expect(() => new ContextCompression({ proactive: { threshold: 0.8 } })).not.toThrow() + }) + + it('accepts threshold of exactly 1', () => { + expect(() => new ContextCompression({ proactive: { threshold: 1 } })).not.toThrow() + }) + + it('defaults to truncate method', () => { + const compression = new ContextCompression() + expect((compression as any)._method).toBe('truncate') + }) + + it('accepts summarize method', () => { + const compression = new ContextCompression({ method: 'summarize' }) + expect((compression as any)._method).toBe('summarize') + }) + + it('defaults proactive threshold to 0.7 when proactive is true or omitted', () => { + const compression = new ContextCompression() + expect((compression as any)._proactiveThreshold).toBe(0.7) + }) + + it('disables proactive compression when proactive is false', () => { + const compression = new ContextCompression({ proactive: false }) + expect((compression as any)._proactiveThreshold).toBeUndefined() + }) + }) + + describe('initAgent', () => { + it('registers AfterModelCallEvent hook', () => { + const compression = new ContextCompression() + const agent = createMockAgent() + compression.initAgent(agent) + + const hookTypes = agent.trackedHooks.map((h) => h.eventType) + expect(hookTypes).toContain(AfterModelCallEvent) + }) + + it('registers BeforeModelCallEvent hook', () => { + const compression = new ContextCompression() + const agent = createMockAgent() + compression.initAgent(agent) + + const hookTypes = agent.trackedHooks.map((h) => h.eventType) + expect(hookTypes).toContain(BeforeModelCallEvent) + }) + + it('registers AfterInvocationEvent hook when method is truncate', () => { + const compression = new ContextCompression({ method: 'truncate' }) + const agent = createMockAgent() + compression.initAgent(agent) + + const hookTypes = agent.trackedHooks.map((h) => h.eventType) + expect(hookTypes).toContain(AfterInvocationEvent) + }) + + it('does not register AfterInvocationEvent hook when method is summarize', () => { + const compression = new ContextCompression({ method: 'summarize' }) + const agent = createMockAgent() + compression.initAgent(agent) + + const hookTypes = agent.trackedHooks.map((h) => h.eventType) + expect(hookTypes).not.toContain(AfterInvocationEvent) + }) + }) + + describe('proactive hook (BeforeModelCallEvent)', () => { + it('skips when proactive threshold is undefined', async () => { + const compression = new ContextCompression({ proactive: false }) + const messages = Array.from({ length: 50 }, (_, i) => + i % 2 === 0 ? userMsg(`msg ${i}`) : assistantMsg(`resp ${i}`) + ) + const mockAgent = createMockAgent({ messages }) + compression.initAgent(mockAgent) + + const mockModel = { getConfig: () => ({ contextWindowLimit: 1000 }) as BaseModelConfig } as any + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + projectedInputTokens: 900, // 90% of limit + }) + await invokeTrackedHook(mockAgent, event) + + // Messages should not have been modified + expect(mockAgent.messages).toHaveLength(50) + }) + + it('triggers reduce when ratio exceeds threshold', async () => { + const compression = new ContextCompression({ + proactive: { threshold: 0.7 }, + windowSize: 4, + }) + const messages = [ + userMsg('Message 1'), + assistantMsg('Response 1'), + userMsg('Message 2'), + assistantMsg('Response 2'), + userMsg('Message 3'), + assistantMsg('Response 3'), + ] + const mockAgent = createMockAgent({ messages }) + compression.initAgent(mockAgent) + + const mockModel = { getConfig: () => ({ contextWindowLimit: 1000 }) as BaseModelConfig } as any + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + projectedInputTokens: 800, // 80% > 70% threshold + }) + await invokeTrackedHook(mockAgent, event) + + expect(mockAgent.messages.length).toBeLessThan(6) + }) + + it('does not trigger reduce when ratio is below threshold', async () => { + const compression = new ContextCompression({ proactive: { threshold: 0.7 } }) + const messages = [userMsg('Message 1'), assistantMsg('Response 1')] + const mockAgent = createMockAgent({ messages }) + compression.initAgent(mockAgent) + + const mockModel = { getConfig: () => ({ contextWindowLimit: 1000 }) as BaseModelConfig } as any + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + projectedInputTokens: 500, // 50% < 70% threshold + }) + await invokeTrackedHook(mockAgent, event) + + expect(mockAgent.messages).toHaveLength(2) + }) + + it('uses estimateInputTokens when projectedInputTokens is not provided', async () => { + const compression = new ContextCompression({ + proactive: { threshold: 0.7 }, + windowSize: 2, + }) + const messages = [ + userMsg('Message 1'), + new Message({ + role: 'assistant', + content: [new TextBlock('Response 1')], + metadata: { usage: { inputTokens: 600, outputTokens: 200, totalTokens: 800 } }, + }), + userMsg('Message 2'), + assistantMsg('Response 2'), + userMsg('Message 3'), + ] + const mockAgent = createMockAgent({ messages }) + compression.initAgent(mockAgent) + + const mockModel = { + getConfig: () => ({ contextWindowLimit: 1000 }) as BaseModelConfig, + countTokens: vi.fn().mockResolvedValue(100), + } as any + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + // No projectedInputTokens — will use estimateInputTokens + }) + await invokeTrackedHook(mockAgent, event) + + // 600+200+countTokens(remaining) = 900 > 700 threshold — should compress + expect(mockAgent.messages.length).toBeLessThan(5) + }) + + it('skips when projectedInputTokens is undefined and estimation returns undefined', async () => { + const compression = new ContextCompression({ proactive: { threshold: 0.7 } }) + const messages = [userMsg('Message 1'), assistantMsg('Response 1')] + const mockAgent = createMockAgent({ messages }) + compression.initAgent(mockAgent) + + const mockModel = { + getConfig: () => ({ contextWindowLimit: 1000 }) as BaseModelConfig, + countTokens: vi.fn().mockRejectedValue(new Error('fail')), + } as any + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + }) + await invokeTrackedHook(mockAgent, event) + + expect(mockAgent.messages).toHaveLength(2) + }) + }) + + describe('reactive hook (AfterModelCallEvent)', () => { + it('retries on ContextWindowOverflowError', async () => { + const compression = new ContextCompression({ windowSize: 2 }) + const messages = [ + userMsg('Message 1'), + assistantMsg('Response 1'), + userMsg('Message 2'), + assistantMsg('Response 2'), + ] + const mockAgent = createMockAgent({ messages }) + compression.initAgent(mockAgent) + + const event = new AfterModelCallEvent({ + agent: mockAgent, + model: {} as any, + attemptCount: 1, + error: new ContextWindowOverflowError('overflow'), + invocationState: {}, + }) + await invokeTrackedHook(mockAgent, event) + + expect(event.retry).toBe(true) + expect(mockAgent.messages.length).toBeLessThan(4) + }) + + it('does not set retry when error is not ContextWindowOverflowError', async () => { + const compression = new ContextCompression() + const messages = [userMsg('Message 1'), assistantMsg('Response 1')] + const mockAgent = createMockAgent({ messages }) + compression.initAgent(mockAgent) + + const event = new AfterModelCallEvent({ + agent: mockAgent, + model: {} as any, + attemptCount: 1, + error: new Error('some other error'), + invocationState: {}, + }) + await invokeTrackedHook(mockAgent, event) + + expect(event.retry).toBeUndefined() + }) + + it('does not set retry when no error is present', async () => { + const compression = new ContextCompression() + const messages = [userMsg('Message 1'), assistantMsg('Response 1')] + const mockAgent = createMockAgent({ messages }) + compression.initAgent(mockAgent) + + const event = new AfterModelCallEvent({ + agent: mockAgent, + model: {} as any, + attemptCount: 1, + invocationState: {}, + }) + await invokeTrackedHook(mockAgent, event) + + expect(event.retry).toBeUndefined() + }) + + it('does not set retry when reduce returns false', async () => { + const compression = new ContextCompression({ windowSize: 10 }) + // Only 2 messages - truncate returns false for messages.length <= 2 + const messages = [userMsg('Message 1'), assistantMsg('Response 1')] + const mockAgent = createMockAgent({ messages }) + compression.initAgent(mockAgent) + + const event = new AfterModelCallEvent({ + agent: mockAgent, + model: {} as any, + attemptCount: 1, + error: new ContextWindowOverflowError('overflow'), + invocationState: {}, + }) + await invokeTrackedHook(mockAgent, event) + + expect(event.retry).toBeUndefined() + }) + }) + + describe('AfterInvocationEvent hook (sliding window enforcement)', () => { + it('truncates when messages exceed window size', async () => { + const compression = new ContextCompression({ windowSize: 4 }) + const messages = [ + userMsg('Message 1'), + assistantMsg('Response 1'), + userMsg('Message 2'), + assistantMsg('Response 2'), + userMsg('Message 3'), + assistantMsg('Response 3'), + ] + const mockAgent = createMockAgent({ messages }) + compression.initAgent(mockAgent) + + const event = new AfterInvocationEvent({ + agent: mockAgent, + invocationState: {}, + }) + + // Find AfterInvocationEvent hook specifically + const hook = mockAgent.trackedHooks.find((h) => h.eventType === AfterInvocationEvent) + expect(hook).toBeDefined() + await hook!.callback(event) + + expect(mockAgent.messages.length).toBeLessThanOrEqual(4) + }) + + it('does not truncate when messages are within window size', async () => { + const compression = new ContextCompression({ windowSize: 10 }) + const messages = [userMsg('Message 1'), assistantMsg('Response 1')] + const mockAgent = createMockAgent({ messages }) + compression.initAgent(mockAgent) + + const event = new AfterInvocationEvent({ + agent: mockAgent, + invocationState: {}, + }) + + const hook = mockAgent.trackedHooks.find((h) => h.eventType === AfterInvocationEvent) + expect(hook).toBeDefined() + await hook!.callback(event) + + expect(mockAgent.messages).toHaveLength(2) + }) + }) + + describe('getTools', () => { + it('returns empty array', () => { + const compression = new ContextCompression() + expect(compression.getTools()).toEqual([]) + }) + }) +}) diff --git a/strands-ts/src/context-manager/compression/__tests__/protection.test.ts b/strands-ts/src/context-manager/compression/__tests__/protection.test.ts new file mode 100644 index 000000000..6deb8f55d --- /dev/null +++ b/strands-ts/src/context-manager/compression/__tests__/protection.test.ts @@ -0,0 +1,67 @@ +import { describe, it, expect } from 'vitest' +import { isProtected, pinMessage } from '../protection.js' +import { Message, TextBlock, ToolUseBlock, ToolResultBlock } from '../../../types/messages.js' + +function userMsg(text: string): Message { + return new Message({ role: 'user', content: [new TextBlock(text)] }) +} + +function assistantMsg(text: string): Message { + return new Message({ role: 'assistant', content: [new TextBlock(text)] }) +} + +function toolUseMsg(toolUseId: string): Message { + return new Message({ role: 'assistant', content: [new ToolUseBlock({ toolUseId, name: 'test', input: {} })] }) +} + +function toolResultMsg(toolUseId: string): Message { + return new Message({ + role: 'user', + content: [new ToolResultBlock({ toolUseId, content: [new TextBlock('result')], status: 'success' })], + }) +} + +describe('isProtected', () => { + describe('no range, no pin', () => { + it('returns false for unprotected message', () => { + const messages = [userMsg('a'), assistantMsg('b')] + expect(isProtected(messages, 0)).toBe(false) + expect(isProtected(messages, 1)).toBe(false) + }) + }) + + describe('positive range (protect first N)', () => { + it('protects messages within range', () => { + const messages = [userMsg('a'), assistantMsg('b'), userMsg('c')] + expect(isProtected(messages, 0, 2)).toBe(true) + expect(isProtected(messages, 1, 2)).toBe(true) + expect(isProtected(messages, 2, 2)).toBe(false) + }) + + it('protects toolUse outside range if its toolResult is inside range', () => { + const messages = [userMsg('task'), toolUseMsg('t1'), toolResultMsg('t1'), userMsg('next')] + // range=2 protects [0] and [1]. [2] is toolResult — check if toolUse at [1] being in range protects [2] + // Actually [2] is outside range. But [1] (toolUse) is in range, so [2] (toolResult, partner) should be protected. + expect(isProtected(messages, 2, 2)).toBe(true) + }) + + it('protects toolResult outside range if its toolUse is inside range', () => { + const messages = [toolUseMsg('t1'), toolResultMsg('t1'), userMsg('a'), assistantMsg('b')] + // range=1 protects [0] (toolUse). [1] (toolResult) is outside but partner is protected. + expect(isProtected(messages, 1, 1)).toBe(true) + }) + }) + + describe('pinned messages', () => { + it('protects pinned message regardless of range', () => { + const messages = [userMsg('a'), pinMessage(assistantMsg('pinned')), userMsg('c')] + expect(isProtected(messages, 1)).toBe(true) + expect(isProtected(messages, 1, 0)).toBe(true) + }) + + it('protects tool-pair partner of pinned message', () => { + const messages = [pinMessage(toolUseMsg('t1')), toolResultMsg('t1'), userMsg('a')] + expect(isProtected(messages, 1)).toBe(true) // toolResult partner of pinned toolUse + }) + }) +}) diff --git a/strands-ts/src/context-manager/compression/context-compression.ts b/strands-ts/src/context-manager/compression/context-compression.ts new file mode 100644 index 000000000..9c9cdbde7 --- /dev/null +++ b/strands-ts/src/context-manager/compression/context-compression.ts @@ -0,0 +1,178 @@ +import type { Plugin } from '../../plugins/plugin.js' +import type { LocalAgent } from '../../types/agent.js' +import type { Tool } from '../../tools/tool.js' +import type { Message } from '../../types/messages.js' +import type { Model } from '../../models/model.js' +import { AfterInvocationEvent, AfterModelCallEvent, BeforeModelCallEvent } from '../../hooks/events.js' +import { ContextWindowOverflowError } from '../../errors.js' +import { truncate, type TruncateOptions } from './strategies/truncate.js' +import { summarize, type SummarizeOptions } from './strategies/summarize.js' +import { estimateInputTokens } from '../token-estimation.js' +import { logger } from '../../logging/logger.js' +import { warnOnce } from '../../logging/warn-once.js' + +const DEFAULT_CONTEXT_WINDOW_LIMIT = 200_000 +const DEFAULT_PROACTIVE_THRESHOLD = 0.7 +const DEFAULT_WINDOW_SIZE = 40 + +export type CompressionMethod = 'truncate' | 'summarize' + +type SharedCompressionOptions = { + /** + * Proactive compression before the model call. + * - `true`: compress when 70% of the context window is used (default threshold). + * - `{ threshold: number }`: compress at the specified ratio (0, 1]. + * - `false`: disable proactive compression; only reactive overflow recovery is used. + * - Omitted: defaults to `true`. + */ + proactive?: boolean | { threshold: number } + /** + * Number of messages at the start of the conversation to protect from eviction. + * These messages are never trimmed or summarized away. + */ + protectFirst?: number +} + +export type TruncateCompressionConfig = SharedCompressionOptions & { + method?: 'truncate' + /** Maximum messages to keep after trimming. Defaults to 40. */ + windowSize?: number +} + +export type SummarizeCompressionConfig = SharedCompressionOptions & { + method: 'summarize' + /** Ratio of messages to summarize (0.1–0.8). Defaults to 0.3. */ + summaryRatio?: number + /** Minimum recent messages to preserve during summarization. Defaults to 10. */ + preserveRecentMessages?: number +} + +/** + * Compression configuration (discriminated union on `method`). + * + * @example + * ```typescript + * contextManager: { compression: true } // defaults (truncate) + * contextManager: { compression: 'summarize' } // strategy shorthand + * contextManager: { compression: { method: 'truncate', windowSize: 30 } } // full config + * contextManager: { compression: { method: 'summarize', summaryRatio: 0.5 } } // full config + * ``` + */ +export type CompressionOptions = TruncateCompressionConfig | SummarizeCompressionConfig + +/** + * Plugin that handles context compression — both proactive (before model call when + * threshold is exceeded) and reactive (after model call on overflow error). + * + * Delegates reduction to strategy functions (truncate or summarize). + */ +export class ContextCompression implements Plugin { + readonly name = 'strands:context-compression' + + private readonly _proactiveThreshold: number | undefined + private readonly _method: CompressionMethod + private readonly _windowSize: number + private readonly _truncateOptions: TruncateOptions + private readonly _summarizeOptions: SummarizeOptions + + constructor(config?: CompressionOptions) { + const proactive = config?.proactive ?? true + if (proactive === false) { + this._proactiveThreshold = undefined + } else if (proactive === true) { + this._proactiveThreshold = DEFAULT_PROACTIVE_THRESHOLD + } else { + if (proactive.threshold <= 0 || proactive.threshold > 1) { + throw new Error( + `proactive compression threshold must be between 0 (exclusive) and 1 (inclusive), got ${proactive.threshold}` + ) + } + this._proactiveThreshold = proactive.threshold + } + + this._method = config?.method ?? 'truncate' + this._windowSize = (config as TruncateCompressionConfig | undefined)?.windowSize ?? DEFAULT_WINDOW_SIZE + + this._truncateOptions = { + ...(config?.protectFirst !== undefined && { protectFirst: config.protectFirst }), + } + this._summarizeOptions = { + ...(config?.protectFirst !== undefined && { protectFirst: config.protectFirst }), + ...((config as SummarizeCompressionConfig)?.summaryRatio !== undefined && { + summaryRatio: (config as SummarizeCompressionConfig).summaryRatio, + }), + ...((config as SummarizeCompressionConfig)?.preserveRecentMessages !== undefined && { + preserveRecentMessages: (config as SummarizeCompressionConfig).preserveRecentMessages, + }), + } + } + + getTools(): Tool[] { + return [] + } + + initAgent(agent: LocalAgent): void { + // Reactive overflow recovery + agent.addHook(AfterModelCallEvent, async (event) => { + if (event.error instanceof ContextWindowOverflowError) { + if (await this._reduce(event.agent.messages, event.model)) { + event.retry = true + } + } + }) + + // Proactive compression + agent.addHook(BeforeModelCallEvent, async (event) => { + if (this._proactiveThreshold === undefined) { + return + } + + let contextWindowLimit = event.model.getConfig().contextWindowLimit + if (contextWindowLimit === undefined) { + contextWindowLimit = DEFAULT_CONTEXT_WINDOW_LIMIT + warnOnce( + logger, + `context_compression | contextWindowLimit is not set on the model, using default of ${DEFAULT_CONTEXT_WINDOW_LIMIT} | set contextWindowLimit in your model config for accurate proactive compression` + ) + } + + const projectedInputTokens = + event.projectedInputTokens ?? (await estimateInputTokens(event.agent.messages, event.model)) + + if (projectedInputTokens === undefined) { + return + } + + const ratio = projectedInputTokens / contextWindowLimit + if (ratio >= this._proactiveThreshold) { + logger.debug( + `projected_tokens=<${projectedInputTokens}>, limit=<${contextWindowLimit}>, ratio=<${ratio.toFixed(2)}>, threshold=<${this._proactiveThreshold}> | compression threshold exceeded, reducing context` + ) + try { + await this._reduce(event.agent.messages, event.model) + } catch (e) { + logger.warn(`context_compression | proactive compression failed, continuing | error=<${e}>`) + } + } + }) + + // Sliding window enforcement after each invocation (truncate method only) + if (this._method === 'truncate') { + agent.addHook(AfterInvocationEvent, (event) => { + if (event.agent.messages.length > this._windowSize) { + truncate(event.agent.messages, this._windowSize, this._truncateOptions) + } + }) + } + } + + private async _reduce(messages: Message[], model: Model): Promise { + switch (this._method) { + case 'summarize': + return summarize(messages, model, this._summarizeOptions) + case 'truncate': + default: + return truncate(messages, this._windowSize, this._truncateOptions) + } + } +} diff --git a/strands-ts/src/context-manager/compression/protection.ts b/strands-ts/src/context-manager/compression/protection.ts new file mode 100644 index 000000000..ee845b84e --- /dev/null +++ b/strands-ts/src/context-manager/compression/protection.ts @@ -0,0 +1,134 @@ +import { z } from 'zod' +import { Message, type ToolUseBlock, type ToolResultBlock } from '../../types/messages.js' +import { tool } from '../../tools/tool-factory.js' + +// --- Pin utilities --- + +/** + * Check if a single message is pinned. + * + * @param message - The message to check + * @returns `true` if the message has `metadata.custom.pinned === true` + */ +export function isPinned(message: Message): boolean +/** + * Check if a message is pinned, including tool-pair partner protection. + * Returns `true` if the message at `index` is pinned, or if it is the + * adjacent tool-pair partner (toolUse/toolResult) of a pinned message, + * matched by toolUseId. + * + * @param messages - The full messages array + * @param index - The index to check + * @returns `true` if the message or its tool-pair partner is pinned + */ +export function isPinned(messages: Message[], index: number): boolean +export function isPinned(messageOrMessages: Message | Message[], index?: number): boolean { + if (index === undefined) { + return (messageOrMessages as Message).metadata?.custom?.pinned === true + } + + const messages = messageOrMessages as Message[] + const msg = messages[index]! + if (msg.metadata?.custom?.pinned === true) return true + + const toolResultBlocks = msg.content.filter((b): b is ToolResultBlock => b.type === 'toolResultBlock') + if (toolResultBlocks.length > 0 && index > 0) { + const prev = messages[index - 1]! + if (prev.metadata?.custom?.pinned === true) { + const resultIds = new Set(toolResultBlocks.map((b) => b.toolUseId)) + if (prev.content.some((b) => b.type === 'toolUseBlock' && resultIds.has((b as ToolUseBlock).toolUseId))) { + return true + } + } + } + + const toolUseBlocks = msg.content.filter((b): b is ToolUseBlock => b.type === 'toolUseBlock') + if (toolUseBlocks.length > 0 && index + 1 < messages.length) { + const next = messages[index + 1]! + if (next.metadata?.custom?.pinned === true) { + const useIds = new Set(toolUseBlocks.map((b) => b.toolUseId)) + if (next.content.some((b) => b.type === 'toolResultBlock' && useIds.has((b as ToolResultBlock).toolUseId))) { + return true + } + } + } + + return false +} + +/** + * Returns a new Message marked as pinned (protected from eviction during context reduction). + */ +export function pinMessage(message: Message): Message { + return new Message({ + role: message.role, + content: message.content, + metadata: { + ...message.metadata, + custom: { ...message.metadata?.custom, pinned: true }, + }, + }) +} + +/** + * Returns a new Message with pinning removed. + */ +export function unpinMessage(message: Message): Message { + const { pinned: _, ...restCustom } = message.metadata?.custom ?? {} + const { custom: __, ...restMetadata } = message.metadata ?? {} + const hasCustom = Object.keys(restCustom).length > 0 + const hasMetadata = hasCustom || Object.keys(restMetadata).length > 0 + const metadata = hasMetadata ? { ...restMetadata, ...(hasCustom ? { custom: restCustom } : {}) } : undefined + + return new Message({ + role: message.role, + content: message.content, + ...(metadata !== undefined ? { metadata } : {}), + }) +} + +/** + * Agent-invokable tool that pins or unpins a message in the conversation history. + */ +export const pinMessageTool = tool({ + name: 'pin_message', + description: + 'Pin or unpin a message in the conversation history. ' + + 'Pinned messages are protected from eviction during context reduction. ' + + 'Use this to preserve important context that should not be summarized or trimmed away.', + inputSchema: z.object({ + index: z.number().int().min(0).describe('The zero-based index of the message in the conversation history.'), + action: z.enum(['pin', 'unpin']).default('pin').describe('Whether to pin or unpin the message.'), + }), + callback: ({ index, action }, context) => { + const messages = context!.agent.messages + if (index >= messages.length) { + return `Invalid index ${index}. Conversation has ${messages.length} messages (indices 0-${messages.length - 1}).` + } + messages[index] = action === 'pin' ? pinMessage(messages[index]!) : unpinMessage(messages[index]!) + return `${action === 'pin' ? 'Pinned' : 'Unpinned'} message at index ${index}.` + }, +}) + +// --- Protection check --- + +/** + * Check if a message at the given index is protected from eviction. + * A message is protected if it is pinned, within the protectFirst range, + * or is a tool-pair partner of a protected message. + */ +export function isProtected(messages: Message[], index: number, protectFirst?: number): boolean { + if (isPinned(messages, index)) return true + if (protectFirst === undefined || protectFirst <= 0) return false + if (index < protectFirst) return true + + // Tool-pair partner: protect a toolResult just outside the range if its + // preceding toolUse is inside the range (prevents orphaning) + const msg = messages[index]! + if (msg.content.some((b) => b.type === 'toolResultBlock') && index > 0 && index - 1 < protectFirst) { + const prev = messages[index - 1]! + if (prev.content.some((b) => b.type === 'toolUseBlock')) return true + } + + return false +} diff --git a/strands-ts/src/context-manager/compression/strategies/__tests__/summarize.test.ts b/strands-ts/src/context-manager/compression/strategies/__tests__/summarize.test.ts new file mode 100644 index 000000000..56d39fb3d --- /dev/null +++ b/strands-ts/src/context-manager/compression/strategies/__tests__/summarize.test.ts @@ -0,0 +1,227 @@ +import { describe, it, expect, vi } from 'vitest' +import { summarize } from '../summarize.js' +import { Message, TextBlock, ToolUseBlock, ToolResultBlock } from '../../../../types/messages.js' +import { pinMessage } from '../../protection.js' +import type { Model } from '../../../../models/model.js' + +function userMsg(text: string): Message { + return new Message({ role: 'user', content: [new TextBlock(text)] }) +} + +function assistantMsg(text: string): Message { + return new Message({ role: 'assistant', content: [new TextBlock(text)] }) +} + +function createMockModel(summaryText = '## Summary\n* Conversation summary'): Model { + const streamAggregated = vi.fn().mockImplementation(() => { + let callCount = 0 + return { + next: () => { + callCount++ + if (callCount === 1) { + return Promise.resolve({ + done: true, + value: { + message: new Message({ role: 'assistant', content: [new TextBlock(summaryText)] }), + stopReason: 'endTurn', + }, + }) + } + return Promise.resolve({ done: true, value: undefined }) + }, + } + }) + + return { streamAggregated } as unknown as Model +} + +describe('summarize', () => { + it('summarizes oldest messages and replaces them with summary', async () => { + const messages = [ + userMsg('Message 1'), + assistantMsg('Response 1'), + userMsg('Message 2'), + assistantMsg('Response 2'), + userMsg('Message 3'), + assistantMsg('Response 3'), + userMsg('Message 4'), + assistantMsg('Response 4'), + userMsg('Message 5'), + assistantMsg('Response 5'), + userMsg('Message 6'), + assistantMsg('Response 6'), + ] + const model = createMockModel() + + const result = await summarize(messages, model) + + expect(result).toBe(true) + // Default summaryRatio is 0.3, 12 * 0.3 = 3.6 -> 3 messages summarized + // But adjustSplitForToolPairs may adjust. Summary replaces those. + expect(messages.length).toBeLessThan(12) + // First message should be the summary (since no protected messages) + expect(messages[0]!.role).toBe('user') + }) + + it('preserves recent messages', async () => { + const messages = [ + userMsg('Old 1'), + assistantMsg('Old response 1'), + userMsg('Old 2'), + assistantMsg('Old response 2'), + userMsg('Old 3'), + assistantMsg('Old response 3'), + userMsg('Old 4'), + assistantMsg('Old response 4'), + userMsg('Old 5'), + assistantMsg('Old response 5'), + userMsg('Recent 1'), + assistantMsg('Recent response 1'), + ] + const model = createMockModel() + + await summarize(messages, model, { preserveRecentMessages: 10 }) + + // With 12 messages and preserveRecentMessages=10, count=max(1, floor(12*0.3))=3 + // count = min(3, 12-10)=2 -> summarize first 2 messages + // After: summary + 10 remaining = 11 + expect(messages.length).toBeLessThan(12) + // Recent messages should still be there + const lastMsg = messages[messages.length - 1]! + expect((lastMsg.content[0] as TextBlock).text).toBe('Recent response 1') + }) + + it('respects protectFirst (keeps protected messages verbatim)', async () => { + const messages = [ + userMsg('System instruction'), // index 0 - protected + assistantMsg('Acknowledged'), // index 1 - protected + userMsg('Old message'), + assistantMsg('Old response'), + userMsg('Message 3'), + assistantMsg('Response 3'), + userMsg('Message 4'), + assistantMsg('Response 4'), + userMsg('Message 5'), + assistantMsg('Response 5'), + userMsg('Recent'), + assistantMsg('Recent response'), + ] + const model = createMockModel() + + const result = await summarize(messages, model, { protectFirst: 2, preserveRecentMessages: 6 }) + + expect(result).toBe(true) + // Protected messages should be preserved verbatim at the start + expect((messages[0]!.content[0] as TextBlock).text).toBe('System instruction') + expect((messages[1]!.content[0] as TextBlock).text).toBe('Acknowledged') + }) + + it('returns false when insufficient messages to summarize', async () => { + const messages = [userMsg('Message 1'), assistantMsg('Response 1')] + const model = createMockModel() + + // preserveRecentMessages=10, count = min(floor(2*0.3)=0 -> max(1,0)=1, 2-10=-8) -> 0 -> returns false + const result = await summarize(messages, model, { preserveRecentMessages: 10 }) + + expect(result).toBe(false) + expect(messages).toHaveLength(2) + }) + + it('returns false when all messages in range are protected', async () => { + const messages = [ + pinMessage(userMsg('Pinned 1')), + pinMessage(assistantMsg('Pinned 2')), + userMsg('Message 3'), + assistantMsg('Response 3'), + userMsg('Message 4'), + assistantMsg('Response 4'), + userMsg('Message 5'), + assistantMsg('Response 5'), + userMsg('Message 6'), + assistantMsg('Response 6'), + userMsg('Recent'), + assistantMsg('Recent response'), + ] + const model = createMockModel() + + // protectFirst=2, summaryRatio=0.15 -> count=max(1, floor(12*0.15))=1 + // Only index 0 is in range, and it's pinned. hasUnprotected=false -> returns false + const result = await summarize(messages, model, { protectFirst: 2, summaryRatio: 0.08 }) + + expect(result).toBe(false) + }) + + it('summary is generated from ALL messages in range including protected', async () => { + const messages = [ + pinMessage(userMsg('Important context')), + assistantMsg('Response to important'), + userMsg('Regular message'), + assistantMsg('Regular response'), + userMsg('Message 3'), + assistantMsg('Response 3'), + userMsg('Message 4'), + assistantMsg('Response 4'), + userMsg('Message 5'), + assistantMsg('Response 5'), + userMsg('Recent'), + assistantMsg('Recent response'), + ] + + const streamAggregated = vi.fn().mockImplementation((_input: Message[]) => { + let callCount = 0 + return { + next: () => { + callCount++ + if (callCount === 1) { + return Promise.resolve({ + done: true, + value: { + message: new Message({ role: 'assistant', content: [new TextBlock('Summary')] }), + stopReason: 'endTurn', + }, + }) + } + return Promise.resolve({ done: true, value: undefined }) + }, + } + }) + const model = { streamAggregated } as unknown as Model + + await summarize(messages, model, { protectFirst: 1, summaryRatio: 0.3, preserveRecentMessages: 6 }) + + // streamAggregated should have been called with messages that include the pinned one + expect(streamAggregated).toHaveBeenCalled() + const inputMessages = streamAggregated.mock.calls[0]![0] as Message[] + // The input should include the pinned message at index 0 + const hasPinnedContent = inputMessages.some((m) => + m.content.some((b) => b.type === 'textBlock' && (b as TextBlock).text === 'Important context') + ) + expect(hasPinnedContent).toBe(true) + }) + + it('adjusts split point to avoid breaking tool pairs', async () => { + const messages = [ + userMsg('Message 1'), + assistantMsg('Response 1'), + new Message({ role: 'assistant', content: [new ToolUseBlock({ toolUseId: 't1', name: 'test', input: {} })] }), + new Message({ + role: 'user', + content: [new ToolResultBlock({ toolUseId: 't1', content: [new TextBlock('result')], status: 'success' })], + }), + userMsg('Message after tool'), + assistantMsg('Response after tool'), + userMsg('Message 3'), + assistantMsg('Response 3'), + userMsg('Message 4'), + assistantMsg('Response 4'), + userMsg('Recent'), + assistantMsg('Recent response'), + ] + const model = createMockModel() + + const result = await summarize(messages, model, { summaryRatio: 0.3, preserveRecentMessages: 4 }) + + expect(result).toBe(true) + // The split should not break the toolUse/toolResult pair at indices 2-3 + }) +}) diff --git a/strands-ts/src/context-manager/compression/strategies/__tests__/truncate.test.ts b/strands-ts/src/context-manager/compression/strategies/__tests__/truncate.test.ts new file mode 100644 index 000000000..4c78a6c10 --- /dev/null +++ b/strands-ts/src/context-manager/compression/strategies/__tests__/truncate.test.ts @@ -0,0 +1,152 @@ +import { describe, it, expect } from 'vitest' +import { truncate } from '../truncate.js' +import { Message, TextBlock, ToolUseBlock, ToolResultBlock } from '../../../../types/messages.js' +import { pinMessage } from '../../protection.js' + +function userMsg(text: string): Message { + return new Message({ role: 'user', content: [new TextBlock(text)] }) +} + +function assistantMsg(text: string): Message { + return new Message({ role: 'assistant', content: [new TextBlock(text)] }) +} + +function toolUseMsg(toolUseId: string): Message { + return new Message({ role: 'assistant', content: [new ToolUseBlock({ toolUseId, name: 'test', input: {} })] }) +} + +function toolResultMsg(toolUseId: string): Message { + return new Message({ + role: 'user', + content: [new ToolResultBlock({ toolUseId, content: [new TextBlock('result')], status: 'success' })], + }) +} + +describe('truncate', () => { + it('removes oldest messages when over window size', () => { + const messages = [ + userMsg('Message 1'), + assistantMsg('Response 1'), + userMsg('Message 2'), + assistantMsg('Response 2'), + userMsg('Message 3'), + assistantMsg('Response 3'), + ] + + const result = truncate(messages, 4) + + expect(result).toBe(true) + expect(messages).toHaveLength(4) + expect((messages[0]!.content[0] as TextBlock).text).toBe('Message 2') + }) + + it('preserves tool use/result pairs', () => { + const messages = [ + userMsg('Message 1'), + toolUseMsg('t1'), + toolResultMsg('t1'), + userMsg('Message 2'), + assistantMsg('Response 2'), + ] + + // Window size 2, trimIndex starts at 3 (5 - 2 = 3), which is user 'Message 2' - valid + const result = truncate(messages, 2) + + expect(result).toBe(true) + // Should trim at a point that doesn't break tool use/result pairs + expect(messages[0]!.role).toBe('user') + }) + + it('returns false when messages.length <= 2', () => { + const messages = [userMsg('Message 1'), assistantMsg('Response 1')] + + const result = truncate(messages, 1) + + expect(result).toBe(false) + expect(messages).toHaveLength(2) + }) + + it('returns false for single message', () => { + const messages = [userMsg('only')] + + const result = truncate(messages, 0) + + expect(result).toBe(false) + expect(messages).toHaveLength(1) + }) + + it('respects protectFirst', () => { + const messages = [ + userMsg('Protected 1'), + assistantMsg('Protected 2'), + userMsg('Message 3'), + assistantMsg('Response 3'), + userMsg('Message 4'), + assistantMsg('Response 4'), + ] + + const result = truncate(messages, 3, { protectFirst: 2 }) + + expect(result).toBe(true) + // First 2 messages should still be there + expect((messages[0]!.content[0] as TextBlock).text).toBe('Protected 1') + expect((messages[1]!.content[0] as TextBlock).text).toBe('Protected 2') + }) + + it('returns false when all messages in trim range are protected', () => { + const messages = [ + pinMessage(userMsg('Pinned 1')), + pinMessage(assistantMsg('Pinned 2')), + pinMessage(userMsg('Pinned 3')), + assistantMsg('Response'), + userMsg('Last'), + ] + + // trimIndex will be in range where all are pinned + const result = truncate(messages, 4) + + // The trim range [0, trimIndex) only contains pinned messages + // trimIndex = max(2, 5-4) = 2, range [0,2) are pinned -> returns false... + // Actually trimIndex = max(2, 5-4=1) = 2, but findValidTrimPoint may adjust. + // Let's see: trimIndex = max(messages.length - windowSize, 2) = max(1, 2) = 2 + // findValidTrimPoint(messages, 2): messages[2] is pinned user, role=user, no toolResult. + // So trimIndex = 2, range [0,2) -> indices 0,1 are pinned -> returns false + expect(result).toBe(false) + }) + + it('does not orphan toolResult at trim boundary', () => { + const messages = [ + userMsg('Start'), + toolUseMsg('t1'), + toolResultMsg('t1'), + userMsg('After tools'), + assistantMsg('Response'), + ] + + // windowSize=2, trimIndex = 5-2=3, message[3] is user 'After tools', no toolResult -> valid + const result = truncate(messages, 2) + + expect(result).toBe(true) + // First message after trim should not be a tool result + const firstMsg = messages[0]! + const hasToolResult = firstMsg.content.some((b) => b.type === 'toolResultBlock') + expect(hasToolResult).toBe(false) + }) + + it('skips assistant messages to find valid user trim point', () => { + const messages = [ + userMsg('Message 1'), + assistantMsg('Response 1'), + assistantMsg('Response 2'), // Non-user at potential trim index + userMsg('Message 2'), + assistantMsg('Response 3'), + ] + + // windowSize=2, trimIndex = 5-2=3, messages[3] is user 'Message 2' -> valid + const result = truncate(messages, 2) + + expect(result).toBe(true) + expect(messages[0]!.role).toBe('user') + expect((messages[0]!.content[0] as TextBlock).text).toBe('Message 2') + }) +}) diff --git a/strands-ts/src/context-manager/compression/strategies/summarize.ts b/strands-ts/src/context-manager/compression/strategies/summarize.ts new file mode 100644 index 000000000..6c20e2a3b --- /dev/null +++ b/strands-ts/src/context-manager/compression/strategies/summarize.ts @@ -0,0 +1,149 @@ +import { Message, TextBlock } from '../../../types/messages.js' +import type { Model } from '../../../models/model.js' +import { isProtected } from '../protection.js' +import { logger } from '../../../logging/logger.js' + +const SUMMARIZATION_PROMPT = `You are a conversation summarizer. Provide a concise summary of the conversation history. + +Format Requirements: +- You MUST create a structured and concise summary in bullet-point format. +- You MUST NOT respond conversationally. +- You MUST NOT address the user directly. +- You MUST NOT comment on tool availability. + +Assumptions: +- You MUST NOT assume tool executions failed unless otherwise stated. + +Task: +Your task is to create a structured summary document: +- It MUST contain bullet points with key topics and questions covered +- It MUST contain bullet points for all significant tools executed and their results +- It MUST contain bullet points for any code or technical information shared +- It MUST contain a section of key insights gained +- It MUST format the summary in the third person + +Example format: + +## Conversation Summary +* Topic 1: Key information +* Topic 2: Key information + +## Tools Executed +* Tool X: Result Y` + +export type SummarizeOptions = { + /** Ratio of messages to summarize (0.1–0.8). Defaults to 0.3. */ + summaryRatio?: number + /** Minimum recent messages to preserve. Defaults to 10. */ + preserveRecentMessages?: number + /** Positive: protect first N messages. Negative: protect last N messages. */ + protectFirst?: number +} + +/** + * Summarize the oldest messages and replace them with a model-generated summary. + * + * @param messages - The messages array to mutate in place + * @param model - The model to use for generating the summary + * @param options - Summarization options + * @returns `true` if messages were summarized, `false` if not enough to summarize + */ +export async function summarize(messages: Message[], model: Model, options?: SummarizeOptions): Promise { + const summaryRatio = Math.max(0.1, Math.min(0.8, options?.summaryRatio ?? 0.3)) + const preserveRecent = options?.preserveRecentMessages ?? 10 + const protectedRange = options?.protectFirst + + let count = Math.max(1, Math.floor(messages.length * summaryRatio)) + count = Math.min(count, messages.length - preserveRecent) + + if (count <= 0) { + logger.warn( + `preserve_recent=<${preserveRecent}>, messages=<${messages.length}> | insufficient messages for summarization` + ) + return false + } + + count = adjustSplitForToolPairs(messages, count) + + // Identify protected messages that must survive verbatim + const protectedToPreserve: Message[] = [] + let hasUnprotected = false + for (let i = 0; i < count; i++) { + if (isProtected(messages, i, protectedRange)) { + protectedToPreserve.push(messages[i]!) + } else { + hasUnprotected = true + } + } + + if (!hasUnprotected) { + logger.warn(`messages=<${messages.length}> | all messages in summarize range are protected, unable to reduce`) + return false + } + + // Summarize ALL messages in range (including protected) for full context + const summary = await generateSummary(messages.slice(0, count), model) + + // Replace range with protected messages (verbatim) + summary + messages.splice(0, count, ...protectedToPreserve, summary) + return true +} + +async function generateSummary(messagesToSummarize: Message[], model: Model): Promise { + const input = [ + ...messagesToSummarize, + new Message({ role: 'user', content: [new TextBlock('Please summarize this conversation.')] }), + ] + + const stream = model.streamAggregated(input, { systemPrompt: SUMMARIZATION_PROMPT }) + + let result: Awaited> | undefined + for (;;) { + result = await stream.next() + if (result.done) break + } + + if (!result?.done || !result.value) { + throw new Error('Failed to generate summary: no response from model') + } + + const summaryText = result.value.message.content + .filter((block) => block.type === 'textBlock') + .map((block) => (block as TextBlock).text) + .join('\n') + + return new Message({ + role: 'user', + content: [new TextBlock(`\n${summaryText}\n`)], + }) +} + +/** + * Adjust split point forward to avoid breaking tool use/result pairs. + */ +function adjustSplitForToolPairs(messages: Message[], splitPoint: number): number { + if (splitPoint >= messages.length) return splitPoint + + let idx = splitPoint + while (idx < messages.length) { + const msg = messages[idx]! + + if (msg.content.some((b) => b.type === 'toolResultBlock')) { + idx++ + continue + } + + const hasToolUse = msg.content.some((b) => b.type === 'toolUseBlock') + if (hasToolUse) { + const next = messages[idx + 1] + if (!next?.content.some((b) => b.type === 'toolResultBlock')) { + idx++ + continue + } + } + + break + } + + return idx >= messages.length ? splitPoint : idx +} diff --git a/strands-ts/src/context-manager/compression/strategies/truncate.ts b/strands-ts/src/context-manager/compression/strategies/truncate.ts new file mode 100644 index 000000000..cac684937 --- /dev/null +++ b/strands-ts/src/context-manager/compression/strategies/truncate.ts @@ -0,0 +1,85 @@ +import type { Message } from '../../../types/messages.js' +import { isProtected } from '../protection.js' +import { logger } from '../../../logging/logger.js' + +export type TruncateOptions = { + /** Positive: protect first N messages. Negative: protect last N messages. */ + protectFirst?: number +} + +/** + * Truncate oldest messages from the conversation, preserving tool use/result pairs. + * Protected messages (by range) are never removed. + * + * @param messages - The messages array to mutate in place + * @param windowSize - Maximum messages to keep + * @param options - Options including protectFirst + * @returns `true` if messages were removed, `false` if no valid trim point found + */ +export function truncate(messages: Message[], windowSize: number, options?: TruncateOptions): boolean { + if (messages.length <= 2) return false + + const protectedRange = options?.protectFirst + + let trimIndex = messages.length <= windowSize ? 2 : messages.length - windowSize + trimIndex = findValidTrimPoint(messages, trimIndex) + + if (trimIndex >= messages.length) { + logger.warn(`window_size=<${windowSize}>, messages=<${messages.length}> | unable to trim, no valid trim point`) + return false + } + + // Collect non-protected indices in [0, trimIndex) to remove + const indicesToRemove: number[] = [] + for (let i = 0; i < trimIndex; i++) { + if (isProtected(messages, i, protectedRange)) continue + indicesToRemove.push(i) + } + + if (indicesToRemove.length === 0) { + logger.warn( + `window_size=<${windowSize}>, messages=<${messages.length}> | all messages in trim range are protected, unable to reduce` + ) + return false + } + + // Remove in reverse order to keep indices stable + for (let i = indicesToRemove.length - 1; i >= 0; i--) { + messages.splice(indicesToRemove[i]!, 1) + } + return true +} + +/** + * Find a valid trim point starting from the given index. + * Skips positions that would leave orphaned toolResults or toolUse without a following toolResult. + */ +function findValidTrimPoint(messages: Message[], startIndex: number): number { + let idx = startIndex + while (idx < messages.length) { + const msg = messages[idx] + if (!msg) break + + if (msg.role !== 'user') { + idx++ + continue + } + + if (msg.content.some((b) => b.type === 'toolResultBlock')) { + idx++ + continue + } + + const hasToolUse = msg.content.some((b) => b.type === 'toolUseBlock') + if (hasToolUse) { + const next = messages[idx + 1] + if (!next || !next.content.some((b) => b.type === 'toolResultBlock')) { + idx++ + continue + } + } + + break + } + return idx +} diff --git a/strands-ts/src/context-manager/context-manager.ts b/strands-ts/src/context-manager/context-manager.ts new file mode 100644 index 000000000..f27624bcb --- /dev/null +++ b/strands-ts/src/context-manager/context-manager.ts @@ -0,0 +1,182 @@ +import type { Storage } from '../vended-plugins/context-offloader/storage.js' +import type { Plugin } from '../plugins/plugin.js' +import type { Tool } from '../tools/tool.js' +import type { LocalAgent } from '../types/agent.js' +import { ContextCompression } from './compression/context-compression.js' +import { ContextOffloader } from '../vended-plugins/context-offloader/plugin.js' +import { InMemoryStorage } from '../vended-plugins/context-offloader/storage.js' + +export type ContextStrategyValue = 'auto' + +/** + * Configuration for the offloader component. + */ +export type OffloaderConfig = { + /** Token threshold above which tool results are offloaded. Defaults to 2500. */ + threshold?: number + /** Number of tokens to keep as an inline preview. Defaults to 1500. */ + previewTokens?: number +} + +/** + * Compression configuration accepted by contextManager. + * - `true`: enable with defaults (truncate, proactive at 0.7). + * - `'truncate'` / `'summarize'`: enable specific strategy with defaults. + * - Object: full config with strategy and options. + * - Omitted: disabled. + */ +export type CompressionConfig = + | true + | import('./compression/context-compression.js').CompressionMethod + | import('./compression/context-compression.js').CompressionOptions + +/** + * Configuration accepted by the {@link ContextManager} constructor. + * + * Config objects are additive — only features you explicitly set are enabled. + * Use `"auto"` to enable everything with defaults. + */ +export type ContextManagerConfig = { + /** Strategy name. Only "auto" is supported currently. */ + strategy?: ContextStrategyValue + /** Storage backend for cached tool results. Defaults to InMemoryStorage. */ + storage?: Storage + /** + * Context offloader configuration. + * - `true`: enable with defaults (threshold=2500, previewTokens=500). + * - Object: enable with custom settings. + * - Omitted: disabled. + */ + offloader?: true | OffloaderConfig + /** + * Compression configuration. + * - `true`: enable with defaults (truncate, proactive at 0.7). + * - `'truncate'` / `'summarize'`: enable specific strategy with defaults. + * - `CompressionStrategy.Truncate(...)` / `CompressionStrategy.Summarize(...)`: full config. + * - Omitted: disabled. + */ + compression?: CompressionConfig +} + +/** + * The `contextManager` parameter type accepted by AgentConfig. + * + * - `"auto"`: enables everything with defaults. + * - `{ strategy: 'auto', ... }`: auto with overrides (omitted features stay enabled). + * - `{ compression: true }`: additive — only what you set is enabled. + * - `undefined` (default): no context management facade. + */ +export type ContextManagerParam = ContextStrategyValue | ContextManagerConfig + +/** + * Pre-composed context management for agents. + * + * Internal plugin that composes sub-plugins (ContextCompression, ContextOffloader) + * for compression and caching behavior. + * + * @example + * ```typescript + * const agent = new Agent({ contextManager: "auto" }) + * ``` + */ +export class ContextManager implements Plugin { + readonly name = 'strands:context-manager' + readonly storage: Storage + + private readonly _config: ContextManagerConfig + private _subPlugins: Plugin[] | undefined + + constructor(config?: ContextManagerConfig) { + this._config = config ?? {} + this.storage = this._config.storage ?? new InMemoryStorage() + } + + /** + * Resolve sub-plugins, skipping any that the user already provides. + * Called once before plugin initialization. + * @internal + */ + _resolveSubPlugins(userPlugins?: Plugin[]): void { + this._subPlugins = this._buildSubPlugins(userPlugins) + } + + getTools(): Tool[] { + const plugins = this._subPlugins ?? [] + const tools: Tool[] = [] + for (const plugin of plugins) { + if (plugin.getTools) { + tools.push(...plugin.getTools()) + } + } + return tools + } + + initAgent(agent: LocalAgent): void { + if (!this._subPlugins) { + this._subPlugins = this._buildSubPlugins() + } + + for (const plugin of this._subPlugins) { + plugin.initAgent(agent) + } + } + + private _buildSubPlugins(userPlugins?: Plugin[]): Plugin[] { + const config = this._config + const plugins: Plugin[] = [] + + if (config.compression) { + const userProvided = userPlugins?.some((p) => p.name === 'strands:context-compression') + if (!userProvided) { + let compressionConfig: import('./compression/context-compression.js').CompressionOptions | undefined + if (config.compression === true) { + compressionConfig = undefined + } else if (typeof config.compression === 'string') { + compressionConfig = { method: config.compression } + } else { + compressionConfig = config.compression + } + plugins.push(new ContextCompression(compressionConfig)) + } + } + + if (config.offloader) { + const userProvided = userPlugins?.some((p) => p.name === 'strands:context-offloader') + if (!userProvided) { + const offloaderConfig = config.offloader === true ? {} : config.offloader + plugins.push( + new ContextOffloader({ + storage: this.storage, + maxResultTokens: offloaderConfig.threshold ?? 2500, + previewTokens: offloaderConfig.previewTokens ?? 1500, + includeRetrievalTool: true, + }) + ) + } + } + + return plugins + } +} + +/** + * Resolve a `contextManager` parameter into a ContextManager plugin instance. + * User-provided plugins that overlap with sub-plugins take precedence. + * + * @param param - The contextManager config (strategy string, config object, or class instance) + * @param userPlugins - User-provided plugins array, used for dedup checking + * @internal + */ +const STRATEGY_DEFAULTS = { + auto: { compression: true, offloader: true }, +} satisfies Record> + +export function resolveContextManager(param: ContextManagerParam, userPlugins?: Plugin[]): ContextManager { + const base = typeof param === 'string' ? { strategy: param } : param + const defaults = base.strategy ? STRATEGY_DEFAULTS[base.strategy] : undefined + const config = defaults ? { ...defaults, ...base } : base + + const instance = new ContextManager(config) + instance._resolveSubPlugins(userPlugins) + return instance +} diff --git a/strands-ts/src/context-manager/token-estimation.ts b/strands-ts/src/context-manager/token-estimation.ts new file mode 100644 index 000000000..204bb61f6 --- /dev/null +++ b/strands-ts/src/context-manager/token-estimation.ts @@ -0,0 +1,31 @@ +import type { Message } from '../types/messages.js' +import type { Model } from '../models/model.js' +import { logger } from '../logging/logger.js' + +/** + * Estimate input tokens for a conversation. + * + * Uses an incremental strategy: if the last assistant message has usage metadata, + * uses (inputTokens + outputTokens) as a baseline and only counts new messages + * added after it. Falls back to full model estimation otherwise. + * + * @param messages - The conversation messages + * @param model - The model to use for token counting + * @returns Estimated token count, or undefined if estimation fails + */ +export async function estimateInputTokens(messages: Message[], model: Model): Promise { + try { + for (let i = messages.length - 1; i >= 0; i--) { + const usage = messages[i]!.metadata?.usage + if (messages[i]!.role === 'assistant' && usage) { + const baseline = usage.inputTokens + usage.outputTokens + const newMessages = messages.slice(i + 1) + return newMessages.length === 0 ? baseline : baseline + (await model.countTokens(newMessages)) + } + } + return await model.countTokens(messages) + } catch (e) { + logger.debug(`error=<${e}> | token estimation failed`) + return undefined + } +} diff --git a/strands-ts/src/index.ts b/strands-ts/src/index.ts index 46b1022c0..bad13b255 100644 --- a/strands-ts/src/index.ts +++ b/strands-ts/src/index.ts @@ -19,6 +19,10 @@ export type { ToolCaller, ToolCallerProxy, ToolHandle, DirectToolCallOptions } f export type { InvocationState, InvokeArgs, InvokeOptions, LocalAgent } from './types/agent.js' export type { LifecycleObserver } from './types/lifecycle-observer.js' +// Context Manager +export type { ContextManagerParam } from './context-manager/context-manager.js' +export { pinMessageTool } from './context-manager/compression/protection.js' + // Snapshot types export { SNAPSHOT_SCHEMA_VERSION } from './types/snapshot.js' export type { Scope, Snapshot } from './types/snapshot.js'