Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,29 @@ public CompressionQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullab
public Query transform(Query query) {
Assert.notNull(query, "query cannot be null");

if (query.history().isEmpty()) {
logger.debug("No conversation history to compress. Returning the original query.");
return query;
}

logger.debug("Compressing conversation history and follow-up query into a standalone query");

// Exclude the last history entry if it's a duplicate of the current query
List<Message> history = query.history();
List<Message> effectiveHistory;
int lastIdx = history.size() - 1;
Message lastMessage = history.get(lastIdx);
String lastText = lastMessage.getText();
if (lastMessage.getMessageType() == MessageType.USER && lastText != null && lastText.equals(query.text())) {
effectiveHistory = history.subList(0, lastIdx);
}
else {
effectiveHistory = history;
}

var compressedQueryText = this.chatClient.prompt()
.user(user -> user.text(this.promptTemplate.getTemplate())
.param("history", formatConversationHistory(query.history()))
.param("history", formatConversationHistory(effectiveHistory))
.param("query", query.text()))
.call()
.content();
Expand Down Expand Up @@ -136,4 +154,4 @@ public CompressionQueryTransformer build() {

}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@
import org.junit.jupiter.api.Test;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

/**
* Unit tests for {@link CompressionQueryTransformer}.
Expand Down Expand Up @@ -69,4 +75,55 @@ void whenPromptHasMissingQueryPlaceholderThenThrow() {
.hasMessageContaining("query");
}

@Test
void whenHistoryIsEmptyThenReturnOriginalQuery() {
ChatClient.Builder chatClientBuilder = mock(ChatClient.Builder.class);
ChatClient chatClient = mock(ChatClient.class);
when(chatClientBuilder.build()).thenReturn(chatClient);

QueryTransformer queryTransformer = CompressionQueryTransformer.builder()
.chatClientBuilder(chatClientBuilder)
.build();

Query query = Query.builder().text("What is Spring AI?").build();

Query result = queryTransformer.transform(query);

assertThat(result).isEqualTo(query);
}

@Test
void whenLastHistoryEntryMatchesCurrentQueryThenExcludeItFromHistory() {
ChatClient.Builder chatClientBuilder = mock(ChatClient.Builder.class);
ChatClient chatClient = mock(ChatClient.class);
ChatClient.Callable mockedCallable = mock(ChatClient.Callable.class);
when(chatClientBuilder.build()).thenReturn(chatClient);
when(chatClient.prompt()).thenReturn(mockedCallable);
when(mockedCallable.user(any())).thenReturn(mockedCallable);
when(mockedCallable.call()).thenReturn(
org.springframework.ai.chat.client.ChatResponse.builder().content("What is Spring AI?").build());

QueryTransformer queryTransformer = CompressionQueryTransformer.builder()
.chatClientBuilder(chatClientBuilder)
.build();

String currentQueryText = "What is Spring AI?";
Message userMessage = Message.builder().messageType(MessageType.USER).text(currentQueryText).build();
Message assistantMessage = Message.builder()
.messageType(MessageType.ASSISTANT)
.text("Spring AI is a framework for AI-native applications.")
.build();

Query query = Query.builder()
.text(currentQueryText)
.history(java.util.List.of(assistantMessage, userMessage))
.build();

queryTransformer.transform(query);

// Verify the prompt was called - the history should not contain the duplicate
// query
verify(chatClient).prompt();
}

}