diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/CompressionQueryTransformer.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/CompressionQueryTransformer.java index 8d2bcee4a6..b03e1d2b07 100644 --- a/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/CompressionQueryTransformer.java +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/CompressionQueryTransformer.java @@ -77,11 +77,29 @@ public CompressionQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullab public Query transform(Query query) { Assert.notNull(query, "query cannot be null"); + if (query.history().isEmpty()) { + logger.debug("No conversation history to compress. Returning the original query."); + return query; + } + logger.debug("Compressing conversation history and follow-up query into a standalone query"); + // Exclude the last history entry if it's a duplicate of the current query + List history = query.history(); + List effectiveHistory; + int lastIdx = history.size() - 1; + Message lastMessage = history.get(lastIdx); + String lastText = lastMessage.getText(); + if (lastMessage.getMessageType() == MessageType.USER && lastText != null && lastText.equals(query.text())) { + effectiveHistory = history.subList(0, lastIdx); + } + else { + effectiveHistory = history; + } + var compressedQueryText = this.chatClient.prompt() .user(user -> user.text(this.promptTemplate.getTemplate()) - .param("history", formatConversationHistory(query.history())) + .param("history", formatConversationHistory(effectiveHistory)) .param("query", query.text())) .call() .content(); @@ -136,4 +154,4 @@ public CompressionQueryTransformer build() { } -} +} \ No newline at end of file diff --git a/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/CompressionQueryTransformerTests.java b/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/CompressionQueryTransformerTests.java index cc983c6624..ce1d047967 100644 --- a/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/CompressionQueryTransformerTests.java +++ b/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/CompressionQueryTransformerTests.java @@ -19,10 +19,16 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.rag.Query; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; /** * Unit tests for {@link CompressionQueryTransformer}. @@ -69,4 +75,55 @@ void whenPromptHasMissingQueryPlaceholderThenThrow() { .hasMessageContaining("query"); } + @Test + void whenHistoryIsEmptyThenReturnOriginalQuery() { + ChatClient.Builder chatClientBuilder = mock(ChatClient.Builder.class); + ChatClient chatClient = mock(ChatClient.class); + when(chatClientBuilder.build()).thenReturn(chatClient); + + QueryTransformer queryTransformer = CompressionQueryTransformer.builder() + .chatClientBuilder(chatClientBuilder) + .build(); + + Query query = Query.builder().text("What is Spring AI?").build(); + + Query result = queryTransformer.transform(query); + + assertThat(result).isEqualTo(query); + } + + @Test + void whenLastHistoryEntryMatchesCurrentQueryThenExcludeItFromHistory() { + ChatClient.Builder chatClientBuilder = mock(ChatClient.Builder.class); + ChatClient chatClient = mock(ChatClient.class); + ChatClient.Callable mockedCallable = mock(ChatClient.Callable.class); + when(chatClientBuilder.build()).thenReturn(chatClient); + when(chatClient.prompt()).thenReturn(mockedCallable); + when(mockedCallable.user(any())).thenReturn(mockedCallable); + when(mockedCallable.call()).thenReturn( + org.springframework.ai.chat.client.ChatResponse.builder().content("What is Spring AI?").build()); + + QueryTransformer queryTransformer = CompressionQueryTransformer.builder() + .chatClientBuilder(chatClientBuilder) + .build(); + + String currentQueryText = "What is Spring AI?"; + Message userMessage = Message.builder().messageType(MessageType.USER).text(currentQueryText).build(); + Message assistantMessage = Message.builder() + .messageType(MessageType.ASSISTANT) + .text("Spring AI is a framework for AI-native applications.") + .build(); + + Query query = Query.builder() + .text(currentQueryText) + .history(java.util.List.of(assistantMessage, userMessage)) + .build(); + + queryTransformer.transform(query); + + // Verify the prompt was called - the history should not contain the duplicate + // query + verify(chatClient).prompt(); + } + }