-
Notifications
You must be signed in to change notification settings - Fork 34
Feat/add OpenAI reranking #288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| provider: "" | ||
| model_name: ${oc.env:RERANKER_MODEL, Alibaba-NLP/gte-multilingual-reranker-base} | ||
| top_k: ${oc.decode:${oc.env:RERANKER_TOP_K, 10}} # Number of documents to return after reranking. Upgrade for better results if your llm has a wider context window. | ||
| base_url: ${oc.env:RERANKER_BASE_URL, http://reranker:${oc.env:RERANKER_PORT, 7997}} | ||
| semaphore: ${oc.decode:${oc.env:RERANKER_SEMAPHORE, 40}} # Number of concurrent reranking operations. Adjust based on your server capacity. | ||
| enabled: ${oc.decode:${oc.env:RERANKER_ENABLED, true}} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| defaults: | ||
| - base | ||
|
|
||
| provider: infinity | ||
| base_url: ${oc.env:RERANKER_BASE_URL, http://reranker:${oc.env:RERANKER_PORT, 7997}} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| defaults: | ||
| - base | ||
|
|
||
| provider: openai | ||
| api_key: ${oc.env:RERANKER_API_KEY, "EMPTY"} | ||
| base_url: ${oc.env:RERANKER_BASE_URL, http://reranker:${oc.env:RERANKER_PORT, 8000}} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| x-vllm-env: &vllm_env | ||
| HUGGING_FACE_HUB_TOKEN: | ||
| VLLM_SLEEP_WHEN_IDLE: 1 # Avoid 100% CPU usage when idle | ||
|
|
||
| x-reranker: &reranker_template | ||
| networks: | ||
| default: | ||
| aliases: | ||
| - reranker | ||
| # restart: on-failure | ||
| environment: | ||
| - HUGGING_FACE_HUB_TOKEN | ||
| - VLLM_SLEEP_WHEN_IDLE=1 # Avoid 100% CPU usage when idle | ||
| ipc: "host" | ||
| volumes: | ||
| - ${VLLM_CACHE:-/root/.cache/huggingface}:/root/.cache/huggingface | ||
| command: > | ||
| --model ${RERANKER_MODEL:-BAAI/bge-reranker-v2-m3} | ||
| --trust-remote-code | ||
| --gpu_memory_utilization 0.3 | ||
| healthcheck: | ||
| test: ["CMD", "curl", "-f", "http://localhost:8000/health"] | ||
| interval: 20s | ||
| timeout: 5s | ||
| retries: 4 | ||
| start_period: 90s | ||
| ports: | ||
| - ${RERANKER_PORT:-8003}:8000 | ||
|
|
||
| services: | ||
| reranker-gpu: | ||
| <<: *reranker_template | ||
| image: vllm/vllm-openai:v0.17.1 | ||
| environment: | ||
| <<: *vllm_env | ||
| NVIDIA_VISIBLE_DEVICES: all | ||
| NVIDIA_DRIVER_CAPABILITIES: compute,utility | ||
| runtime: nvidia | ||
| profiles: | ||
| - "" | ||
| deploy: | ||
| resources: | ||
| reservations: | ||
| devices: | ||
| - driver: nvidia | ||
| count: all | ||
| capabilities: [gpu] | ||
|
|
||
| reranker-cpu: | ||
| <<: *reranker_template | ||
| image: vllm/vllm-openai-cpu:v0.17.1 | ||
| deploy: {} | ||
| environment: | ||
| <<: *vllm_env | ||
| VLLM_CPU_KVCACHE_SPACE: 8 | ||
| command: > | ||
| --model ${RERANKER_MODEL:-BAAI/bge-reranker-v2-m3} | ||
| --trust-remote-code | ||
| --dtype float32 | ||
| profiles: | ||
| - "cpu" | ||
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| from .base import BaseReranker | ||
| from .infinity import InfinityReranker | ||
| from .openai import OpenAIReranker | ||
|
|
||
| RERANKER_MAPPING = { | ||
| "infinity": InfinityReranker, | ||
| "openai": OpenAIReranker, | ||
| } | ||
|
|
||
|
|
||
| class RerankerFactory: | ||
| @staticmethod | ||
| def get_reranker(config: dict) -> BaseReranker: | ||
| provider = config.reranker.get("provider") | ||
| reranker_class = RERANKER_MAPPING.get(provider, None) | ||
|
|
||
| if not reranker_class: | ||
| raise ValueError(f"Unsupported reranker provider: {provider}") | ||
|
|
||
| return reranker_class(config) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| from langchain_core.documents.base import Document | ||
|
|
||
|
|
||
| class BaseReranker: | ||
| async def rerank(self, query: str, documents: list[Document], top_k: int | None = None) -> list[Document]: | ||
| """Rerank a list of documents based on a query and an optional top_k parameter""" | ||
| raise NotImplementedError("Rerank method must be implemented by subclasses") | ||
|
|
||
| @staticmethod | ||
| def rrf_reranking(doc_lists: list[list], k: int = 60) -> list[Document]: | ||
| """Reciprocal_rank_fusion that takes multiple lists of ranked documents | ||
| and an optional parameter k used in the RRF formula | ||
| RRF formula: \\sum_{i=1}^{n} \frac{1}{k + rank_i} | ||
| where rank_i is the rank of the document in the i-th list and n is the number of lists. | ||
|
|
||
| k small: High sensitivity to top ranks | ||
| k large: More balanced sensitivity across ranks | ||
| k = 60 a common and balanced choice in practice. | ||
| """ | ||
|
|
||
| if len(doc_lists) == 1: | ||
| return doc_lists[0] | ||
|
|
||
| # Initialize a dictionary to hold fused scores for each unique document | ||
| fused_scores = {} | ||
|
|
||
| for doc_list in doc_lists: | ||
| doc_list: list[Document] | ||
| for rank, doc in enumerate(doc_list, start=1): | ||
| doc_id = doc.metadata.get("_id") | ||
|
|
||
| score, d = fused_scores.get(doc_id, (0, doc)) | ||
| fused_scores[doc_id] = (score + 1 / (rank + k), d) | ||
|
|
||
| # sort the docs | ||
| reranked_docs = [doc for _, doc in sorted(fused_scores.values(), key=lambda x: x[0], reverse=True)] | ||
| return reranked_docs |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,52 @@ | ||||||
| import asyncio | ||||||
|
|
||||||
| from infinity_client import Client | ||||||
| from infinity_client.api.default import rerank | ||||||
| from infinity_client.models import RerankInput, ReRankResult | ||||||
| from langchain_core.documents.base import Document | ||||||
| from utils.logger import get_logger | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use absolute import from Per coding guidelines, imports should use absolute paths from the -from utils.logger import get_logger
+from openrag.utils.logger import get_logger📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||
|
|
||||||
| from .base import BaseReranker | ||||||
|
|
||||||
| logger = get_logger() | ||||||
|
|
||||||
|
|
||||||
| class InfinityReranker(BaseReranker): | ||||||
| def __init__(self, config): | ||||||
| self.model_name = config.reranker["model_name"] | ||||||
| self.client = Client(base_url=config.reranker["base_url"]) | ||||||
| semaphore = config.reranker.get("semaphore", 40) | ||||||
| self.semaphore = asyncio.Semaphore(semaphore) | ||||||
| logger.debug("Reranker initialized", model_name=self.model_name) | ||||||
|
|
||||||
| async def rerank(self, query: str, documents: list[Document], top_k: int | None = None) -> list[Document]: | ||||||
| async with self.semaphore: | ||||||
| logger.debug("Reranking documents", documents_count=len(documents), top_k=top_k) | ||||||
| top_k = min(top_k, len(documents)) if top_k is not None else len(documents) | ||||||
| rerank_input = RerankInput.from_dict( | ||||||
| { | ||||||
| "model": self.model_name, | ||||||
| "query": query, | ||||||
| "documents": [doc.page_content for doc in documents], | ||||||
| "top_n": top_k, | ||||||
| "return_documents": True, | ||||||
| "raw_scores": True, # Normalized score between 0 and 1 | ||||||
| } | ||||||
| ) | ||||||
| try: | ||||||
| rerank_result: ReRankResult = await rerank.asyncio(client=self.client, body=rerank_input) | ||||||
| output = [] | ||||||
| for rerank_res in rerank_result.results: | ||||||
| doc = documents[rerank_res.index] | ||||||
| doc.metadata["relevance_score"] = rerank_res.relevance_score | ||||||
| output.append(doc) | ||||||
| return output | ||||||
|
|
||||||
| except Exception as e: | ||||||
| logger.error( | ||||||
| "Reranking failed", | ||||||
| error=str(e), | ||||||
| model_name=self.model_name, | ||||||
| documents_count=len(documents), | ||||||
| ) | ||||||
| raise e | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: linagora/openrag
Length of output: 343
🏁 Script executed:
Repository: linagora/openrag
Length of output: 1962
🏁 Script executed:
Repository: linagora/openrag
Length of output: 1628
🏁 Script executed:
Repository: linagora/openrag
Length of output: 3180
RERANKER_PORToverride breaks service connectivity when overridden.The openai.yaml vLLM service lacks a
--portflag in its commands (lines 17-20 and 56-59), so it always listens on port 8000 internally. However, the port mapping on line 28 (${RERANKER_PORT:-8003}:8000) changes whenRERANKER_PORTis set to a non-default value. When overridden, the port mapping remaps the container's port 8000 to the host, but the Hydra configuration constructs URLs using the overridden port number, causing a mismatch. Compare this to the infinity.yaml reranker variant, which correctly passes--port ${RERANKER_PORT:-7997}to the service.Add the
--portflag to both reranker command definitions and update the port mapping to maintain port alignment across the stack.Proposed fix
command: > --model ${RERANKER_MODEL:-BAAI/bge-reranker-v2-m3} --trust-remote-code --gpu_memory_utilization 0.3 + --port ${RERANKER_PORT:-8000} ports: - - ${RERANKER_PORT:-8003}:8000 + - ${RERANKER_PORT:-8003}:${RERANKER_PORT:-8000}Apply the same change to the reranker-cpu command block (lines 56-59).
🤖 Prompt for AI Agents