diff --git a/openrag/api.py b/openrag/api.py
index 349069bf..922a33ea 100644
--- a/openrag/api.py
+++ b/openrag/api.py
@@ -167,6 +167,16 @@ async def openrag_exception_handler(request: Request, exc: OpenRAGError):
return JSONResponse(status_code=exc.status_code, content=exc.to_dict())
+@app.exception_handler(Exception)
+async def unhandled_exception_handler(request: Request, exc: Exception):
+ logger = get_logger()
+ logger.exception("Unhandled exception", error=str(exc))
+ return JSONResponse(
+ status_code=500,
+ content={"detail": "[UNEXPECTED_ERROR]: An unexpected error occurred", "extra": {}},
+ )
+
+
# Add CORS middleware
allow_origins = [
"http://localhost:3042",
diff --git a/openrag/app_front.py b/openrag/app_front.py
index 5735bbc1..ebc66fc7 100644
--- a/openrag/app_front.py
+++ b/openrag/app_front.py
@@ -66,7 +66,7 @@ async def on_chat_resume(thread):
@cl.password_auth_callback
async def auth_callback(username: str, password: str):
try:
- async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=httpx.Timeout(4 * 60.0))) as client:
+ async with httpx.AsyncClient(timeout=httpx.Timeout(4 * 60.0)) as client:
response = await client.get(
url=f"{INTERNAL_BASE_URL}/users/info",
headers=get_headers(password),
@@ -131,7 +131,7 @@ async def on_chat_start():
api_key = user.metadata.get("api_key", "sk-1234") if user else "sk-1234"
logger.debug("New Chat Started", internal_base_url=INTERNAL_BASE_URL)
try:
- async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=httpx.Timeout(4 * 60.0))) as client:
+ async with httpx.AsyncClient(timeout=httpx.Timeout(4 * 60.0)) as client:
response = await client.get(
url=f"{INTERNAL_BASE_URL}/health_check",
headers=get_headers(api_key),
diff --git a/openrag/components/indexer/chunker/chunker.py b/openrag/components/indexer/chunker/chunker.py
index 3a8e5f1d..63431ae4 100644
--- a/openrag/components/indexer/chunker/chunker.py
+++ b/openrag/components/indexer/chunker/chunker.py
@@ -2,6 +2,7 @@
import openai
from components.indexer.utils.text_sanitizer import sanitize_text
+from utils.exceptions.base import OpenRAGError
from components.prompts import CHUNK_CONTEXTUALIZER_PROMPT
from components.utils import detect_language, get_vlm_semaphore, load_config
from langchain_core.documents.base import Document
@@ -68,18 +69,20 @@ async def _generate_context(
]
output = await self.context_generator.ainvoke(messages)
return output.content
+
except openai.APITimeoutError:
- logger.warning(
- f"OpenAI API timeout contextualizing chunk after {CONTEXTUALIZATION_TIMEOUT}s",
- filename=filename,
- )
+ # VLM timeout - graceful degradation
+ logger.warning("VLM context generation timeout", timeout=CONTEXTUALIZATION_TIMEOUT)
return ""
- except Exception as e:
- logger.warning(
- "Error contextualizing chunk of document",
- filename=filename,
- error=str(e),
- )
+
+ except openai.APIError as e:
+ # Other VLM API errors - log but don't fail chunking
+ logger.error("VLM context generation failed", error=str(e))
+ return ""
+
+ except Exception:
+ # Unexpected errors - log but still gracefully degrade
+ logger.exception("Unexpected error during context generation")
return ""
async def contextualize_chunks(
@@ -130,8 +133,10 @@ async def contextualize_chunks(
for chunk, context in zip(chunks, contexts, strict=True)
]
- except Exception as e:
- logger.warning(f"Error contextualizing chunks from `{filename}`: {e}")
+ except OpenRAGError:
+ raise
+ except Exception:
+ logger.exception("Error contextualizing chunks", filename=filename)
return chunks
diff --git a/openrag/components/indexer/embeddings/openai.py b/openrag/components/indexer/embeddings/openai.py
index 9dc031e7..45a829b9 100644
--- a/openrag/components/indexer/embeddings/openai.py
+++ b/openrag/components/indexer/embeddings/openai.py
@@ -23,8 +23,18 @@ def embedding_dimension(self) -> int:
# Test call to get embedding dimension
output = self.embed_documents([Document(page_content="test")])
return len(output[0])
+
+ except openai.APIError as e:
+ logger.error("Failed to get embedding dimension", error=str(e))
+ raise EmbeddingAPIError(f"API error: {e}", model_name=self.embedding_model)
+
+ except (IndexError, AttributeError) as e:
+ logger.error("Invalid embedding response format", error=str(e))
+ raise EmbeddingResponseError("Unexpected response format", error=str(e))
+
except Exception:
- raise
+ logger.exception("Unexpected error getting embedding dimension")
+ raise UnexpectedEmbeddingError("An unexpected error occurred")
def embed_documents(self, texts: list[str | Document]) -> list[list[float]]:
"""
@@ -62,7 +72,7 @@ def embed_documents(self, texts: list[str | Document]) -> list[list[float]]:
except Exception as e:
logger.exception("Unexpected error while embedding documents", error=str(e))
raise UnexpectedEmbeddingError(
- f"Failed to embed documents: {e!s}",
+ "An unexpected error occurred during document embedding",
model_name=self.embedding_model,
base_url=self.base_url,
error=str(e),
diff --git a/openrag/components/indexer/indexer.py b/openrag/components/indexer/indexer.py
index 6c3930b7..3e0c16fd 100644
--- a/openrag/components/indexer/indexer.py
+++ b/openrag/components/indexer/indexer.py
@@ -10,6 +10,8 @@
import torch
from config import load_config
from langchain_core.documents.base import Document
+from utils.exceptions.base import OpenRAGError
+from utils.exceptions.common import UnexpectedError
from .chunker import BaseChunker, ChunkerFactory
from .utils import serialize_file
@@ -119,25 +121,34 @@ async def add_file(
# Mark task as completed
await task_state_manager.set_state.remote(task_id, "COMPLETED")
- except Exception as e:
- log.exception(f"Task {task_id} failed in add_file")
+ except OpenRAGError as e:
+ log.error("Operation failed during file ingestion", code=e.code, error=e.message)
tb = "".join(traceback.format_exception(type(e), e, e.__traceback__))
await task_state_manager.set_state.remote(task_id, "FAILED")
await task_state_manager.set_error.remote(task_id, tb)
raise
+ except Exception as e:
+ # Truly unexpected errors
+ log.exception("Unexpected error during file ingestion", task_id=task_id)
+ tb = "".join(traceback.format_exception(type(e), e, e.__traceback__))
+ await task_state_manager.set_state.remote(task_id, "FAILED")
+ await task_state_manager.set_error.remote(task_id, tb)
+ raise UnexpectedError("An unexpected error occurred during file processing") from e
+
finally:
+ # GPU cleanup
if torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
+
+ # File cleanup - nest try/except per Phase 2 decision
try:
- # Cleanup input file
if not save_uploaded_files:
Path(path).unlink(missing_ok=True)
- log.debug(f"Deleted input file: {path}")
except Exception as cleanup_err:
- log.warning(f"Failed to delete input file {path}: {cleanup_err}")
+ log.warning("Failed to delete input file", path=path, error=str(cleanup_err))
return True
@ray.method(concurrency_group="insert")
@@ -157,10 +168,13 @@ async def delete_file(self, file_id: str, partition: str) -> bool:
await vectordb.delete_file.remote(file_id, partition)
log.info("Deleted file from partition.", file_id=file_id, partition=partition)
- except Exception as e:
- log.exception("Error in delete_file", error=str(e))
+ except OpenRAGError:
raise
+ except Exception as e:
+ log.exception("Unexpected error in delete_file")
+ raise UnexpectedError("An unexpected error occurred during file deletion") from e
+
@ray.method(concurrency_group="update")
async def update_file_metadata(
self,
@@ -184,10 +198,14 @@ async def update_file_metadata(
await vectordb.async_add_documents.remote(docs, user=user)
log.info("Metadata updated for file.")
- except Exception as e:
- log.exception("Error in update_file_metadata", error=str(e))
+
+ except OpenRAGError:
raise
+ except Exception as e:
+ log.exception("Unexpected error in update_file_metadata")
+ raise UnexpectedError("An unexpected error occurred during metadata update") from e
+
@ray.method(concurrency_group="update")
async def copy_file(
self,
@@ -216,10 +234,14 @@ async def copy_file(
new_file_id=metadata.get("file_id"),
new_partition=metadata.get("partition"),
)
- except Exception as e:
- log.exception("Error in copy_file", error=str(e))
+
+ except OpenRAGError:
raise
+ except Exception as e:
+ log.exception("Unexpected error in copy_file")
+ raise UnexpectedError("An unexpected error occurred during file copy") from e
+
@ray.method(concurrency_group="search")
async def asearch(
self,
diff --git a/openrag/components/indexer/loaders/base.py b/openrag/components/indexer/loaders/base.py
index 7a9e4e40..7071a19e 100644
--- a/openrag/components/indexer/loaders/base.py
+++ b/openrag/components/indexer/loaders/base.py
@@ -116,8 +116,13 @@ async def get_image_description(
base64.b64decode(image_data)
image_url = f"data:image/png;base64,{image_data}"
logger.debug("Processing raw base64 string")
- except Exception:
- logger.error(f"Invalid image data type or format: {type(image_data)}")
+ except (ValueError, base64.binascii.Error) as e:
+ # Invalid base64 data
+ logger.warning("Failed to decode base64 image", error=str(e)[:100])
+ return """\n\nInvalid image data format\n\n"""
+ except Exception as e:
+ # PIL image opening errors or other unexpected issues
+ logger.warning("Failed to process image data", error=str(e)[:100])
return """\n\nInvalid image data format\n\n"""
else:
logger.error(f"Unsupported image data type: {type(image_data)}")
diff --git a/openrag/components/indexer/loaders/eml_loader.py b/openrag/components/indexer/loaders/eml_loader.py
index c23bcca7..742a0a2f 100644
--- a/openrag/components/indexer/loaders/eml_loader.py
+++ b/openrag/components/indexer/loaders/eml_loader.py
@@ -1,5 +1,6 @@
import datetime
import email
+import email.errors
import io
import os
import tempfile
@@ -7,7 +8,8 @@
from pathlib import Path
from langchain_core.documents.base import Document
-from PIL import Image
+from PIL import Image, UnidentifiedImageError
+from utils.exceptions.common import FileStorageError, UnexpectedError
from . import get_loader_classes
from .base import BaseLoader
@@ -52,7 +54,8 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar
if email_data["header"]["date"]:
try:
email_data["header"]["date"] = parsedate_to_datetime(email_data["header"]["date"]).isoformat()
- except Exception:
+ except (ValueError, TypeError, email.errors.MessageError):
+ # Invalid date format - keep original string
pass
# Extract body content and attachments
@@ -98,7 +101,8 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar
# Use plain text as primary body content
if content_type == "text/plain" or not body_content:
body_content = text_content
- except Exception as e:
+ except (UnicodeDecodeError, email.errors.MessageError) as e:
+ # Failed to decode email text part - skip this part
print(f"Failed to decode text content: {e}")
# Extract body content
@@ -140,7 +144,11 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar
metadata={"source": f"attachment:{filename}"},
)
attachments_text += f"Content:\n{attachment_doc.page_content}\n"
+ except OSError as e:
+ # File I/O error - skip this attachment
+ attachments_text += f"Cannot read attachment file: {str(e)[:200]}...\n"
except Exception as e:
+ # Loader-specific errors - try fallback loaders
attachments_text += f"Failed to process attachment with loader ({loader_cls.__name__}): {str(e)[:200]}...\n"
# Special fallback handling for PDFs with alternative loaders
@@ -179,7 +187,11 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar
attachments_text += f"Content (via {fallback_loader_name}):\n{attachment_doc.page_content}\n"
fallback_success = True
break
+ except OSError as fallback_e:
+ # File I/O error - skip fallback
+ attachments_text += f"Fallback {fallback_loader_name} file error: {str(fallback_e)[:100]}...\n"
except Exception as fallback_e:
+ # Fallback loader failed - try next
attachments_text += f"Fallback {fallback_loader_name} also failed: {str(fallback_e)[:100]}...\n"
if not fallback_success:
@@ -197,7 +209,11 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar
attachments_text += (
"Image attachment present but image captioning disabled\n"
)
+ except (OSError, UnidentifiedImageError) as img_e:
+ # Invalid or unreadable image
+ attachments_text += f"Image fallback failed (invalid image): {str(img_e)[:100]}...\n"
except Exception as img_e:
+ # Unexpected image processing error
attachments_text += f"Image fallback also failed: {str(img_e)[:100]}...\n"
# Try text fallback for other text-based formats
@@ -213,7 +229,11 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar
)
else:
attachments_text += "No readable text found in attachment\n"
+ except UnicodeDecodeError as text_e:
+ # Text decoding failed
+ attachments_text += f"Text fallback failed (encoding error): {str(text_e)[:100]}...\n"
except Exception as text_e:
+ # Unexpected text extraction error
attachments_text += f"Text fallback failed: {str(text_e)[:100]}...\n"
finally:
# Clean up temporary file
@@ -237,8 +257,9 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar
# Generate caption using the base loader's method
caption = await self.get_image_description(image_data=image)
attachments_text += f"Image Description:\n{caption}\n"
- except Exception as e:
- attachments_text += f"Failed to generate image caption: {str(e)[:200]}...\n"
+ except (OSError, UnidentifiedImageError) as e:
+ # Invalid or corrupted image
+ attachments_text += f"Failed to generate image caption (invalid image): {str(e)[:200]}...\n"
# Try to show basic image info if available
try:
size_info = f"Image size: {len(attachment['raw'])} bytes"
@@ -247,6 +268,9 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar
)
except Exception:
attachments_text += "Image attachment present but corrupted or unreadable\n"
+ except Exception as e:
+ # Unexpected image captioning error (VLM errors handled in base.py)
+ attachments_text += f"Failed to generate image caption: {str(e)[:200]}...\n"
elif content_type.startswith("text/"):
# For text attachments, decode directly
@@ -255,7 +279,11 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar
else:
# For other binary content, just show metadata
attachments_text += f"Binary content (size: {len(attachment['raw'])} bytes)\n"
+ except OSError as e:
+ # File I/O error creating temp file
+ attachments_text += f"Cannot create temp file for attachment: {e}\n"
except Exception as e:
+ # Unexpected error processing attachment
attachments_text += f"Content could not be processed: {e}\n"
attachments_text += "---\n"
@@ -298,8 +326,15 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar
with open(markdown_path, "w", encoding="utf-8") as md_file:
md_file.write(content_body)
metadata["markdown_path"] = str(markdown_path)
+ except OSError as e:
+ # File I/O error reading email file
+ raise FileStorageError(f"Cannot read email file: {e}") from e
+ except email.errors.MessageError as e:
+ # Email parsing error
+ raise UnexpectedError(f"Invalid email format: {e}") from e
except Exception as e:
- raise ValueError(f"Failed to parse the EML file {file_path}: {e}")
+ # Unexpected error
+ raise UnexpectedError(f"Failed to parse the EML file {file_path}: {e}") from e
document = Document(page_content=content_body, metadata=metadata)
return document
diff --git a/openrag/components/indexer/loaders/image.py b/openrag/components/indexer/loaders/image.py
index a1356c80..0bf7a3fe 100644
--- a/openrag/components/indexer/loaders/image.py
+++ b/openrag/components/indexer/loaders/image.py
@@ -3,7 +3,7 @@
import cairosvg
from langchain_core.documents import Document
-from PIL import Image
+from PIL import Image, UnidentifiedImageError
from utils.logger import get_logger
from .base import BaseLoader
@@ -29,7 +29,16 @@ async def aload_document(self, file_path, metadata=None, save_markdown=False):
img = Image.open(BytesIO(png_data))
else:
img = Image.open(path)
+ except OSError as e:
+ # File not found, permission denied, etc.
+ log.error("Cannot read image file", file_path=str(path), error=str(e))
+ raise ImageLoadError(f"Cannot read image file: {e}") from e
+ except UnidentifiedImageError as e:
+ # Invalid image format
+ log.error("Invalid image format", file_path=str(path), error=str(e))
+ raise ImageLoadError(f"Invalid image format: {e}") from e
except Exception as e:
+ # SVG conversion errors or other unexpected issues
log.error(
"Failed to load image file",
file_path=str(path),
diff --git a/openrag/components/indexer/loaders/media_loader.py b/openrag/components/indexer/loaders/media_loader.py
index 41c9f635..e90b4871 100644
--- a/openrag/components/indexer/loaders/media_loader.py
+++ b/openrag/components/indexer/loaders/media_loader.py
@@ -7,7 +7,7 @@
import numpy as np
from components.utils import get_audio_semaphore
from langchain_core.documents.base import Document
-from openai import AsyncOpenAI
+from openai import APIError, AsyncOpenAI
from pydub import AudioSegment, silence
from tqdm.asyncio import tqdm
from utils.logger import get_logger
@@ -58,8 +58,13 @@ async def _process_chunk(self, index: int, segment: AudioSegment, wav_path: Path
try:
result = await self._transcribe_chunk(tmp_path, language)
return result
+ except APIError as e:
+ # OpenAI API errors - gracefully degrade by returning empty transcript
+ logger.warning("Audio transcription API error", chunk=tmp_path.name, error=str(e)[:200])
+ return ""
except Exception as e:
- logger.exception(f"Error transcribing chunk {tmp_path.name}", error=str(e))
+ # Unexpected errors - log and return empty transcript for graceful degradation
+ logger.exception("Error transcribing chunk", chunk=tmp_path.name, error=str(e))
return ""
finally:
tmp_path.unlink(missing_ok=True)
@@ -77,8 +82,13 @@ async def _transcribe_chunk(self, wav_path: Path, language: str = None) -> str:
kwargs["language"] = language
result = await self.client.audio.transcriptions.create(**kwargs)
return result.text.strip()
+ except APIError as e:
+ # OpenAI API errors - gracefully degrade by returning empty transcript
+ logger.warning("Audio transcription API error", chunk=wav_path.name, error=str(e)[:200])
+ return ""
except Exception as e:
- logger.exception(f"Error transcribing chunk {wav_path.name}", error=str(e))
+ # Unexpected errors - log and return empty transcript for graceful degradation
+ logger.exception("Error transcribing chunk", chunk=wav_path.name, error=str(e))
return ""
async def _get_audio_chunks(self, sound: AudioSegment) -> list[AudioSegment]:
@@ -122,7 +132,12 @@ async def _detect_language(self, sound: AudioSegment, wav_path, fallback_languag
return fallback_language # Fallback to English
try:
return langdetect.detect(text)
+ except langdetect.LangDetectException as e:
+ # Expected failure for non-textual content or too-short text
+ logger.warning("Language detection failed", error=str(e))
+ return fallback_language
except Exception as e:
+ # Unexpected language detection errors
logger.exception("Language detection failed", error=str(e))
return fallback_language
finally:
diff --git a/openrag/components/indexer/loaders/pdf_loaders/marker.py b/openrag/components/indexer/loaders/pdf_loaders/marker.py
index 7e5561aa..0e095f77 100644
--- a/openrag/components/indexer/loaders/pdf_loaders/marker.py
+++ b/openrag/components/indexer/loaders/pdf_loaders/marker.py
@@ -9,6 +9,7 @@
from config import load_config
from langchain_core.documents.base import Document
from marker.converters.pdf import PdfConverter
+from utils.exceptions.common import FileStorageError, UnexpectedError
from utils.logger import get_logger
from ..base import BaseLoader
@@ -102,9 +103,16 @@ def _process_pdf(file_path, config):
)
render = converter(file_path)
return render
+ except asyncio.CancelledError:
+ # Cancellation request - propagate immediately
+ logger.info("PDF processing cancelled", path=file_path)
+ raise
+ except OSError as e:
+ logger.error("Cannot read PDF file", path=file_path, error=str(e))
+ raise FileStorageError(f"Cannot read PDF file: {e}") from e
except Exception as e:
logger.exception("Error processing PDF", path=file_path, error=str(e))
- raise
+ raise UnexpectedError("Failed to process PDF document") from e
finally:
gc.collect()
if torch.cuda.is_available():
@@ -124,6 +132,10 @@ def run_with_timeout():
return result
except MPTimeoutError:
self.logger.exception("MarkerWorker child process timed out", path=file_path)
+ raise UnexpectedError("PDF processing timed out")
+ except asyncio.CancelledError:
+ # Cancellation - propagate
+ self.logger.info("PDF processing cancelled", path=file_path)
raise
except Exception:
self.logger.exception("Error processing with MarkerWorker", path=file_path)
@@ -229,7 +241,7 @@ async def aload_document(
)
if not markdown:
- raise RuntimeError(f"Conversion failed for {file_path_str}")
+ raise UnexpectedError(f"Conversion failed for {file_path_str}")
if self.image_captioning:
keys = list(images.keys())
@@ -253,6 +265,14 @@ async def aload_document(
logger.info(f"Processed {file_path_str} in {duration:.2f}s")
return doc
+ except asyncio.CancelledError:
+ # Cancellation - propagate immediately (MUST be first)
+ logger.info("PDF loading cancelled", path=file_path_str)
+ raise
+ except OSError as e:
+ logger.error("Cannot read PDF file", path=file_path_str, error=str(e))
+ raise FileStorageError(f"Cannot read PDF file: {e}") from e
except Exception:
+ # Ray actor errors or PDF processing failures
logger.exception("Error in aload_document", path=file_path_str)
raise
diff --git a/openrag/components/indexer/loaders/pptx_loader.py b/openrag/components/indexer/loaders/pptx_loader.py
index 69fa45be..ea46c18f 100644
--- a/openrag/components/indexer/loaders/pptx_loader.py
+++ b/openrag/components/indexer/loaders/pptx_loader.py
@@ -125,11 +125,20 @@ def _convert_chart_to_markdown(self, chart):
separator = "|" + "|".join(["---"] * len(data[0])) + "|"
return md + "\n".join([header, separator] + markdown_table[1:])
except ValueError as e:
- # Handle the specific error for unsupported chart types
+ # Handle unsupported chart types (expected error)
if "unsupported plot type" in str(e):
+ logger.debug("Unsupported chart type encountered")
return "\n\n[unsupported chart]\n\n"
- except Exception:
- # Catch any other exceptions that might occur
+ # Other ValueError - log and return placeholder
+ logger.warning("Chart conversion value error", error=str(e))
+ return "\n\n[unsupported chart]\n\n"
+ except (AttributeError, IndexError) as e:
+ # Missing chart data or unexpected structure
+ logger.warning("Chart structure error", error=str(e))
+ return "\n\n[unsupported chart]\n\n"
+ except Exception as e:
+ # Unexpected errors - log and gracefully degrade
+ logger.warning("Chart conversion failed", error=str(e))
return "\n\n[unsupported chart]\n\n"
diff --git a/openrag/components/indexer/loaders/serializer.py b/openrag/components/indexer/loaders/serializer.py
index 9f687a86..e620f5c6 100644
--- a/openrag/components/indexer/loaders/serializer.py
+++ b/openrag/components/indexer/loaders/serializer.py
@@ -5,6 +5,7 @@
import torch
from config import load_config
from langchain_core.documents.base import Document
+from utils.exceptions.common import FileStorageError, UnexpectedError
from . import get_loader_classes
@@ -84,6 +85,9 @@ async def serialize_document(
torch.cuda.ipc_collect()
log.info("Document serialized successfully")
return doc
+ except OSError as e:
+ log.error("File operation failed during serialization", path=str(path), error=str(e))
+ raise FileStorageError(f"Cannot read file: {e}") from e
except Exception as e:
- log.exception("Failed to serialize document", error=str(e))
- raise
+ log.exception("Failed to serialize document", path=str(path), file_type=file_ext, error=str(e))
+ raise UnexpectedError("Failed to serialize document: unsupported format or corrupted file") from e
diff --git a/openrag/components/indexer/vectordb/utils.py b/openrag/components/indexer/vectordb/utils.py
index a7717b39..9137ff78 100644
--- a/openrag/components/indexer/vectordb/utils.py
+++ b/openrag/components/indexer/vectordb/utils.py
@@ -97,6 +97,24 @@ class User(Base):
memberships = relationship("PartitionMembership", back_populates="user", cascade="all, delete-orphan")
+class FileDomain(Base):
+ __tablename__ = "file_domains"
+
+ id = Column(Integer, primary_key=True)
+ file_id = Column(String, nullable=False)
+ partition_name = Column(
+ String,
+ ForeignKey("partitions.partition", ondelete="CASCADE"),
+ nullable=False,
+ )
+ domain = Column(String, nullable=False)
+
+ __table_args__ = (
+ UniqueConstraint("file_id", "partition_name", "domain", name="uix_file_domain"),
+ Index("ix_partition_domain", "partition_name", "domain"),
+ )
+
+
class PartitionMembership(Base):
__tablename__ = "partition_memberships"
@@ -134,10 +152,10 @@ def __init__(self, database_url: str, logger=logger):
AUTH_TOKEN = os.getenv("AUTH_TOKEN")
self._ensure_admin_user(AUTH_TOKEN)
- except Exception as e:
+ except Exception:
raise VDBConnectionError(
- f"Failed to connect to database: {e!s}",
- db_url=database_url,
+ "An unexpected database error occurred",
+ db_url=str(database_url),
db_type="SQLAlchemy",
)
@@ -225,10 +243,14 @@ def add_file_to_partition(
session.commit()
log.info("Added file successfully")
return True
- except Exception:
+ except Exception as e:
session.rollback()
- log.exception("Error adding file to partition")
- raise
+ log.exception("Error adding file to partition", error=str(e))
+ raise VDBInsertError(
+ "An unexpected database error occurred",
+ file_id=file_id,
+ partition=partition,
+ )
def remove_file_from_partition(self, file_id: str, partition: str):
"""Remove a file from its partition - Optimized without join"""
@@ -246,8 +268,12 @@ def remove_file_from_partition(self, file_id: str, partition: str):
return False
except Exception as e:
session.rollback()
- log.error(f"Error removing file: {e}")
- raise e
+ log.exception("Error removing file", error=str(e))
+ raise VDBDeleteError(
+ "An unexpected database error occurred",
+ file_id=file_id,
+ partition=partition,
+ )
def delete_partition(self, partition: str):
"""Delete a partition and all its files"""
@@ -293,6 +319,58 @@ def file_exists_in_partition(self, file_id: str, partition: str):
session.query(File).filter(File.file_id == file_id, File.partition_name == partition).exists()
).scalar()
+ # Domains
+
+ def get_file_ids_by_domains(self, partition: str | None, domains: list[str]) -> list[str]:
+ """Get file_ids matching ANY of the given domains in a partition (or all partitions if None)."""
+ with self.Session() as session:
+ query = session.query(FileDomain.file_id).filter(FileDomain.domain.in_(domains))
+ if partition is not None:
+ query = query.filter(FileDomain.partition_name == partition)
+ rows = query.distinct().all()
+ return [row[0] for row in rows]
+
+ def set_file_domains(self, file_id: str, partition: str, domains: list[str]):
+ """Replace all domains for a file with the given list."""
+ with self.Session() as session:
+ try:
+ session.query(FileDomain).filter(
+ FileDomain.file_id == file_id,
+ FileDomain.partition_name == partition,
+ ).delete()
+ for domain in domains:
+ session.add(FileDomain(file_id=file_id, partition_name=partition, domain=domain))
+ session.commit()
+ except Exception as e:
+ session.rollback()
+ self.logger.exception("Error setting file domains", error=str(e))
+ raise VDBInsertError(
+ "An unexpected database error occurred",
+ file_id=file_id,
+ partition=partition,
+ )
+
+ def get_file_domains(self, file_id: str, partition: str) -> list[str]:
+ """Get domains for a file in a partition."""
+ with self.Session() as session:
+ rows = (
+ session.query(FileDomain.domain)
+ .filter(FileDomain.file_id == file_id, FileDomain.partition_name == partition)
+ .all()
+ )
+ return [row[0] for row in rows]
+
+ def list_partition_domains(self, partition: str) -> list[str]:
+ """List all unique domains in a partition."""
+ with self.Session() as session:
+ rows = (
+ session.query(FileDomain.domain)
+ .filter(FileDomain.partition_name == partition)
+ .distinct()
+ .all()
+ )
+ return [row[0] for row in rows]
+
# Users
def create_user(
diff --git a/openrag/components/indexer/vectordb/vectordb.py b/openrag/components/indexer/vectordb/vectordb.py
index 291f8ac0..666e6931 100644
--- a/openrag/components/indexer/vectordb/vectordb.py
+++ b/openrag/components/indexer/vectordb/vectordb.py
@@ -16,6 +16,7 @@
MilvusException,
RRFRanker,
)
+from sqlalchemy import URL
from utils.exceptions.base import EmbeddingError
from utils.exceptions.vectordb import *
from utils.logger import get_logger
@@ -177,7 +178,7 @@ def __init__(self):
except Exception as e:
self.logger.exception("Unexpected error initializing Milvus clients", error=str(e))
raise VDBConnectionError(
- f"Unexpected error initializing Milvus clients: {e!s}",
+ "An unexpected database error occurred",
db_url=uri,
db_type="Milvus",
)
@@ -225,8 +226,16 @@ def load_collection(self):
operation="load_collection",
)
+ database_url = URL.create(
+ drivername="postgresql",
+ username=self.rdb_user,
+ password=self.rdb_password,
+ host=self.rdb_host,
+ port=self.rdb_port,
+ database=f"partitions_for_collection_{self.collection_name}",
+ )
self.partition_file_manager = PartitionFileManager(
- database_url=f"postgresql://{self.rdb_user}:{self.rdb_password}@{self.rdb_host}:{self.rdb_port}/partitions_for_collection_{self.collection_name}",
+ database_url=database_url,
logger=self.logger,
)
self.logger.info("Milvus collection loaded.")
@@ -238,7 +247,7 @@ def load_collection(self):
error=str(e),
)
raise UnexpectedVDBError(
- f"Unexpected error setting collection name `{self.collection_name}`: {e!s}",
+ "An unexpected database error occurred",
collection_name=self.collection_name,
)
@@ -361,6 +370,9 @@ async def async_add_documents(self, chunks: list[Document], user: dict) -> None:
file_id=file_id,
)
+ # Extract domains before inserting into Milvus (not stored as chunk metadata)
+ domains = file_metadata.pop("domains", None)
+
entities = []
vectors = await self.embedder.aembed_documents(chunks)
order_metadata_l: list[dict] = _gen_chunk_order_metadata(n=len(chunks))
@@ -386,6 +398,10 @@ async def async_add_documents(self, chunks: list[Document], user: dict) -> None:
file_metadata=file_metadata,
user_id=user.get("id"),
)
+
+ if domains:
+ self.partition_file_manager.set_file_domains(file_id, partition, domains)
+
self.logger.info(f"File '{file_id}' added to partition '{partition}'")
except EmbeddingError as e:
self.logger.exception("Embedding failed", error=str(e))
@@ -394,10 +410,16 @@ async def async_add_documents(self, chunks: list[Document], user: dict) -> None:
self.logger.exception("VectorDB operation failed", error=str(e))
raise
+ except MilvusException as e:
+ self.logger.exception("Milvus insert operation failed", error=str(e))
+ raise VDBInsertError(
+ "Failed to insert document into collection",
+ collection_name=self.collection_name,
+ )
except Exception as e:
self.logger.exception("Unexpected error while adding a document", error=str(e))
raise UnexpectedVDBError(
- f"Unexpected error while adding a document: {e!s}",
+ "An unexpected database error occurred",
collection_name=self.collection_name,
)
@@ -431,6 +453,8 @@ async def async_multi_query_search(
retrieved_chunks[document.metadata["_id"]] = document
return list(retrieved_chunks.values())
+ DOMAIN_BATCH_SIZE = 1000
+
async def async_search(
self,
query: str,
@@ -439,6 +463,51 @@ async def async_search(
partition: list[str] = None,
filter: dict | None = None,
with_surrounding_chunks: bool = False,
+ ) -> list[Document]:
+ filter = dict(filter) if filter else {}
+
+ # Resolve domains to file_ids via PostgreSQL
+ if "domains" in filter:
+ domains = filter.pop("domains")
+ partition_for_lookup = partition[0] if partition != ["all"] else None
+ file_ids = self.partition_file_manager.get_file_ids_by_domains(partition_for_lookup, domains)
+ if not file_ids:
+ return []
+
+ if len(file_ids) <= self.DOMAIN_BATCH_SIZE:
+ filter["file_id"] = file_ids
+ else:
+ # Split into batches, run parallel searches, merge results
+ batches = [
+ file_ids[i : i + self.DOMAIN_BATCH_SIZE] for i in range(0, len(file_ids), self.DOMAIN_BATCH_SIZE)
+ ]
+ tasks = []
+ for batch in batches:
+ batch_filter = {**filter, "file_id": batch}
+ tasks.append(
+ self._search_with_filter(
+ query, top_k, similarity_threshold, partition, batch_filter, with_surrounding_chunks
+ )
+ )
+ batch_results = await asyncio.gather(*tasks)
+ merged = {}
+ for results in batch_results:
+ for doc in results:
+ merged[doc.metadata["_id"]] = doc
+ return sorted(merged.values(), key=lambda d: d.metadata.get("_id", 0), reverse=True)[:top_k]
+
+ return await self._search_with_filter(
+ query, top_k, similarity_threshold, partition, filter, with_surrounding_chunks
+ )
+
+ async def _search_with_filter(
+ self,
+ query: str,
+ top_k: int,
+ similarity_threshold: float,
+ partition: list[str],
+ filter: dict,
+ with_surrounding_chunks: bool,
) -> list[Document]:
expr_parts = []
if partition != ["all"]:
@@ -446,7 +515,11 @@ async def async_search(
if filter:
for key, value in filter.items():
- expr_parts.append(f"{key} == '{value}'")
+ if isinstance(value, list):
+ formatted = ", ".join(f"'{v}'" for v in value)
+ expr_parts.append(f"{key} in [{formatted}]")
+ else:
+ expr_parts.append(f"{key} == '{value}'")
# Join all parts with " and " only if there are multiple conditions
expr = " and ".join(expr_parts) if expr_parts else ""
@@ -520,7 +593,7 @@ async def async_search(
except Exception as e:
self.logger.exception("Unexpected error occurred", error=str(e))
raise UnexpectedVDBError(
- f"Unexpected error occurred: {e!s}",
+ "An unexpected database error occurred",
collection_name=self.collection_name,
partition=partition,
)
@@ -594,7 +667,7 @@ async def delete_file(self, file_id: str, partition: str):
except Exception as e:
log.exception("Unexpected error while deleting file chunks", error=str(e))
raise UnexpectedVDBError(
- f"Unexpected error while deleting file chunks {file_id}: {e!s}",
+ "An unexpected database error occurred",
collection_name=self.collection_name,
partition=partition,
file_id=file_id,
@@ -651,8 +724,8 @@ async def get_file_chunks(self, file_id: str, partition: str, include_id: bool =
except Exception as e:
log.exception("Unexpected error while getting file chunks", error=str(e))
- raise VDBSearchError(
- f"Unexpected error while getting file chunks {file_id}: {e!s}",
+ raise UnexpectedVDBError(
+ "An unexpected database error occurred",
collection_name=self.collection_name,
partition=partition,
file_id=file_id,
@@ -696,7 +769,7 @@ async def get_chunk_by_id(self, chunk_id: str):
except Exception as e:
log.exception("Unexpected error while retrieving chunk", error=str(e))
raise UnexpectedVDBError(
- f"Unexpected error while retrieving chunk {chunk_id}: {e!s}",
+ "An unexpected database error occurred",
collection_name=self.collection_name,
)
@@ -706,6 +779,8 @@ def file_exists(self, file_id: str, partition: str):
"""
try:
return self.partition_file_manager.file_exists_in_partition(file_id=file_id, partition=partition)
+ except VDBError:
+ raise
except Exception as e:
self.logger.exception(
"File existence check failed.",
@@ -713,7 +788,12 @@ def file_exists(self, file_id: str, partition: str):
partition=partition,
error=str(e),
)
- return False
+ raise UnexpectedVDBError(
+ "An unexpected database error occurred",
+ collection_name=self.collection_name,
+ partition=partition,
+ file_id=file_id,
+ )
def list_partition_files(self, partition: str, limit: int | None = None):
try:
@@ -729,7 +809,7 @@ def list_partition_files(self, partition: str, limit: int | None = None):
error=str(e),
)
raise UnexpectedVDBError(
- f"Unexpected error while listing files in partition {partition}: {e!s}",
+ "An unexpected database error occurred",
collection_name=self.collection_name,
partition=partition,
)
@@ -737,9 +817,14 @@ def list_partition_files(self, partition: str, limit: int | None = None):
def list_partitions(self):
try:
return self.partition_file_manager.list_partitions()
+ except VDBError:
+ raise
except Exception as e:
self.logger.exception("Failed to list partitions", error=str(e))
- raise
+ raise UnexpectedVDBError(
+ "An unexpected database error occurred",
+ collection_name=self.collection_name,
+ )
def collection_exists(self, collection_name: str):
"""
@@ -773,7 +858,7 @@ async def delete_partition(self, partition: str):
except Exception as e:
log.exception("Unexpected error while deleting partition", error=str(e))
raise UnexpectedVDBError(
- f"Unexpected error while deleting partition {partition}: {e!s}",
+ "An unexpected database error occurred",
collection_name=self.collection_name,
partition=partition,
)
@@ -785,9 +870,15 @@ def partition_exists(self, partition: str):
log = self.logger.bind(partition=partition)
try:
return self.partition_file_manager.partition_exists(partition=partition)
+ except VDBError:
+ raise
except Exception as e:
log.exception("Partition existence check failed.", error=str(e))
- return False
+ raise UnexpectedVDBError(
+ "An unexpected database error occurred",
+ collection_name=self.collection_name,
+ partition=partition,
+ )
async def list_all_chunk(self, partition: str, include_embedding: bool = True):
"""
@@ -855,7 +946,7 @@ def prepare_metadata(res: dict):
error=str(e),
)
raise UnexpectedVDBError(
- f"Unexpected error while listing all chunks in partition {partition}: {e!s}",
+ "An unexpected database error occurred",
collection_name=self.collection_name,
partition=partition,
)
@@ -920,6 +1011,20 @@ async def remove_partition_member(self, partition: str, user_id: int) -> bool:
self.partition_file_manager.remove_partition_member(partition, user_id)
self.logger.info(f"User_id {user_id} removed from partition '{partition}'.")
+ # Domain management (pure PostgreSQL, no Milvus interaction)
+
+ async def set_file_domains(self, file_id: str, partition: str, domains: list[str]):
+ self._check_file_exists(file_id, partition)
+ self.partition_file_manager.set_file_domains(file_id, partition, domains)
+
+ async def get_file_domains(self, file_id: str, partition: str) -> list[str]:
+ self._check_file_exists(file_id, partition)
+ return self.partition_file_manager.get_file_domains(file_id, partition)
+
+ async def list_partition_domains(self, partition: str) -> list[str]:
+ self._check_partition_exists(partition)
+ return self.partition_file_manager.list_partition_domains(partition)
+
def _check_user_exists(self, user_id: int):
if not self.partition_file_manager.user_exists(user_id):
self.logger.warning(f"User with ID {user_id} does not exist.")
diff --git a/openrag/components/llm.py b/openrag/components/llm.py
index 634663c4..0d06b549 100644
--- a/openrag/components/llm.py
+++ b/openrag/components/llm.py
@@ -1,7 +1,9 @@
+import asyncio
import copy
import json
import httpx
+from utils.exceptions.common import UnexpectedError
from utils.logger import get_logger
logger = get_logger()
@@ -37,10 +39,10 @@ async def completions(self, request: dict):
data = response.json()
yield data
except httpx.HTTPStatusError as e:
- error_detail = e.response.text
- raise ValueError(f"LLM API error ({e.response.status_code}): {error_detail}")
+ logger.error("LLM API returned error", status_code=e.response.status_code)
+ raise UnexpectedError(f"LLM API error ({e.response.status_code})") from e
except json.JSONDecodeError as e:
- raise ValueError(f"Invalid JSON in API response: {str(e)}")
+ raise UnexpectedError("Invalid JSON in LLM API response") from e
async def chat_completion(self, request: dict):
request.pop("model")
@@ -58,17 +60,28 @@ async def chat_completion(self, request: dict):
headers=self.headers,
json=payload,
) as response:
- if response.status_code >= 400:
- await response.aread()
- error_detail = response.text
- raise ValueError(f"LLM API error ({response.status_code}): {error_detail}")
+ response.raise_for_status()
async for line in response.aiter_lines():
yield line
- except ValueError:
+
+ except asyncio.CancelledError:
+ # MUST be first per Phase 2 decision
+ logger.info("LLM streaming cancelled by client")
raise
+
+ except httpx.HTTPStatusError as e:
+ # 4xx/5xx responses
+ logger.error("LLM API returned error", status_code=e.response.status_code)
+ raise UnexpectedError(f"LLM API error ({e.response.status_code})") from e
+
+ except httpx.RequestError as e:
+ # Network/connection failures
+ logger.error("Network error during LLM streaming", error=str(e))
+ raise UnexpectedError("Network error during LLM streaming") from e
+
except Exception as e:
- logger.error(f"Error while streaming chat completion: {str(e)}")
- raise
+ logger.exception("Unexpected error during LLM streaming")
+ raise UnexpectedError("An unexpected error occurred during streaming") from e
else: # Handle non-streaming response
try:
@@ -81,7 +94,7 @@ async def chat_completion(self, request: dict):
data = response.json()
yield data
except httpx.HTTPStatusError as e:
- error_detail = e.response.text
- raise ValueError(f"LLM API error ({e.response.status_code}): {error_detail}")
+ logger.error("LLM API returned error", status_code=e.response.status_code)
+ raise UnexpectedError(f"LLM API error ({e.response.status_code})") from e
except json.JSONDecodeError as e:
- raise ValueError(f"Invalid JSON in API response: {str(e)}")
+ raise UnexpectedError("Invalid JSON in LLM API response") from e
diff --git a/openrag/components/map_reduce.py b/openrag/components/map_reduce.py
index e9f89f34..9572a7c4 100644
--- a/openrag/components/map_reduce.py
+++ b/openrag/components/map_reduce.py
@@ -80,7 +80,8 @@ async def infer_chunk_relevancy(self, query, chunk: Document) -> SummarizedChunk
)
return output_chunk
except Exception as e:
- logger.error("Error during chunk relevancy inference", error=str(e))
+ # Graceful degradation - mark chunk as irrelevant on error
+ logger.warning("Failed to infer chunk relevancy", chunk_id=chunk.metadata.get("id"), error=str(e)[:200])
return SummarizedChunk(relevancy=False, summary="")
async def map_batch(
diff --git a/openrag/components/pipeline.py b/openrag/components/pipeline.py
index 3c20a3fc..36af5ae9 100644
--- a/openrag/components/pipeline.py
+++ b/openrag/components/pipeline.py
@@ -44,8 +44,10 @@ def __init__(self, config) -> None:
if self.reranker_enabled:
self.reranker = Reranker(logger, config)
- async def retrieve_docs(self, partition: list[str], query: str, use_map_reduce: bool = False) -> list[Document]:
- docs = await self.retriever.retrieve(partition=partition, query=query)
+ async def retrieve_docs(
+ self, partition: list[str], query: str, use_map_reduce: bool = False, filter: dict | None = None
+ ) -> list[Document]:
+ docs = await self.retriever.retrieve(partition=partition, query=query, filter=filter)
top_k = max(self.map_reduce_max_docs, self.reranker_top_k) if use_map_reduce else self.reranker_top_k
logger.debug("Documents retreived", document_count=len(docs))
if docs:
@@ -117,20 +119,25 @@ async def _prepare_for_chat_completion(self, partition: list[str], payload: dict
query = await self.generate_query(messages)
logger.debug("Prepared query for chat completion", query=query)
- metadata = payload.get("metadata", {})
+ metadata = payload.get("metadata", {}) or {}
use_map_reduce = metadata.get("use_map_reduce", False)
spoken_style_answer = metadata.get("spoken_style_answer", False)
+ domains = metadata.get("domains")
logger.debug(
"Metadata parameters",
use_map_reduce=use_map_reduce,
spoken_style_answer=spoken_style_answer,
+ domains=domains,
)
+ # Build filter from metadata
+ search_filter = {"domains": domains} if domains else None
+
# 2. get docs
docs = await self.retriever_pipeline.retrieve_docs(
- partition=partition, query=query, use_map_reduce=use_map_reduce
+ partition=partition, query=query, use_map_reduce=use_map_reduce, filter=search_filter
)
if use_map_reduce and docs:
@@ -157,11 +164,14 @@ async def _prepare_for_chat_completion(self, partition: list[str], payload: dict
async def _prepare_for_completions(self, partition: list[str], payload: dict):
prompt = payload["prompt"]
+ metadata = payload.get("metadata", {}) or {}
+ domains = metadata.get("domains")
+ search_filter = {"domains": domains} if domains else None
# 1. get the query
query = await self.generate_query(messages=[{"role": "user", "content": prompt}])
# 2. get docs
- docs = await self.retriever_pipeline.retrieve_docs(partition=partition, query=query)
+ docs = await self.retriever_pipeline.retrieve_docs(partition=partition, query=query, filter=search_filter)
# 3. Format the retrieved docs
context = format_context(docs, max_context_tokens=self.max_context_tokens)
@@ -178,25 +188,17 @@ async def _prepare_for_completions(self, partition: list[str], payload: dict):
return payload, docs
async def completions(self, partition: list[str], payload: dict):
- try:
- if partition is None:
- docs = []
- else:
- payload, docs = await self._prepare_for_completions(partition=partition, payload=payload)
- llm_output = self.llm_client.completions(request=payload)
- return llm_output, docs
- except Exception as e:
- logger.error(f"Error during chat completion: {e!s}")
- raise e
+ if partition is None:
+ docs = []
+ else:
+ payload, docs = await self._prepare_for_completions(partition=partition, payload=payload)
+ llm_output = self.llm_client.completions(request=payload)
+ return llm_output, docs
async def chat_completion(self, partition: list[str] | None, payload: dict):
- try:
- if partition is None:
- docs = []
- else:
- payload, docs = await self._prepare_for_chat_completion(partition=partition, payload=payload)
- llm_output = self.llm_client.chat_completion(request=payload)
- return llm_output, docs
- except Exception as e:
- logger.error(f"Error during chat completion: {e!s}")
- raise e
+ if partition is None:
+ docs = []
+ else:
+ payload, docs = await self._prepare_for_chat_completion(partition=partition, payload=payload)
+ llm_output = self.llm_client.chat_completion(request=payload)
+ return llm_output, docs
diff --git a/openrag/components/ray_utils.py b/openrag/components/ray_utils.py
index 9d1c31a5..06bba133 100644
--- a/openrag/components/ray_utils.py
+++ b/openrag/components/ray_utils.py
@@ -3,6 +3,7 @@
import ray
from ray.exceptions import RayTaskError, TaskCancelledError
+from utils.exceptions.common import RayActorError
from utils.logger import get_logger
logger = get_logger()
@@ -33,7 +34,7 @@ async def call_ray_actor_with_timeout(
TimeoutError: If the task exceeds the timeout
asyncio.CancelledError: If the calling coroutine is cancelled
TaskCancelledError: If the Ray task was cancelled
- RuntimeError: If the Ray task failed with an error
+ RayActorError: If the Ray task failed with an error
"""
try:
result = await asyncio.wait_for(asyncio.gather(future), timeout=timeout)
@@ -54,4 +55,4 @@ async def call_ray_actor_with_timeout(
raise
except RayTaskError as e:
- raise RuntimeError(f"{task_description} failed") from e
+ raise RayActorError(f"{task_description} failed") from e
diff --git a/openrag/components/reranker.py b/openrag/components/reranker.py
index 3768720f..3e19c315 100644
--- a/openrag/components/reranker.py
+++ b/openrag/components/reranker.py
@@ -4,6 +4,7 @@
from infinity_client.api.default import rerank
from infinity_client.models import RerankInput, ReRankResult
from langchain_core.documents.base import Document
+from utils.exceptions.common import UnexpectedError
class Reranker:
@@ -38,10 +39,5 @@ async def rerank(self, query: str, documents: list[Document], top_k: int) -> lis
return output
except Exception as e:
- self.logger.error(
- "Reranking failed",
- error=str(e),
- model_name=self.model_name,
- documents_count=len(documents),
- )
- raise e
+ self.logger.exception("Reranking failed", query=query[:100], doc_count=len(documents))
+ raise UnexpectedError("An unexpected error occurred during reranking") from e
diff --git a/openrag/components/retriever.py b/openrag/components/retriever.py
index 3c5caf32..8bf2c140 100644
--- a/openrag/components/retriever.py
+++ b/openrag/components/retriever.py
@@ -27,7 +27,7 @@ def __init__(
pass
@abstractmethod
- async def retrieve(self, partition: list[str], query: str) -> list[Document]:
+ async def retrieve(self, partition: list[str], query: str, filter: dict | None = None) -> list[Document]:
pass
@@ -43,6 +43,7 @@ async def retrieve(
self,
partition: list[str],
query: str,
+ filter: dict | None = None,
) -> list[Document]:
db = get_vectordb()
chunks = await db.async_search.remote(
@@ -50,6 +51,7 @@ async def retrieve(
partition=partition,
top_k=self.top_k,
similarity_threshold=self.similarity_threshold,
+ filter=filter,
with_surrounding_chunks=self.with_surrounding_chunks,
)
return chunks
@@ -78,7 +80,7 @@ def __init__(
prompt: ChatPromptTemplate = ChatPromptTemplate.from_template(MULTI_QUERY_PROMPT)
self.generate_queries = prompt | llm | StrOutputParser() | (lambda x: x.split("[SEP]"))
- async def retrieve(self, partition: list[str], query: str) -> list[Document]:
+ async def retrieve(self, partition: list[str], query: str, filter: dict | None = None) -> list[Document]:
db = get_vectordb()
logger.debug("Generating multiple queries", k_queries=self.k_queries)
generated_queries = await self.generate_queries.ainvoke(
@@ -92,6 +94,7 @@ async def retrieve(self, partition: list[str], query: str) -> list[Document]:
partition=partition,
top_k_per_query=self.top_k,
similarity_threshold=self.similarity_threshold,
+ filter=filter,
with_surrounding_chunks=self.with_surrounding_chunks,
)
return chunks
@@ -121,7 +124,7 @@ async def get_hyde(self, query: str):
hyde_document = await self.hyde_generator.ainvoke({"query": query})
return hyde_document
- async def retrieve(self, partition: list[str], query: str) -> list[Document]:
+ async def retrieve(self, partition: list[str], query: str, filter: dict | None = None) -> list[Document]:
db = get_vectordb()
hyde = await self.get_hyde(query)
queries = [hyde]
@@ -133,6 +136,7 @@ async def retrieve(self, partition: list[str], query: str) -> list[Document]:
partition=partition,
top_k_per_query=self.top_k,
similarity_threshold=self.similarity_threshold,
+ filter=filter,
with_surrounding_chunks=self.with_surrounding_chunks,
)
diff --git a/openrag/models/indexer.py b/openrag/models/indexer.py
index 61a14227..8dfccf81 100644
--- a/openrag/models/indexer.py
+++ b/openrag/models/indexer.py
@@ -1,6 +1,35 @@
-from pydantic import BaseModel
+from pydantic import BaseModel, Field, field_validator
class SearchRequest(BaseModel):
query: str
top_k: int | None = 5 # default to 5 if not provided
+
+
+class FileMetadataSchema(BaseModel):
+ """Schema for validating file upload metadata.
+
+ Metadata is passed as JSON in the file upload form and contains
+ optional file processing hints and domain filtering configuration.
+ """
+
+ mimetype: str | None = None
+ domains: list[str] = Field(default_factory=list)
+
+ # Allow additional fields for backward compatibility
+ # (existing code may pass extra fields we don't validate)
+ model_config = {"extra": "allow"}
+
+ @field_validator("domains")
+ @classmethod
+ def validate_domains(cls, v):
+ """Ensure domains is a list of non-empty strings."""
+ if v:
+ if not isinstance(v, list):
+ raise ValueError("domains must be a list")
+ for domain in v:
+ if not isinstance(domain, str):
+ raise ValueError("All domains must be strings")
+ if not domain.strip():
+ raise ValueError("Domain names cannot be empty")
+ return v
diff --git a/openrag/routers/actors.py b/openrag/routers/actors.py
index 49ab1dfa..9f30b884 100644
--- a/openrag/routers/actors.py
+++ b/openrag/routers/actors.py
@@ -51,24 +51,17 @@
)
async def list_ray_actors():
"""List all known Ray actors and their status."""
- try:
- actors = [
- {
- "actor_id": a.actor_id,
- "name": a.name,
- "class_name": a.class_name,
- "state": a.state,
- "namespace": a.ray_namespace,
- }
- for a in list_actors()
- ]
- return JSONResponse(status_code=status.HTTP_200_OK, content={"actors": actors})
- except Exception:
- logger.exception("Error getting actor summaries")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to retrieve actor summaries.",
- )
+ actors = [
+ {
+ "actor_id": a.actor_id,
+ "name": a.name,
+ "class_name": a.class_name,
+ "state": a.state,
+ "namespace": a.ray_namespace,
+ }
+ for a in list_actors()
+ ]
+ return JSONResponse(status_code=status.HTTP_200_OK, content={"actors": actors})
@router.post(
@@ -116,29 +109,16 @@ async def restart_actor(
logger.info(f"Killed actor: {actor_name}")
except ValueError:
logger.warning("Actor not found. Creating new instance.", actor=actor_name)
- except Exception as e:
- logger.exception("Failed to kill actor", actor=actor_name)
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=f"Failed to kill actor {actor_name}: {e!s}",
- )
- try:
- new_actor = actor_creation_map[actor_name]()
- if "Semaphore" in actor_name:
- new_actor = new_actor._actor
- logger.info(f"Restarted actor: {actor_name}")
- return JSONResponse(
- status_code=status.HTTP_200_OK,
- content={
- "message": f"Actor {actor_name} restarted successfully.",
- "actor_name": actor_name,
- "actor_id": new_actor._actor_id.hex(),
- },
- )
- except Exception as e:
- logger.exception("Failed to restart actor", actor=actor_name)
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=f"Failed to restart actor {actor_name}: {e!s}",
- )
+ new_actor = actor_creation_map[actor_name]()
+ if "Semaphore" in actor_name:
+ new_actor = new_actor._actor
+ logger.info(f"Restarted actor: {actor_name}")
+ return JSONResponse(
+ status_code=status.HTTP_200_OK,
+ content={
+ "message": f"Actor {actor_name} restarted successfully.",
+ "actor_name": actor_name,
+ "actor_id": new_actor._actor_id.hex(),
+ },
+ )
diff --git a/openrag/routers/extract.py b/openrag/routers/extract.py
index 32f51e5a..5179a146 100644
--- a/openrag/routers/extract.py
+++ b/openrag/routers/extract.py
@@ -48,31 +48,23 @@ async def get_extract(
user_partitions=Depends(current_user_or_admin_partitions_list),
):
log = logger.bind(extract_id=extract_id)
- try:
- chunk = await vectordb.get_chunk_by_id.remote(extract_id)
- if chunk is None:
- log.warning("Extract not found.")
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail=f"Extract '{extract_id}' not found.",
- )
- chunk_partition = chunk.metadata["partition"]
- log.info(f"User partitions: {user_partitions}, Chunk partition: {chunk_partition}")
- if chunk_partition not in user_partitions and user_partitions != ["all"]:
- log.warning("User does not have access to this extract.")
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail=f"User does not have access to extract '{extract_id}'.",
- )
- log.info("Extract successfully retrieved.")
- except HTTPException:
- raise
- except Exception as e:
- log.exception("Failed to retrieve extract.", error=str(e))
+
+ chunk = await vectordb.get_chunk_by_id.remote(extract_id)
+ if chunk is None:
+ log.warning("Extract not found.")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"Extract '{extract_id}' not found.",
+ )
+ chunk_partition = chunk.metadata["partition"]
+ log.info(f"User partitions: {user_partitions}, Chunk partition: {chunk_partition}")
+ if chunk_partition not in user_partitions and user_partitions != ["all"]:
+ log.warning("User does not have access to this extract.")
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=f"Failed to retrieve extract: {e!s}",
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=f"User does not have access to extract '{extract_id}'.",
)
+ log.info("Extract successfully retrieved.")
return JSONResponse(
status_code=status.HTTP_200_OK,
diff --git a/openrag/routers/indexer.py b/openrag/routers/indexer.py
index 8f023d54..5a17c1fd 100644
--- a/openrag/routers/indexer.py
+++ b/openrag/routers/indexer.py
@@ -123,13 +123,6 @@ async def add_file(
vectordb=Depends(get_vectordb),
user=Depends(require_partition_editor),
):
- log = logger.bind(
- file_id=file_id,
- partition=partition,
- filename=file.filename,
- user=user.get("display_name"),
- )
-
if await vectordb.file_exists.remote(file_id, partition):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
@@ -137,16 +130,9 @@ async def add_file(
)
save_dir = Path(DATA_DIR)
- try:
- original_filename = file.filename
- file.filename = sanitize_filename(file.filename)
- file_path = await save_file_to_disk(file, save_dir, with_random_prefix=True)
- except Exception as e:
- log.exception("Failed to save file to disk.", error=str(e))
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=str(e),
- )
+ original_filename = file.filename
+ file.filename = sanitize_filename(file.filename)
+ file_path = await save_file_to_disk(file, save_dir, with_random_prefix=True)
metadata.update(
{
@@ -240,8 +226,6 @@ async def put_file(
vectordb=Depends(get_vectordb),
user=Depends(require_partition_editor),
):
- log = logger.bind(file_id=file_id, partition=partition, filename=file.filename)
-
if not await vectordb.file_exists.remote(file_id, partition):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@@ -252,16 +236,9 @@ async def put_file(
await indexer.delete_file.remote(file_id, partition)
save_dir = Path(DATA_DIR)
- try:
- original_filename = file.filename
- file.filename = sanitize_filename(file.filename)
- file_path = await save_file_to_disk(file, save_dir, with_random_prefix=True)
- except Exception:
- log.exception("Failed to save file to disk.")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to save uploaded file.",
- )
+ original_filename = file.filename
+ file.filename = sanitize_filename(file.filename)
+ file_path = await save_file_to_disk(file, save_dir, with_random_prefix=True)
metadata.update(
{
@@ -441,21 +418,13 @@ async def get_task_error(
task_state_manager=Depends(get_task_state_manager),
task_details=Depends(require_task_owner),
):
- try:
- error = await task_state_manager.get_error.remote(task_id)
- if error is None:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail=f"No error found for task '{task_id}'.",
- )
- return {"task_id": task_id, "traceback": error.splitlines()}
- except HTTPException:
- raise
- except Exception:
+ error = await task_state_manager.get_error.remote(task_id)
+ if error is None:
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to retrieve task error.",
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"No error found for task '{task_id}'.",
)
+ return {"task_id": task_id, "traceback": error.splitlines()}
@router.get(
@@ -475,32 +444,27 @@ async def get_task_error(
""",
)
async def get_task_logs(task_id: str, max_lines: int = 100, task_details=Depends(require_task_owner)):
- try:
- if not LOG_FILE.exists():
- raise HTTPException(status_code=500, detail="Log file not found.")
-
- logs = []
- with open(LOG_FILE, errors="replace") as f:
- for line in reversed(list(f)):
- try:
- record = json.loads(line).get("record", {})
- if record.get("extra", {}).get("task_id") == task_id:
- logs.append(
- f"{record['time']['repr']} - {record['level']['name']} - {record['message']} - {(record['extra'])}"
- )
- if len(logs) >= max_lines:
- break
- except json.JSONDecodeError:
- continue
-
- if not logs:
- raise HTTPException(status_code=404, detail=f"No logs found for task '{task_id}'")
-
- return JSONResponse(content={"task_id": task_id, "logs": logs[::-1]}) # restore order
- except HTTPException:
- raise
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"Failed to fetch logs: {e!s}")
+ if not LOG_FILE.exists():
+ raise HTTPException(status_code=500, detail="Log file not found.")
+
+ logs = []
+ with open(LOG_FILE, errors="replace") as f:
+ for line in reversed(list(f)):
+ try:
+ record = json.loads(line).get("record", {})
+ if record.get("extra", {}).get("task_id") == task_id:
+ logs.append(
+ f"{record['time']['repr']} - {record['level']['name']} - {record['message']} - {(record['extra'])}"
+ )
+ if len(logs) >= max_lines:
+ break
+ except json.JSONDecodeError:
+ continue
+
+ if not logs:
+ raise HTTPException(status_code=404, detail=f"No logs found for task '{task_id}'")
+
+ return JSONResponse(content={"task_id": task_id, "logs": logs[::-1]}) # restore order
@router.delete(
@@ -525,13 +489,9 @@ async def cancel_task(
task_state_manager=Depends(get_task_state_manager),
task_details=Depends(require_task_owner),
):
- try:
- obj_ref = await task_state_manager.get_object_ref.remote(task_id)
- if obj_ref is None:
- raise HTTPException(404, f"No ObjectRef stored for task {task_id}")
-
- ray.cancel(obj_ref["ref"], recursive=True)
- return {"message": f"Cancellation signal sent for task {task_id}"}
- except Exception as e:
- logger.exception("Failed to cancel task.")
- raise HTTPException(status_code=500, detail=str(e))
+ obj_ref = await task_state_manager.get_object_ref.remote(task_id)
+ if obj_ref is None:
+ raise HTTPException(404, f"No ObjectRef stored for task {task_id}")
+
+ ray.cancel(obj_ref["ref"], recursive=True)
+ return {"message": f"Cancellation signal sent for task {task_id}"}
diff --git a/openrag/routers/openai.py b/openrag/routers/openai.py
index f7e874cd..564f2dad 100644
--- a/openrag/routers/openai.py
+++ b/openrag/routers/openai.py
@@ -1,3 +1,4 @@
+import asyncio
import json
from pathlib import Path
from urllib.parse import quote
@@ -13,6 +14,7 @@
OpenAICompletionRequest,
)
from utils.dependencies import get_vectordb
+from utils.exceptions.base import OpenRAGError
from utils.logger import get_logger
from .utils import (
@@ -31,6 +33,19 @@
ragpipe = RagPipeline(config=config)
+def _make_sse_error(message: str, code: str) -> str:
+ """Format an error as an SSE data chunk for streaming responses."""
+ chunk = {
+ "error": {
+ "message": message,
+ "type": "error",
+ "param": None,
+ "code": code,
+ }
+ }
+ return f"data: {json.dumps(chunk)}\n\ndata: [DONE]\n\n"
+
+
@router.get(
"/models",
summary="OpenAI-compatible model listing endpoint",
@@ -163,25 +178,14 @@ async def openai_chat_completion(
truncate(str(request.messages)),
)
- try:
- if is_direct_llm_model(request):
- partitions = None
- else:
- partitions = await get_partition_name(model_name, user_partitions, is_admin=user["is_admin"])
- log.debug(f"Using partitions: {partitions}")
- except Exception as e:
- log.warning("Invalid model or partition", error=str(e))
- raise
-
- try:
- llm_output, docs = await ragpipe.chat_completion(partition=partitions, payload=request.model_dump())
- log.debug("RAG chat completion pipeline executed.")
- except Exception as e:
- log.exception("Chat completion failed.", error=str(e))
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=f"Chat completion failed: {e!s}",
- )
+ if is_direct_llm_model(request):
+ partitions = None
+ else:
+ partitions = await get_partition_name(model_name, user_partitions, is_admin=user["is_admin"])
+ log.debug(f"Using partitions: {partitions}")
+
+ llm_output, docs = await ragpipe.chat_completion(partition=partitions, payload=request.model_dump())
+ log.debug("RAG chat completion pipeline executed.")
metadata = __prepare_sources(request2, docs)
metadata_json = json.dumps({"sources": metadata})
@@ -204,33 +208,23 @@ async def stream_response():
except json.JSONDecodeError as e:
log.error("Failed to decode streamed chunk.", error=str(e))
raise
+ except asyncio.CancelledError:
+ log.info("Client disconnected during streaming")
+ return
+ except OpenRAGError as e:
+ log.warning("OpenRAG error during streaming", code=e.code, error=e.message)
+ yield _make_sse_error(e.message, e.code)
except Exception as e:
- log.warning("Error while generating streaming answer", error=str(e))
- error_chunk = {
- "error": {
- "message": f"Error while generating answer: {str(e)}",
- "type": "error",
- "param": None,
- "code": "ERROR_ANSWER_GENERATION",
- }
- }
- yield f"data: {json.dumps(error_chunk)}\n\n"
- yield "data: [DONE]\n\n"
+ log.warning("Error during streaming", error=str(e))
+ yield _make_sse_error("An unexpected error occurred during streaming", "UNEXPECTED_ERROR")
return StreamingResponse(stream_response(), media_type="text/event-stream")
else:
- try:
- chunk = await llm_output.__anext__()
- chunk["model"] = model_name
- chunk["extra"] = metadata_json
- log.debug("Returning non-streaming completion chunk.")
- return JSONResponse(content=chunk)
- except Exception as e:
- log.warning("Error while generating answer", error=str(e))
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=f"Error while generating answer: {e!s}",
- )
+ chunk = await llm_output.__anext__()
+ chunk["model"] = model_name
+ chunk["extra"] = metadata_json
+ log.debug("Returning non-streaming completion chunk.")
+ return JSONResponse(content=chunk)
@router.post(
@@ -285,37 +279,18 @@ async def openai_completion(
detail="Streaming is not supported for this endpoint",
)
- try:
- if is_direct_llm_model(request):
- partitions = None
- else:
- partitions = await get_partition_name(model_name, user_partitions, is_admin=user["is_admin"])
-
- except Exception as e:
- log.warning(f"Invalid model or partition: {e}")
- raise
-
- try:
- llm_output, docs = await ragpipe.completions(partition=partitions, payload=request.model_dump())
- log.debug("RAG completion pipeline executed.")
- except Exception as e:
- log.exception("Completion request failed.", error=str(e))
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=f"Completion failed: {e!s}",
- )
+ if is_direct_llm_model(request):
+ partitions = None
+ else:
+ partitions = await get_partition_name(model_name, user_partitions, is_admin=user["is_admin"])
+
+ llm_output, docs = await ragpipe.completions(partition=partitions, payload=request.model_dump())
+ log.debug("RAG completion pipeline executed.")
metadata = __prepare_sources(request2, docs)
metadata_json = json.dumps({"sources": metadata})
- try:
- complete_response = await llm_output.__anext__()
- complete_response["extra"] = metadata_json
- log.debug("Returning completion response.")
- return JSONResponse(content=complete_response)
- except Exception as e:
- log.warning("No response from LLM.", error=str(e))
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=f"No response from LLM: {e!s}",
- )
+ complete_response = await llm_output.__anext__()
+ complete_response["extra"] = metadata_json
+ log.debug("Returning completion response.")
+ return JSONResponse(content=complete_response)
diff --git a/openrag/routers/tools.py b/openrag/routers/tools.py
index 9d2c725f..315f394f 100644
--- a/openrag/routers/tools.py
+++ b/openrag/routers/tools.py
@@ -118,15 +118,6 @@ async def execute_tool(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Tool {tool['name']} not found",
)
-
- except HTTPException:
- raise
- except Exception as e:
- logger.exception("Failed during tool execution.", extra={"error": str(e)})
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Tool execution failed due to an internal error.",
- )
finally:
# Cleanup of the temporary file
if file_path is not None:
diff --git a/openrag/routers/utils.py b/openrag/routers/utils.py
index cf94ae6e..2aa6fed5 100644
--- a/openrag/routers/utils.py
+++ b/openrag/routers/utils.py
@@ -4,9 +4,12 @@
from typing import Any
import consts
+import httpx
from config import load_config
from fastapi import Depends, Form, HTTPException, Request, UploadFile, status
+from models.indexer import FileMetadataSchema
from openai import AsyncOpenAI
+from pydantic import ValidationError
from utils.dependencies import get_task_state_manager, get_vectordb
from utils.logger import get_logger
@@ -196,10 +199,18 @@ async def validate_file_id(file_id: str):
async def validate_metadata(metadata: Any | None = Form(None)):
try:
processed_metadata = metadata or "{}"
- processed_metadata = json.loads(processed_metadata)
- return processed_metadata
+ parsed = json.loads(processed_metadata)
+
+ # Validate against Pydantic schema
+ validated = FileMetadataSchema(**parsed)
+ return validated.model_dump()
+
except json.JSONDecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid JSON in metadata")
+ except ValidationError as e:
+ # Format Pydantic validation errors for user-friendly response
+ errors = "; ".join(f"{err['loc'][0]}: {err['msg']}" for err in e.errors())
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid metadata: {errors}")
async def validate_file_format(
@@ -248,13 +259,25 @@ async def check_llm_model_availability(request: Request):
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Only these models ({available_models}) are available for your `{model_type}`. Please check your configuration file.",
)
+ except HTTPException:
+ raise
+ except httpx.TimeoutException:
+ logger.warning("LLM model availability check timed out", model=model_type)
+ raise HTTPException(
+ status_code=status.HTTP_504_GATEWAY_TIMEOUT,
+ detail="LLM service timed out",
+ )
+ except httpx.HTTPError as e:
+ logger.warning("LLM model availability check failed", model=model_type, error=str(e))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
+ detail="LLM service is unavailable",
+ )
except Exception as e:
- logger.exception("Failed to validate model", model=model_type, error=str(e))
- if isinstance(e, HTTPException):
- raise
+ logger.exception("Failed to check LLM model availability", model=model_type, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=f"Error while checking the `{model_type}` endpoint, it seems not available at this moment",
+ detail="Failed to check LLM model availability",
)
diff --git a/openrag/scripts/migrations/alembic/env.py b/openrag/scripts/migrations/alembic/env.py
index 23038aed..182c78e6 100644
--- a/openrag/scripts/migrations/alembic/env.py
+++ b/openrag/scripts/migrations/alembic/env.py
@@ -3,7 +3,7 @@
from alembic import context
from components.indexer.vectordb.utils import Base
from config import load_config
-from sqlalchemy import engine_from_config, pool
+from sqlalchemy import URL, engine_from_config, pool
rag_config = load_config()
@@ -25,10 +25,16 @@
collection_name = rag_config.vectordb.collection_name
-database_url = (
- f"postgresql://{rdb_user}:{rdb_password}@{rdb_host}:{rdb_port}/partitions_for_collection_{collection_name}"
+database_url = URL.create(
+ drivername="postgresql",
+ username=rdb_user,
+ password=rdb_password,
+ host=rdb_host,
+ port=rdb_port,
+ database=f"partitions_for_collection_{collection_name}",
)
-config.set_main_option("sqlalchemy.url", database_url)
+# config.set_main_option expects a string, so convert URL object
+config.set_main_option("sqlalchemy.url", str(database_url))
# add your model's MetaData object here
# for 'autogenerate' support
diff --git a/openrag/utils/exceptions/__init__.py b/openrag/utils/exceptions/__init__.py
index 9b5ed21c..0333b6a3 100644
--- a/openrag/utils/exceptions/__init__.py
+++ b/openrag/utils/exceptions/__init__.py
@@ -1 +1,2 @@
from .base import *
+from .common import *
diff --git a/openrag/utils/exceptions/common.py b/openrag/utils/exceptions/common.py
new file mode 100644
index 00000000..eaf1c7ce
--- /dev/null
+++ b/openrag/utils/exceptions/common.py
@@ -0,0 +1,29 @@
+from .base import OpenRAGError
+
+
+class FileStorageError(OpenRAGError):
+ """Raised when file I/O operations fail (save, read, delete)."""
+
+ def __init__(self, message: str, **kwargs):
+ super().__init__(message=message, code="FILE_STORAGE_ERROR", status_code=500, **kwargs)
+
+
+class RayActorError(OpenRAGError):
+ """Raised when a Ray actor operation fails."""
+
+ def __init__(self, message: str, **kwargs):
+ super().__init__(message=message, code="RAY_ACTOR_ERROR", status_code=500, **kwargs)
+
+
+class ToolExecutionError(OpenRAGError):
+ """Raised when tool execution fails."""
+
+ def __init__(self, message: str, **kwargs):
+ super().__init__(message=message, code="TOOL_EXECUTION_ERROR", status_code=500, **kwargs)
+
+
+class UnexpectedError(OpenRAGError):
+ """Raised for unexpected errors that don't match any specific category."""
+
+ def __init__(self, message: str, **kwargs):
+ super().__init__(message=message, code="UNEXPECTED_ERROR", status_code=500, **kwargs)