Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/vs/platform/agentHost/common/agentService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ export interface IAgentSessionMetadata {
readonly project?: IAgentSessionProjectInfo;
readonly summary?: string;
readonly status?: SessionStatus;
readonly model?: string;
readonly workingDirectory?: URI;
readonly isRead?: boolean;
readonly isDone?: boolean;
Expand Down
5 changes: 4 additions & 1 deletion src/vs/platform/agentHost/node/agentService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ export class AgentService extends Disposable implements IAgentService {
const withStatus = result.map(s => {
const liveState = this._stateManager.getSessionState(s.session.toString());
if (liveState) {
return { ...s, status: liveState.summary.status };
return { ...s, status: liveState.summary.status, model: liveState.summary.model ?? s.model };
}
return s;
});
Expand Down Expand Up @@ -230,6 +230,7 @@ export class AgentService extends Disposable implements IAgentService {
createdAt: Date.now(),
modifiedAt: Date.now(),
...(created.project ? { project: { uri: created.project.uri.toString(), displayName: created.project.displayName } } : {}),
model: config?.model,
workingDirectory: config.workingDirectory?.toString(),
};
const state = this._stateManager.createSession(summary);
Expand All @@ -244,6 +245,7 @@ export class AgentService extends Disposable implements IAgentService {
createdAt: Date.now(),
modifiedAt: Date.now(),
...(created.project ? { project: { uri: created.project.uri.toString(), displayName: created.project.displayName } } : {}),
model: config?.model,
workingDirectory: config?.workingDirectory?.toString(),
};
this._stateManager.createSession(summary);
Expand Down Expand Up @@ -423,6 +425,7 @@ export class AgentService extends Disposable implements IAgentService {
createdAt: meta.startTime,
modifiedAt: meta.modifiedTime,
...(meta.project ? { project: { uri: meta.project.uri.toString(), displayName: meta.project.displayName } } : {}),
model: meta.model,
workingDirectory: meta.workingDirectory?.toString(),
isRead,
isDone,
Expand Down
12 changes: 7 additions & 5 deletions src/vs/platform/agentHost/node/copilot/copilotAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ export class CopilotAgent extends Disposable implements IAgent {
const projectByContext = new Map<string, Promise<IAgentSessionProjectInfo | undefined>>();
const result: IAgentSessionMetadata[] = await Promise.all(sessions.map(async s => {
const session = AgentSession.uri(this.id, s.sessionId);
const metadata = await this._readSessionMetadata(session);
let { project, resolved } = await this._readSessionProject(session);
if (!resolved) {
project = await this._resolveSessionProject(s.context, projectLimiter, projectByContext);
Expand All @@ -177,6 +178,7 @@ export class CopilotAgent extends Disposable implements IAgent {
modifiedTime: s.modifiedTime.getTime(),
...(project ? { project } : {}),
summary: s.summary,
model: metadata.model,
workingDirectory: typeof s.context?.cwd === 'string' ? URI.file(s.context.cwd) : undefined,
};
}));
Expand Down Expand Up @@ -233,7 +235,7 @@ export class CopilotAgent extends Disposable implements IAgent {
const session = agentSession.sessionUri;
this._logService.info(`[Copilot] Forked session created: ${session.toString()}`);
const project = await projectFromCopilotContext({ cwd: config.workingDirectory?.fsPath });
this._storeSessionMetadata(session, undefined, config.workingDirectory, project, true);
this._storeSessionMetadata(session, config.model, config.workingDirectory, project, true);
return { session, ...(project ? { project } : {}) };
});
}
Expand Down Expand Up @@ -467,11 +469,12 @@ export class CopilotAgent extends Disposable implements IAgent {
const parsedPlugins = await this._plugins.getAppliedPlugins();

const sessionUri = AgentSession.uri(this.id, sessionId);
const storedMetadata = await this._readSessionMetadata(sessionUri);
const sessionMetadata = await client.getSessionMetadata(sessionId).catch(err => {
this._logService.warn(`[Copilot:${sessionId}] getSessionMetadata failed`, err);
return undefined;
});
const workingDirectory = typeof sessionMetadata?.context?.cwd === 'string' ? URI.file(sessionMetadata.context.cwd) : undefined;
const workingDirectory = typeof sessionMetadata?.context?.cwd === 'string' ? URI.file(sessionMetadata.context.cwd) : storedMetadata.workingDirectory;
const shellManager = this._instantiationService.createInstance(ShellManager, sessionUri);
const sessionConfig = this._buildSessionConfig(parsedPlugins, shellManager);

Expand All @@ -492,13 +495,12 @@ export class CopilotAgent extends Disposable implements IAgent {
}

this._logService.warn(`[Copilot:${sessionId}] Resume failed (session not found in SDK), recreating`);
const metadata = await this._readSessionMetadata(sessionUri);
const raw = await client.createSession({
...config,
sessionId,
streaming: true,
model: metadata.model,
workingDirectory: metadata.workingDirectory?.fsPath,
model: storedMetadata.model,
workingDirectory: workingDirectory?.fsPath,
});

return new CopilotSessionWrapper(raw);
Expand Down
1 change: 1 addition & 0 deletions src/vs/platform/agentHost/node/protocolServerHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ export class ProtocolServerHandler extends Disposable {
createdAt: s.startTime,
modifiedAt: s.modifiedTime,
...(s.project ? { project: { uri: s.project.uri.toString(), displayName: s.project.displayName } } : {}),
model: s.model,
workingDirectory: s.workingDirectory?.toString(),
isRead: s.isRead,
isDone: s.isDone,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import assert from 'assert';
import { timeout } from '../../../../../base/common/async.js';
import { ISubscribeResult } from '../../../common/state/protocol/commands.js';
import type { IResponsePartAction, ISessionAddedNotification, ITitleChangedAction } from '../../../common/state/sessionActions.js';
import type { IModelChangedAction, IResponsePartAction, ISessionAddedNotification, ITitleChangedAction } from '../../../common/state/sessionActions.js';
import { PROTOCOL_VERSION } from '../../../common/state/sessionCapabilities.js';
import type { IListSessionsResult, INotificationBroadcastParams } from '../../../common/state/sessionProtocol.js';
import { PendingMessageKind, ResponsePartKind, type ISessionState } from '../../../common/state/sessionState.js';
Expand Down Expand Up @@ -118,6 +118,51 @@ suite('Protocol WebSocket — Session Features', function () {
assert.strictEqual(session.title, 'Persisted Title');
});

// ---- Session model --------------------------------------------------------

test('session model flows through create, subscribe, listSessions, and modelChanged', async function () {
this.timeout(10_000);

await client.call('initialize', { protocolVersion: PROTOCOL_VERSION, clientId: 'test-model-summary' });

const sessionUri = nextSessionUri();
await client.call('createSession', { session: sessionUri, provider: 'mock', model: 'mock-model' });

const addedNotif = await client.waitForNotification(n =>
n.method === 'notification' && (n.params as INotificationBroadcastParams).notification.type === 'notify/sessionAdded'
);
const addedSession = (addedNotif.params as INotificationBroadcastParams).notification as ISessionAddedNotification;
assert.strictEqual(addedSession.summary.model, 'mock-model');
const createdSessionUri = addedSession.summary.resource;

const initialSnapshot = await client.call<ISubscribeResult>('subscribe', { resource: createdSessionUri });
const initialState = initialSnapshot.snapshot.state as ISessionState;
assert.strictEqual(initialState.summary.model, 'mock-model');

const initialList = await client.call<IListSessionsResult>('listSessions');
assert.strictEqual(initialList.items.find(s => s.resource === createdSessionUri)?.model, 'mock-model');

client.notify('dispatchAction', {
clientSeq: 1,
action: {
type: 'session/modelChanged',
session: createdSessionUri,
model: 'mock-model-2',
},
});

const modelNotif = await client.waitForNotification(n => isActionNotification(n, 'session/modelChanged'));
const modelAction = getActionEnvelope(modelNotif).action as IModelChangedAction;
assert.strictEqual(modelAction.model, 'mock-model-2');

const updatedSnapshot = await client.call<ISubscribeResult>('subscribe', { resource: createdSessionUri });
const updatedState = updatedSnapshot.snapshot.state as ISessionState;
assert.strictEqual(updatedState.summary.model, 'mock-model-2');

const updatedList = await client.call<IListSessionsResult>('listSessions');
assert.strictEqual(updatedList.items.find(s => s.resource === createdSessionUri)?.model, 'mock-model-2');
});

// ---- Reasoning events ------------------------------------------------------

test('reasoning events produce reasoning response parts and append actions', async function () {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import { IChatInputPickerOptions } from '../../../../workbench/contrib/chat/brow
import { EnhancedModelPickerActionItem } from '../../../../workbench/contrib/chat/browser/widget/input/modelPickerActionItem2.js';
import { HoverPosition } from '../../../../base/browser/ui/hover/hoverWidget.js';
import { IContextKeyService, ContextKeyExpr } from '../../../../platform/contextkey/common/contextkey.js';
import { IStorageService, StorageScope, StorageTarget } from '../../../../platform/storage/common/storage.js';
import { Menus } from '../../../browser/menus.js';
import { ISessionsManagementService } from '../../../services/sessions/common/sessionsManagement.js';
import { ISessionsProvidersService } from '../../../services/sessions/browser/sessionsProvidersService.js';
Expand Down Expand Up @@ -187,7 +186,6 @@ class CopilotPickerActionViewItemContribution extends Disposable implements IWor
@ILanguageModelsService languageModelsService: ILanguageModelsService,
@ISessionsManagementService sessionsManagementService: ISessionsManagementService,
@ISessionsProvidersService sessionsProvidersService: ISessionsProvidersService,
@IStorageService storageService: IStorageService,
) {
super();

Expand Down Expand Up @@ -219,8 +217,6 @@ class CopilotPickerActionViewItemContribution extends Disposable implements IWor
const delegate: IModelPickerDelegate = {
currentModel,
setModel: (model: ILanguageModelChatMetadataAndIdentifier) => {
currentModel.set(model, undefined);
storageService.store('sessions.localModelPicker.selectedModelId', model.identifier, StorageScope.PROFILE, StorageTarget.MACHINE);
const session = sessionsManagementService.activeSession.get();
if (session) {
const provider = sessionsProvidersService.getProviders().find(p => p.id === session.providerId);
Expand All @@ -240,29 +236,24 @@ class CopilotPickerActionViewItemContribution extends Disposable implements IWor
const action = { id: 'sessions.modelPicker', label: '', enabled: true, class: undefined, tooltip: '', run: () => { } };
const modelPicker = instantiationService.createInstance(EnhancedModelPickerActionItem, action, delegate, pickerOptions);

// Initialize with remembered model or first available model
const rememberedModelId = storageService.get('sessions.localModelPicker.selectedModelId', StorageScope.PROFILE);
const initModel = () => {
const updatePickerModel = (session: ISession | undefined, sessionModelId: string | undefined) => {
const models = getAvailableModels(languageModelsService, sessionsManagementService);
modelPicker.setEnabled(models.length > 0);
if (!currentModel.get() && models.length > 0) {
const remembered = rememberedModelId ? models.find(m => m.identifier === rememberedModelId) : undefined;
delegate.setModel(remembered ?? models[0]);
}
currentModel.set(sessionModelId ? models.find(model => model.identifier === sessionModelId) : undefined, undefined);
};
Comment thread
roblourens marked this conversation as resolved.
initModel();
const updatePickerModelFromActiveSession = () => {
const session = sessionsManagementService.activeSession.get();
updatePickerModel(session, session?.modelId.get());
};
updatePickerModelFromActiveSession();

const disposableStore = new DisposableStore();
disposableStore.add(languageModelsService.onDidChangeLanguageModels(() => initModel()));
disposableStore.add(languageModelsService.onDidChangeLanguageModels(() => updatePickerModelFromActiveSession()));

// When the active session changes, push the selected model to the new session
disposableStore.add(autorun(reader => {
const session = sessionsManagementService.activeSession.read(reader);
const model = currentModel.read(reader);
if (session && model) {
const provider = sessionsProvidersService.getProviders().find(p => p.id === session.providerId);
provider?.setModel(session.sessionId, model.identifier);
}
const sessionModelId = session?.modelId.read(reader);
updatePickerModel(session, sessionModelId);
}));

return new PickerActionViewItem(modelPicker, disposableStore);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class LocalSessionAdapter implements ISession {
readonly updatedAt: ISettableObservable<Date>;
readonly status = observableValue<SessionStatus>('status', SessionStatus.Completed);
readonly changes = observableValue<readonly IChatSessionFileChange[]>('changes', []);
readonly modelId = observableValue<string | undefined>('modelId', undefined);
readonly modelId: ISettableObservable<string | undefined>;
readonly mode = observableValue<{ readonly id: string; readonly kind: string } | undefined>('mode', undefined);
readonly loading = observableValue('loading', false);
readonly isArchived = observableValue('isArchived', false);
Expand Down Expand Up @@ -89,6 +89,7 @@ class LocalSessionAdapter implements ISession {
this.createdAt = new Date(metadata.startTime);
this.title = observableValue('title', metadata.summary ?? `Session ${rawId.substring(0, 8)}`);
this.updatedAt = observableValue('updatedAt', new Date(metadata.modifiedTime));
this.modelId = observableValue<string | undefined>('modelId', metadata.model ? `${logicalSessionType}:${metadata.model}` : undefined);
this.lastTurnEnd = observableValue('lastTurnEnd', metadata.modifiedTime ? new Date(metadata.modifiedTime) : undefined);
this.description = observableValue('description', new MarkdownString().appendText(localize('localAgentHostDescription', "Local")));
this.workspace = observableValue('workspace', LocalAgentHostSessionsProvider.buildWorkspace(metadata.project, metadata.workingDirectory));
Expand Down Expand Up @@ -155,6 +156,12 @@ class LocalSessionAdapter implements ISession {
didChange = true;
}

const modelId = metadata.model ? `${this.sessionType}:${metadata.model}` : undefined;
if (modelId !== this.modelId.get()) {
this.modelId.set(modelId, undefined);
didChange = true;
}

return didChange;
}
}
Expand Down Expand Up @@ -209,6 +216,7 @@ export class LocalAgentHostSessionsProvider extends Disposable implements ISessi
private _selectedModelId: string | undefined;
private _currentNewSession: ISession | undefined;
private _currentNewSessionStatus: ISettableObservable<SessionStatus> | undefined;
private _currentNewSessionModelId: ISettableObservable<string | undefined> | undefined;

private _cacheInitialized = false;

Expand Down Expand Up @@ -245,6 +253,8 @@ export class LocalAgentHostSessionsProvider extends Disposable implements ISessi
this._refreshSessions();
} else if (e.action.type === ActionType.SessionTitleChanged && isSessionAction(e.action)) {
this._handleTitleChanged(e.action.session, e.action.title);
} else if (e.action.type === ActionType.SessionModelChanged && isSessionAction(e.action)) {
this._handleModelChanged(e.action.session, e.action.model);
} else if (e.action.type === ActionType.SessionIsReadChanged && isSessionAction(e.action)) {
this._handleIsReadChanged(e.action.session, e.action.isRead);
} else if (e.action.type === ActionType.SessionIsDoneChanged && isSessionAction(e.action)) {
Expand Down Expand Up @@ -376,6 +386,7 @@ export class LocalAgentHostSessionsProvider extends Disposable implements ISessi

this._currentNewSession = undefined;
this._selectedModelId = undefined;
this._currentNewSessionModelId = undefined;

const defaultType = this.sessionTypes[0];
if (!defaultType) {
Expand Down Expand Up @@ -429,6 +440,7 @@ export class LocalAgentHostSessionsProvider extends Disposable implements ISessi
};
this._currentNewSession = session;
this._currentNewSessionStatus = status;
this._currentNewSessionModelId = modelId;
return session;
}

Expand Down Expand Up @@ -458,6 +470,18 @@ export class LocalAgentHostSessionsProvider extends Disposable implements ISessi
setModel(sessionId: string, modelId: string): void {
if (this._currentNewSession?.sessionId === sessionId) {
this._selectedModelId = modelId;
this._currentNewSessionModelId?.set(modelId, undefined);
return;
}

const rawId = this._rawIdFromChatId(sessionId);
const cached = rawId ? this._sessionCache.get(rawId) : undefined;
if (cached && rawId) {
cached.modelId.set(modelId, undefined);
this._onDidChangeSessions.fire({ added: [], removed: [], changed: [cached] });
const rawModelId = modelId.startsWith(`${cached.sessionType}:`) ? modelId.substring(cached.sessionType.length + 1) : modelId;
const action = { type: ActionType.SessionModelChanged as const, session: AgentSession.uri(cached.agentProvider, rawId).toString(), model: rawModelId };
this._agentHostService.dispatch(action);
}
}

Expand Down Expand Up @@ -581,11 +605,13 @@ export class LocalAgentHostSessionsProvider extends Disposable implements ISessi

this._selectedModelId = undefined;
this._currentNewSessionStatus = undefined;
this._currentNewSessionModelId = undefined;

try {
const committedSession = await this._waitForNewSession(existingKeys);
if (committedSession) {
this._currentNewSession = undefined;
this._currentNewSessionModelId = undefined;
this._onDidReplaceSession.fire({ from: newSession, to: committedSession });
return committedSession;
}
Expand All @@ -596,6 +622,7 @@ export class LocalAgentHostSessionsProvider extends Disposable implements ISessi
}

this._currentNewSession = undefined;
this._currentNewSessionModelId = undefined;
return newSession;
}

Expand Down Expand Up @@ -676,7 +703,7 @@ export class LocalAgentHostSessionsProvider extends Disposable implements ISessi
}
}

private _handleSessionAdded(summary: { resource: string; provider: string; title: string; createdAt: number; modifiedAt: number; project?: { uri: string; displayName: string }; workingDirectory?: string; isRead?: boolean; isDone?: boolean }): void {
private _handleSessionAdded(summary: { resource: string; provider: string; title: string; createdAt: number; modifiedAt: number; project?: { uri: string; displayName: string }; model?: string; workingDirectory?: string; isRead?: boolean; isDone?: boolean }): void {
const sessionUri = URI.parse(summary.resource);
const rawId = AgentSession.id(sessionUri);
if (this._sessionCache.has(rawId)) {
Expand All @@ -692,6 +719,7 @@ export class LocalAgentHostSessionsProvider extends Disposable implements ISessi
modifiedTime: summary.modifiedAt,
summary: summary.title,
...(summary.project ? { project: { uri: URI.parse(summary.project.uri), displayName: summary.project.displayName } } : {}),
model: summary.model,
workingDirectory: workingDir,
isRead: summary.isRead,
isDone: summary.isDone,
Expand Down Expand Up @@ -725,6 +753,16 @@ export class LocalAgentHostSessionsProvider extends Disposable implements ISessi
}
}

private _handleModelChanged(session: string, model: string): void {
const rawId = AgentSession.id(session);
const cached = this._sessionCache.get(rawId);
const modelId = cached ? `${cached.sessionType}:${model}` : undefined;
if (cached && cached.modelId.get() !== modelId) {
cached.modelId.set(modelId, undefined);
this._onDidChangeSessions.fire({ added: [], removed: [], changed: [cached] });
}
}

private _handleIsReadChanged(session: string, isRead: boolean): void {
const rawId = AgentSession.id(session);
const cached = this._sessionCache.get(rawId);
Expand Down
Loading
Loading