Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions src/tools/search/cohere-reranker.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import { CohereReranker, createReranker } from './rerankers';
import { createDefaultLogger } from './utils';

// Helper to access private apiUrl property for testing
const getApiUrl = (reranker: CohereReranker): string =>
(reranker as unknown as { apiUrl: string }).apiUrl;

describe('CohereReranker', () => {
const mockLogger = createDefaultLogger();

describe('constructor', () => {
it('should use default API URL when no apiUrl is provided', () => {
const reranker = new CohereReranker({
apiKey: 'test-key',
logger: mockLogger,
});

expect(getApiUrl(reranker)).toBe('https://api.cohere.com/v2/rerank');
});

it('should use custom API URL when provided', () => {
const customUrl = 'https://custom-cohere-endpoint.com/v2/rerank';
const reranker = new CohereReranker({
apiKey: 'test-key',
apiUrl: customUrl,
logger: mockLogger,
});

expect(getApiUrl(reranker)).toBe(customUrl);
});

it('should use environment variable COHERE_API_URL when available', () => {
const originalEnv = process.env.COHERE_API_URL;
process.env.COHERE_API_URL = 'https://env-cohere-endpoint.com/v2/rerank';

const reranker = new CohereReranker({
apiKey: 'test-key',
logger: mockLogger,
});

expect(getApiUrl(reranker)).toBe(
'https://env-cohere-endpoint.com/v2/rerank'
);

// Restore original environment
if (originalEnv !== undefined) {
process.env.COHERE_API_URL = originalEnv;
} else {
delete process.env.COHERE_API_URL;
}
});

it('should prioritize explicit apiUrl over environment variable', () => {
const originalEnv = process.env.COHERE_API_URL;
process.env.COHERE_API_URL = 'https://env-cohere-endpoint.com/v2/rerank';

const customUrl = 'https://explicit-cohere-endpoint.com/v2/rerank';
const reranker = new CohereReranker({
apiKey: 'test-key',
apiUrl: customUrl,
logger: mockLogger,
});

expect(getApiUrl(reranker)).toBe(customUrl);

// Restore original environment
if (originalEnv !== undefined) {
process.env.COHERE_API_URL = originalEnv;
} else {
delete process.env.COHERE_API_URL;
}
});
});

describe('rerank method', () => {
it('should log the API URL being used', async () => {
const customUrl = 'https://test-cohere-endpoint.com/v2/rerank';
const reranker = new CohereReranker({
apiKey: 'test-key',
apiUrl: customUrl,
logger: mockLogger,
});

const logSpy = jest.spyOn(mockLogger, 'debug');

try {
await reranker.rerank('test query', ['document1', 'document2'], 2);
} catch (_error) {
// Expected to fail due to network error, but we can check the log
}

expect(logSpy).toHaveBeenCalledWith(
expect.stringContaining(
`Reranking 2 chunks with Cohere using API URL: ${customUrl}`
)
);

logSpy.mockRestore();
});
});
});

describe('createReranker for Cohere', () => {
it('should create CohereReranker with cohereApiUrl when provided', () => {
const customUrl = 'https://custom-cohere-endpoint.com/v2/rerank';
const reranker = createReranker({
rerankerType: 'cohere',
cohereApiKey: 'test-key',
cohereApiUrl: customUrl,
});

expect(reranker).toBeInstanceOf(CohereReranker);
expect(getApiUrl(reranker as CohereReranker)).toBe(customUrl);
});

it('should create CohereReranker with default URL when cohereApiUrl is not provided', () => {
const reranker = createReranker({
rerankerType: 'cohere',
cohereApiKey: 'test-key',
});

expect(reranker).toBeInstanceOf(CohereReranker);
expect(getApiUrl(reranker as CohereReranker)).toBe(
'https://api.cohere.com/v2/rerank'
);
});
});
13 changes: 10 additions & 3 deletions src/tools/search/rerankers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,23 +115,28 @@ export class JinaReranker extends BaseReranker {
}

export class CohereReranker extends BaseReranker {
private apiUrl: string;

constructor({
apiKey = process.env.COHERE_API_KEY,
apiUrl = process.env.COHERE_API_URL || 'https://api.cohere.com/v2/rerank',
logger,
}: {
apiKey?: string;
apiUrl?: string;
logger?: t.Logger;
}) {
super(logger);
this.apiKey = apiKey;
this.apiUrl = apiUrl;
}

async rerank(
query: string,
documents: string[],
topK: number = 5
): Promise<t.Highlight[]> {
this.logger.debug(`Reranking ${documents.length} chunks with Cohere`);
this.logger.debug(`Reranking ${documents.length} chunks with Cohere using API URL: ${this.apiUrl}`);

try {
if (this.apiKey == null || this.apiKey === '') {
Expand All @@ -147,7 +152,7 @@ export class CohereReranker extends BaseReranker {
};

const response = await axios.post<t.CohereRerankerResponse | undefined>(
'https://api.cohere.com/v2/rerank',
this.apiUrl,
requestData,
{
headers: {
Expand Down Expand Up @@ -208,9 +213,10 @@ export const createReranker = (config: {
jinaApiKey?: string;
jinaApiUrl?: string;
cohereApiKey?: string;
cohereApiUrl?: string;
logger?: t.Logger;
}): BaseReranker | undefined => {
const { rerankerType, jinaApiKey, jinaApiUrl, cohereApiKey, logger } = config;
const { rerankerType, jinaApiKey, jinaApiUrl, cohereApiKey, cohereApiUrl, logger } = config;

// Create a default logger if none is provided
const defaultLogger = logger || createDefaultLogger();
Expand All @@ -221,6 +227,7 @@ export const createReranker = (config: {
case 'cohere':
return new CohereReranker({
apiKey: cohereApiKey,
apiUrl: cohereApiUrl,
logger: defaultLogger,
});
case 'infinity':
Expand Down