Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
"cheerio": "^1.0.0",
"dotenv": "^16.4.7",
"https-proxy-agent": "^7.0.6",
"jsonwebtoken": "^9.0.2",
"mathjs": "^15.1.0",
"nanoid": "^3.3.7",
"okapibm25": "^1.4.1",
Expand All @@ -151,6 +152,7 @@
"@rollup/plugin-typescript": "^12.1.2",
"@swc/core": "^1.6.13",
"@types/jest": "^30.0.0",
"@types/jsonwebtoken": "^9.0.10",
"@types/node": "^20.14.11",
"@types/node-fetch": "^2.6.13",
"@types/yargs-parser": "^21.0.3",
Expand Down
105 changes: 101 additions & 4 deletions src/tools/search/rerankers.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import axios from 'axios';
import type * as t from './types';
import { createDefaultLogger } from './utils';
import jwt from 'jsonwebtoken';
import { nanoid } from 'nanoid';

export abstract class BaseReranker {
protected apiKey: string | undefined;
Expand All @@ -27,6 +29,87 @@ export abstract class BaseReranker {
}
}

export class SimpleReranker extends BaseReranker {
private instanceUrl: string | undefined;

constructor({ logger }: { logger?: t.Logger }) {
super(logger);
if (
process.env.RAG_API_URL !== undefined &&
process.env.RAG_API_URL !== ''
) {
this.instanceUrl = process.env.RAG_API_URL + '/rerank';
}
}

async rerank(
query: string,
documents: string[],
topK: number = 5
): Promise<t.Highlight[]> {
this.logger.debug(
`Reranking ${documents.length} chunks with SimpleReranker`
);

if (this.instanceUrl === undefined || this.instanceUrl === '') {
this.logger.warn('RAG_API_URL is not set. Using default ranking.');
return this.getDefaultRanking(documents, topK);
}

try {
const requestData = {
query: query,
docs: documents,
k: topK,
};

const statePayload = {
nonce: nanoid(),
};

const jwtSecret = process.env.JWT_SECRET;

if (jwtSecret === undefined || jwtSecret === '') {
this.logger.warn('JWT_SECRET is not set. Using default ranking.');
return this.getDefaultRanking(documents, topK);
}

const stateToken = jwt.sign(statePayload, jwtSecret, {
expiresIn: '10m',
});

const resp = await axios.post<t.SimpleRerankerResponse[]>(
this.instanceUrl,
requestData,
{
headers: {
'Content-Type': 'application/json',
Authorization: 'Bearer ' + stateToken,
},
}
);

if (resp.data && Array.isArray(resp.data) && resp.data.length > 0) {
const isValid = resp.data.every(
(item: t.SimpleRerankerResponse) =>
typeof item.text === 'string' && typeof item.score === 'number'
);
if (isValid) {
return resp.data;
}
this.logger.warn(
'Unexpected response format from Simple reranker. Using default ranking.'
);
}
return this.getDefaultRanking(documents, topK);
} catch (error) {
this.logger.error('Error using Simple reranker:', error);
// Fallback to default ranking on error
return this.getDefaultRanking(documents, topK);
}
}
}

export class JinaReranker extends BaseReranker {
private apiUrl: string;

Expand All @@ -49,7 +132,9 @@ export class JinaReranker extends BaseReranker {
documents: string[],
topK: number = 5
): Promise<t.Highlight[]> {
this.logger.debug(`Reranking ${documents.length} chunks with Jina using API URL: ${this.apiUrl}`);
this.logger.debug(
`Reranking ${documents.length} chunks with Jina using API URL: ${this.apiUrl}`
);

try {
if (this.apiKey == null || this.apiKey === '') {
Expand Down Expand Up @@ -217,22 +302,34 @@ export const createReranker = (config: {

switch (rerankerType.toLowerCase()) {
case 'jina':
return new JinaReranker({ apiKey: jinaApiKey, apiUrl: jinaApiUrl, logger: defaultLogger });
return new JinaReranker({
apiKey: jinaApiKey,
apiUrl: jinaApiUrl,
logger: defaultLogger,
});
case 'cohere':
return new CohereReranker({
apiKey: cohereApiKey,
logger: defaultLogger,
});
case 'simple':
return new SimpleReranker({
logger: defaultLogger,
});
case 'infinity':
return new InfinityReranker(defaultLogger);
case 'none':
defaultLogger.debug('Skipping reranking as reranker is set to "none"');
return undefined;
default:
defaultLogger.warn(
`Unknown reranker type: ${rerankerType}. Defaulting to InfinityReranker.`
`Unknown reranker type: ${rerankerType}. Defaulting to JinaReranker.`
);
return new JinaReranker({ apiKey: jinaApiKey, apiUrl: jinaApiUrl, logger: defaultLogger });
return new JinaReranker({
apiKey: jinaApiKey,
apiUrl: jinaApiUrl,
logger: defaultLogger,
});
}
};

Expand Down
Loading