diff --git a/src/models/generative_models.ts b/src/models/generative_models.ts index 1e2820e9..0bc9bbf6 100644 --- a/src/models/generative_models.ts +++ b/src/models/generative_models.ts @@ -249,7 +249,7 @@ export class GenerativeModel { * @param request - {@link StartChatParams} * @returns {@link ChatSession} */ - startChat(request?: StartChatParams): ChatSession { + async startChat(request?: StartChatParams): Promise { const startChatRequest: StartChatSessionRequest = { project: this.project, location: this.location, @@ -273,6 +273,7 @@ export class GenerativeModel { startChatRequest.systemInstruction = request.systemInstruction ?? this.systemInstruction; } + await this.fetchToken(); return new ChatSession(startChatRequest, this.requestOptions); } } diff --git a/src/models/test/models_test.ts b/src/models/test/models_test.ts index 0f07413d..0bb9ef1a 100644 --- a/src/models/test/models_test.ts +++ b/src/models/test/models_test.ts @@ -260,25 +260,25 @@ class ChatSessionPreviewForTest extends ChatSessionPreview { } describe('GenerativeModel startChat', () => { - it('returns ChatSession when pass no arg', () => { + it('returns ChatSession when pass no arg', async () => { const model = new GenerativeModel({ model: 'gemini-pro', project: PROJECT, location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chat = model.startChat(); + const chat = await model.startChat(); expect(chat).toBeInstanceOf(ChatSession); }); - it('returns ChatSession when pass an arg', () => { + it('returns ChatSession when pass an arg', async () => { const model = new GenerativeModel({ model: 'gemini-pro', project: PROJECT, location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chat = model.startChat({ + const chat = await model.startChat({ history: TEST_USER_CHAT_MESSAGE, }); @@ -295,7 +295,7 @@ describe('GenerativeModel startChat', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chat = model.startChat({ + const chat = await model.startChat({ generationConfig: { ...TEST_GENERATION_CONFIG, responseMimeType: 'application/json', @@ -304,7 +304,7 @@ describe('GenerativeModel startChat', () => { expect(chat).toBeInstanceOf(ChatSession); }); - it('set timeout info in ChatSession', () => { + it('set timeout info in ChatSession', async () => { const model = new GenerativeModel({ model: 'gemini-pro', project: PROJECT, @@ -312,9 +312,9 @@ describe('GenerativeModel startChat', () => { googleAuth: FAKE_GOOGLE_AUTH, requestOptions: TEST_REQUEST_OPTIONS, }); - const chat = model.startChat({ + const chat = (await model.startChat({ history: TEST_USER_CHAT_MESSAGE, - }) as ChatSessionForTest; + })) as ChatSessionForTest; expect(chat.requestOptions).toEqual(TEST_REQUEST_OPTIONS); }); @@ -332,7 +332,7 @@ describe('GenerativeModel startChat', () => { googleAuth: FAKE_GOOGLE_AUTH, tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, }); - const chat = model.startChat({ + const chat = await model.startChat({ history: TEST_USER_CHAT_MESSAGE, }); const expectedBody = @@ -356,7 +356,7 @@ describe('GenerativeModel startChat', () => { googleAuth: FAKE_GOOGLE_AUTH, tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, }); - const chat = model.startChat({ + const chat = await model.startChat({ history: TEST_USER_CHAT_MESSAGE, tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, }); @@ -381,7 +381,7 @@ describe('GenerativeModel startChat', () => { googleAuth: FAKE_GOOGLE_AUTH, systemInstruction: TEST_SYSTEM_INSTRUCTION, }); - const chat = model.startChat({ + const chat = await model.startChat({ history: TEST_USER_CHAT_MESSAGE, }); const expectedBody = @@ -408,7 +408,7 @@ describe('GenerativeModel startChat', () => { googleAuth: FAKE_GOOGLE_AUTH, systemInstruction: TEST_SYSTEM_INSTRUCTION_1, }); - const chat = model.startChat({ + const chat = await model.startChat({ history: TEST_USER_CHAT_MESSAGE, // this is different from constructor systemInstruction: TEST_SYSTEM_INSTRUCTION, @@ -437,7 +437,7 @@ describe('GenerativeModel startChat', () => { googleAuth: FAKE_GOOGLE_AUTH, systemInstruction: TEST_SYSTEM_INSTRUCTION_WRONG_ROLE, }); - const chat = model.startChat({ + const chat = await model.startChat({ history: TEST_USER_CHAT_MESSAGE, }); const expectedBody = @@ -464,7 +464,7 @@ describe('GenerativeModel startChat', () => { googleAuth: FAKE_GOOGLE_AUTH, systemInstruction: TEST_SYSTEM_INSTRUCTION_1, }); - const chat = model.startChat({ + const chat = await model.startChat({ history: TEST_USER_CHAT_MESSAGE, // this is different from constructor systemInstruction: TEST_SYSTEM_INSTRUCTION_WRONG_ROLE, @@ -482,25 +482,25 @@ describe('GenerativeModel startChat', () => { }); describe('GenerativeModelPreview startChat', () => { - it('returns ChatSessionPreview when pass no arg', () => { + it('returns ChatSessionPreview when pass no arg', async () => { const model = new GenerativeModelPreview({ model: 'gemini-pro', project: PROJECT, location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chat = model.startChat(); + const chat = await model.startChat(); expect(chat).toBeInstanceOf(ChatSessionPreview); }); - it('returns ChatSessionPreview when pass an arg', () => { + it('returns ChatSessionPreview when pass an arg', async () => { const model = new GenerativeModelPreview({ model: 'gemini-pro', project: PROJECT, location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chat = model.startChat({ + const chat = await model.startChat({ history: TEST_USER_CHAT_MESSAGE, }); @@ -513,7 +513,7 @@ describe('GenerativeModelPreview startChat', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chat = model.startChat({ + const chat = await model.startChat({ generationConfig: { ...TEST_GENERATION_CONFIG, responseMimeType: 'application/json', @@ -550,7 +550,7 @@ describe('GenerativeModelPreview startChat', () => { googleAuth: FAKE_GOOGLE_AUTH, tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, }); - const chat = model.startChat({ + const chat = await model.startChat({ history: TEST_USER_CHAT_MESSAGE, }); const expectedBody = @@ -599,7 +599,7 @@ describe('GenerativeModelPreview startChat', () => { googleAuth: FAKE_GOOGLE_AUTH, systemInstruction: TEST_SYSTEM_INSTRUCTION, }); - const chat = model.startChat({ + const chat = await model.startChat({ history: TEST_USER_CHAT_MESSAGE, }); const expectedBody = @@ -655,7 +655,7 @@ describe('GenerativeModelPreview startChat', () => { googleAuth: FAKE_GOOGLE_AUTH, systemInstruction: TEST_SYSTEM_INSTRUCTION_WRONG_ROLE, }); - const chat = model.startChat({ + const chat = await model.startChat({ history: TEST_USER_CHAT_MESSAGE, }); const expectedBody = @@ -682,7 +682,7 @@ describe('GenerativeModelPreview startChat', () => { googleAuth: FAKE_GOOGLE_AUTH, systemInstruction: TEST_SYSTEM_INSTRUCTION_1, }); - const chat = model.startChat({ + const chat = await model.startChat({ history: TEST_USER_CHAT_MESSAGE, // this is different from constructor systemInstruction: TEST_SYSTEM_INSTRUCTION_WRONG_ROLE, @@ -2266,12 +2266,12 @@ describe('ChatSession', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - chatSession = model.startChat({ + chatSession = await model.startChat({ history: TEST_USER_CHAT_MESSAGE, }); expect(await chatSession.getHistory()).toEqual(TEST_USER_CHAT_MESSAGE); - chatSessionWithNoArgs = model.startChat(); - chatSessionWithFunctionCall = model.startChat({ + chatSessionWithNoArgs = await model.startChat(); + chatSessionWithFunctionCall = await model.startChat({ tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, }); }); @@ -2301,7 +2301,7 @@ describe('ChatSession', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chatSessionWithShortName = modelWithShortName.startChat(); + const chatSessionWithShortName = await modelWithShortName.startChat(); const req = TEST_CHAT_MESSSAGE_TEXT; const generateContentSpy = spyOn( GenerateContentFunctions, @@ -2326,7 +2326,7 @@ describe('ChatSession', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chatSessionWithLongName = modelWithLongName.startChat(); + const chatSessionWithLongName = await modelWithLongName.startChat(); const req = TEST_CHAT_MESSSAGE_TEXT; const generateContentSpy = spyOn( GenerateContentFunctions, @@ -2351,7 +2351,7 @@ describe('ChatSession', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chatSessionWithFullName = modelWithFullName.startChat(); + const chatSessionWithFullName = await modelWithFullName.startChat(); const req = TEST_CHAT_MESSSAGE_TEXT; const generateContentSpy = spyOn( GenerateContentFunctions, @@ -2379,7 +2379,7 @@ describe('ChatSession', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chat = model.startChat({ + const chat = await model.startChat({ generationConfig: { ...TEST_GENERATION_CONFIG, responseMimeType: 'application/json', @@ -2403,9 +2403,10 @@ describe('ChatSession', () => { googleAuth: FAKE_GOOGLE_AUTH, requestOptions: TEST_REQUEST_OPTIONS, }); - const chatSessionWithRequestOptions = modelWithRequestOptions.startChat({ - history: TEST_USER_CHAT_MESSAGE, - }) as ChatSessionForTest; + const chatSessionWithRequestOptions = + (await modelWithRequestOptions.startChat({ + history: TEST_USER_CHAT_MESSAGE, + })) as ChatSessionForTest; const req = 'How are you doing today?'; const generateContentSpy: jasmine.Spy = spyOn( GenerateContentFunctions, @@ -2516,7 +2517,7 @@ describe('ChatSession', () => { response: Promise.resolve(TEST_MODEL_RESPONSE), stream: testGenerator(), }; - const chatSession = model.startChat({ + const chatSession = await model.startChat({ history: [ { role: constants.USER_ROLE, @@ -2546,7 +2547,7 @@ describe('ChatSession', () => { }; spyOn(PostFetchFunctions, 'processStream').and.resolveTo(expectedResult); - const chatSession = model.startChat({ + const chatSession = await model.startChat({ history: [ { role: constants.USER_ROLE, @@ -2573,7 +2574,7 @@ describe('ChatSession', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chatSessionWithShortName = modelWithShortName.startChat(); + const chatSessionWithShortName = await modelWithShortName.startChat(); const req = TEST_CHAT_MESSSAGE_TEXT; const generateContentSpy = spyOn( GenerateContentFunctions, @@ -2599,7 +2600,7 @@ describe('ChatSession', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chatSessionWithLongName = modelWithLongName.startChat(); + const chatSessionWithLongName = await modelWithLongName.startChat(); const req = TEST_CHAT_MESSSAGE_TEXT; const generateContentSpy = spyOn( GenerateContentFunctions, @@ -2625,7 +2626,7 @@ describe('ChatSession', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chatSessionWithFullName = modelWithFullName.startChat(); + const chatSessionWithFullName = await modelWithFullName.startChat(); const req = TEST_CHAT_MESSSAGE_TEXT; const generateContentSpy = spyOn( GenerateContentFunctions, @@ -2651,9 +2652,10 @@ describe('ChatSession', () => { googleAuth: FAKE_GOOGLE_AUTH, requestOptions: TEST_REQUEST_OPTIONS, }); - const chatSessionWithRequestOptions = modelWithRequestOptions.startChat({ - history: TEST_USER_CHAT_MESSAGE, - }) as ChatSessionForTest; + const chatSessionWithRequestOptions = + (await modelWithRequestOptions.startChat({ + history: TEST_USER_CHAT_MESSAGE, + })) as ChatSessionForTest; const req = 'How are you doing today?'; const generateContentSpy: jasmine.Spy = spyOn( GenerateContentFunctions, @@ -2739,12 +2741,12 @@ describe('ChatSessionPreview', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - chatSession = model.startChat({ + chatSession = await model.startChat({ history: TEST_USER_CHAT_MESSAGE, }); expect(await chatSession.getHistory()).toEqual(TEST_USER_CHAT_MESSAGE); - chatSessionWithNoArgs = model.startChat(); - chatSessionWithFunctionCall = model.startChat({ + chatSessionWithNoArgs = await model.startChat(); + chatSessionWithFunctionCall = await model.startChat({ tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, }); expectedStreamResult = { @@ -2776,7 +2778,7 @@ describe('ChatSessionPreview', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chatSessionWithShortName = modelWithShortName.startChat(); + const chatSessionWithShortName = await modelWithShortName.startChat(); const req = TEST_CHAT_MESSSAGE_TEXT; const generateContentSpy = spyOn( GenerateContentFunctions, @@ -2801,7 +2803,7 @@ describe('ChatSessionPreview', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chatSessionWithLongName = modelWithLongName.startChat(); + const chatSessionWithLongName = await modelWithLongName.startChat(); const req = TEST_CHAT_MESSSAGE_TEXT; const generateContentSpy = spyOn( GenerateContentFunctions, @@ -2826,7 +2828,7 @@ describe('ChatSessionPreview', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chatSessionWithFullName = modelWithFullName.startChat(); + const chatSessionWithFullName = await modelWithFullName.startChat(); const req = TEST_CHAT_MESSSAGE_TEXT; const generateContentSpy = spyOn( GenerateContentFunctions, @@ -2850,7 +2852,7 @@ describe('ChatSessionPreview', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chat = model.startChat({ + const chat = await model.startChat({ generationConfig: { ...TEST_GENERATION_CONFIG, responseMimeType: 'application/json', @@ -2872,9 +2874,10 @@ describe('ChatSessionPreview', () => { googleAuth: FAKE_GOOGLE_AUTH, requestOptions: TEST_REQUEST_OPTIONS, }); - const chatSessionWithRequestOptions = modelWithRequestOptions.startChat({ - history: TEST_USER_CHAT_MESSAGE, - }) as ChatSessionPreviewForTest; + const chatSessionWithRequestOptions = + (await modelWithRequestOptions.startChat({ + history: TEST_USER_CHAT_MESSAGE, + })) as ChatSessionPreviewForTest; const req = TEST_CHAT_MESSSAGE_TEXT; const generateContentSpy: jasmine.Spy = spyOn( GenerateContentFunctions, @@ -2975,7 +2978,7 @@ describe('ChatSessionPreview', () => { response: Promise.resolve(TEST_MODEL_RESPONSE), stream: testGenerator(), }; - const chatSession = model.startChat({ + const chatSession = await model.startChat({ history: [ { role: constants.USER_ROLE, @@ -3005,7 +3008,7 @@ describe('ChatSessionPreview', () => { }; spyOn(PostFetchFunctions, 'processStream').and.resolveTo(expectedResult); - const chatSession = model.startChat({ + const chatSession = await model.startChat({ history: [ { role: constants.USER_ROLE, @@ -3032,7 +3035,7 @@ describe('ChatSessionPreview', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chatSessionWithShortName = modelWithShortName.startChat(); + const chatSessionWithShortName = await modelWithShortName.startChat(); const req = TEST_CHAT_MESSSAGE_TEXT; const generateContentSpy = spyOn( GenerateContentFunctions, @@ -3058,7 +3061,7 @@ describe('ChatSessionPreview', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chatSessionWithLongName = modelWithLongName.startChat(); + const chatSessionWithLongName = await modelWithLongName.startChat(); const req = TEST_CHAT_MESSSAGE_TEXT; const generateContentSpy = spyOn( GenerateContentFunctions, @@ -3084,7 +3087,7 @@ describe('ChatSessionPreview', () => { location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chatSessionWithFullName = modelWithFullName.startChat(); + const chatSessionWithFullName = await modelWithFullName.startChat(); const req = TEST_CHAT_MESSSAGE_TEXT; const generateContentSpy = spyOn( GenerateContentFunctions, @@ -3110,9 +3113,10 @@ describe('ChatSessionPreview', () => { googleAuth: FAKE_GOOGLE_AUTH, requestOptions: TEST_REQUEST_OPTIONS, }); - const chatSessionWithRequestOptions = modelWithRequestOptions.startChat({ - history: TEST_USER_CHAT_MESSAGE, - }) as ChatSessionPreviewForTest; + const chatSessionWithRequestOptions = + (await modelWithRequestOptions.startChat({ + history: TEST_USER_CHAT_MESSAGE, + })) as ChatSessionPreviewForTest; const req = TEST_CHAT_MESSSAGE_TEXT; const generateContentSpy: jasmine.Spy = spyOn( GenerateContentFunctions, @@ -3264,13 +3268,13 @@ describe('GenerativeModelPreview countTokens', () => { }); describe('GenerativeModel when exception at fetch', () => { + let chatSession: ChatSession; const model = new GenerativeModel({ model: 'gemini-pro', project: PROJECT, location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chatSession = model.startChat(); const message = 'hi'; const req: GenerateContentRequest = { contents: TEST_USER_CHAT_MESSAGE, @@ -3278,6 +3282,10 @@ describe('GenerativeModel when exception at fetch', () => { const countTokenReq: CountTokensRequest = { contents: TEST_USER_CHAT_MESSAGE, }; + beforeAll(async () => { + chatSession = await model.startChat(); + }); + beforeEach(() => { spyOn(global, 'fetch').and.throwError('error'); }); @@ -3300,13 +3308,14 @@ describe('GenerativeModel when exception at fetch', () => { }); describe('GenerativeModelPreview when exception at fetch', () => { + // @ts-ignore + let chatSession: ChatSessionPreview; const model = new GenerativeModelPreview({ model: 'gemini-pro', project: PROJECT, location: LOCATION, googleAuth: FAKE_GOOGLE_AUTH, }); - const chatSession = model.startChat(); const message = 'hi'; const req: GenerateContentRequest = { contents: TEST_USER_CHAT_MESSAGE, @@ -3314,6 +3323,10 @@ describe('GenerativeModelPreview when exception at fetch', () => { const countTokenReq: CountTokensRequest = { contents: TEST_USER_CHAT_MESSAGE, }; + beforeAll(async () => { + chatSession = await model.startChat(); + }); + beforeEach(() => { spyOn(global, 'fetch').and.throwError('error'); }); @@ -3336,6 +3349,7 @@ describe('GenerativeModelPreview when exception at fetch', () => { }); describe('GenerativeModel when response is undefined', () => { + let chatSession: ChatSession; const expectedErrorMessage = '[VertexAI.GoogleGenerativeAIError]: response is undefined'; const model = new GenerativeModel({ @@ -3348,10 +3362,14 @@ describe('GenerativeModel when response is undefined', () => { contents: TEST_USER_CHAT_MESSAGE, }; const message = 'hi'; - const chatSession = model.startChat(); const countTokenReq: CountTokensRequest = { contents: TEST_USER_CHAT_MESSAGE, }; + + beforeAll(async () => { + chatSession = await model.startChat(); + }); + beforeEach(() => { spyOn(global, 'fetch').and.resolveTo(); }); @@ -3386,6 +3404,7 @@ describe('GenerativeModel when response is undefined', () => { }); describe('GenerativeModelPreview when response is undefined', () => { + let chatSession: ChatSessionPreview; const expectedErrorMessage = '[VertexAI.GoogleGenerativeAIError]: response is undefined'; const model = new GenerativeModelPreview({ @@ -3398,10 +3417,14 @@ describe('GenerativeModelPreview when response is undefined', () => { contents: TEST_USER_CHAT_MESSAGE, }; const message = 'hi'; - const chatSession = model.startChat(); const countTokenReq: CountTokensRequest = { contents: TEST_USER_CHAT_MESSAGE, }; + + beforeAll(async () => { + chatSession = await model.startChat(); + }); + beforeEach(() => { spyOn(global, 'fetch').and.resolveTo(); }); @@ -3436,6 +3459,7 @@ describe('GenerativeModelPreview when response is undefined', () => { }); describe('GeneratvieModel when response is 4XX', () => { + let chatSession: ChatSession; const req: GenerateContentRequest = { contents: TEST_USER_CHAT_MESSAGE, }; @@ -3457,10 +3481,14 @@ describe('GeneratvieModel when response is 4XX', () => { googleAuth: FAKE_GOOGLE_AUTH, }); const message = 'hi'; - const chatSession = model.startChat(); const countTokenReq: CountTokensRequest = { contents: TEST_USER_CHAT_MESSAGE, }; + + beforeAll(async () => { + chatSession = await model.startChat(); + }); + beforeEach(() => { spyOn(global, 'fetch').and.resolveTo(response); }); @@ -3499,6 +3527,7 @@ describe('GeneratvieModel when response is 4XX', () => { }); describe('GeneratvieModelPreview when response is 4XX', () => { + let chatSession: ChatSessionPreview; const req: GenerateContentRequest = { contents: TEST_USER_CHAT_MESSAGE, }; @@ -3520,10 +3549,14 @@ describe('GeneratvieModelPreview when response is 4XX', () => { googleAuth: FAKE_GOOGLE_AUTH, }); const message = 'hi'; - const chatSession = model.startChat(); const countTokenReq: CountTokensRequest = { contents: TEST_USER_CHAT_MESSAGE, }; + + beforeAll(async () => { + chatSession = await model.startChat(); + }); + beforeEach(() => { spyOn(global, 'fetch').and.resolveTo(response); }); @@ -3562,6 +3595,7 @@ describe('GeneratvieModelPreview when response is 4XX', () => { }); describe('GenerativeModel when response is not OK and not 4XX', () => { + let chatSession: ChatSession; const req: GenerateContentRequest = { contents: TEST_USER_CHAT_MESSAGE, }; @@ -3583,10 +3617,14 @@ describe('GenerativeModel when response is not OK and not 4XX', () => { googleAuth: FAKE_GOOGLE_AUTH, }); const message = 'hi'; - const chatSession = model.startChat(); const countTokenReq: CountTokensRequest = { contents: TEST_USER_CHAT_MESSAGE, }; + + beforeAll(async () => { + chatSession = await model.startChat(); + }); + beforeEach(() => { spyOn(global, 'fetch').and.resolveTo(response); }); @@ -3625,6 +3663,7 @@ describe('GenerativeModel when response is not OK and not 4XX', () => { }); describe('GenerativeModelPreview when response is not OK and not 4XX', () => { + let chatSession: ChatSessionPreview; const req: GenerateContentRequest = { contents: TEST_USER_CHAT_MESSAGE, }; @@ -3646,10 +3685,14 @@ describe('GenerativeModelPreview when response is not OK and not 4XX', () => { googleAuth: FAKE_GOOGLE_AUTH, }); const message = 'hi'; - const chatSession = model.startChat(); const countTokenReq: CountTokensRequest = { contents: TEST_USER_CHAT_MESSAGE, }; + + beforeAll(async () => { + chatSession = await model.startChat(); + }); + beforeEach(() => { spyOn(global, 'fetch').and.resolveTo(response); }); diff --git a/system_test/end_to_end_sample_test.ts b/system_test/end_to_end_sample_test.ts index af9d59c4..d1d27c00 100644 --- a/system_test/end_to_end_sample_test.ts +++ b/system_test/end_to_end_sample_test.ts @@ -875,7 +875,7 @@ describe('sendMessage', () => { jasmine.DEFAULT_TIMEOUT_INTERVAL = 20000; }); it('should populate history and return a chat response', async () => { - const chat = generativeTextModel.startChat(); + const chat = await generativeTextModel.startChat(); const chatInput1 = 'How can I learn more about Node.js?'; const result1 = await chat.sendMessage(chatInput1); const response1 = result1.response; @@ -899,7 +899,7 @@ describe('sendMessage', () => { model: TEXT_MODEL_NAME, tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, }); - const chat = generativeTextModel.startChat(); + const chat = await generativeTextModel.startChat(); const result = await chat.sendMessage('Why is the sky blue?'); const response = result.response; const groundingMetadata = response.candidates![0].groundingMetadata; @@ -916,7 +916,7 @@ describe('sendMessage', () => { const generativeTextModel = vertexAI.getGenerativeModel({ model: TEXT_MODEL_NAME, }); - const chat = generativeTextModel.startChat({ + const chat = await generativeTextModel.startChat({ tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, }); const result = await chat.sendMessage('Why is the sky blue?'); @@ -936,7 +936,7 @@ describe('sendMessage', () => { model: TEXT_MODEL_NAME, tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, }); - const chat = generativeTextModel.startChat(); + const chat = await generativeTextModel.startChat(); const result = await chat.sendMessage('Why is the sky blue?'); const response = result.response; const groundingMetadata = response.candidates![0].groundingMetadata; @@ -953,7 +953,7 @@ describe('sendMessage', () => { const generativeTextModel = vertexAI.preview.getGenerativeModel({ model: TEXT_MODEL_NAME, }); - const chat = generativeTextModel.startChat({ + const chat = await generativeTextModel.startChat({ tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, }); const result = await chat.sendMessage('Why is the sky blue?'); @@ -975,7 +975,7 @@ describe('sendMessageStream', () => { jasmine.DEFAULT_TIMEOUT_INTERVAL = 30000; }); it('should should return a stream and populate history when generationConfig is passed to startChat', async () => { - const chat = generativeTextModel.startChat({ + const chat = await generativeTextModel.startChat({ generationConfig: { maxOutputTokens: 256, }, @@ -998,7 +998,7 @@ describe('sendMessageStream', () => { expect((await chat.getHistory()).length).toBe(2); }); it('in preview should should return a stream and populate history when generationConfig is passed to startChat', async () => { - const chat = generativeTextModelPreview.startChat({ + const chat = await generativeTextModelPreview.startChat({ generationConfig: { maxOutputTokens: 256, }, @@ -1022,7 +1022,7 @@ describe('sendMessageStream', () => { }); it('should should return a stream and populate history when startChat is passed no request obj', async () => { - const chat = generativeTextModel.startChat(); + const chat = await generativeTextModel.startChat(); const chatInput1 = 'How can I learn more about Node.js?'; const result1 = await chat.sendMessageStream(chatInput1); for await (const item of result1.stream) { @@ -1061,7 +1061,7 @@ describe('sendMessageStream', () => { }); xit('should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { - const chat = generativeTextModel.startChat({ + const chat = await generativeTextModel.startChat({ tools: TOOLS_WITH_FUNCTION_DECLARATION, }); const chatInput1 = 'What is the weather in Boston?'; @@ -1109,7 +1109,7 @@ describe('sendMessageStream', () => { ); }); xit('in preview should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { - const chat = generativeTextModelPreview.startChat({ + const chat = await generativeTextModelPreview.startChat({ tools: TOOLS_WITH_FUNCTION_DECLARATION, }); const chatInput1 = 'What is the weather in Boston?'; @@ -1158,7 +1158,7 @@ describe('sendMessageStream', () => { model: TEXT_MODEL_NAME, tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, }); - const chat = generativeTextModel.startChat(); + const chat = await generativeTextModel.startChat(); const result = await chat.sendMessageStream('Why is the sky blue?'); const response = await result.response; const groundingMetadata = response.candidates![0].groundingMetadata; @@ -1175,7 +1175,7 @@ describe('sendMessageStream', () => { const generativeTextModel = vertexAI.getGenerativeModel({ model: TEXT_MODEL_NAME, }); - const chat = generativeTextModel.startChat({ + const chat = await generativeTextModel.startChat({ tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, }); const result = await chat.sendMessageStream('Why is the sky blue?'); @@ -1195,7 +1195,7 @@ describe('sendMessageStream', () => { model: TEXT_MODEL_NAME, tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, }); - const chat = generativeTextModel.startChat(); + const chat = await generativeTextModel.startChat(); const result = await chat.sendMessageStream('Why is the sky blue?'); const response = await result.response; const groundingMetadata = response.candidates![0].groundingMetadata; @@ -1212,7 +1212,7 @@ describe('sendMessageStream', () => { const generativeTextModel = vertexAI.preview.getGenerativeModel({ model: TEXT_MODEL_NAME, }); - const chat = generativeTextModel.startChat({ + const chat = await generativeTextModel.startChat({ tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, }); const result = await chat.sendMessageStream('Why is the sky blue?');