diff --git a/openrag/api.py b/openrag/api.py index 349069bf..922a33ea 100644 --- a/openrag/api.py +++ b/openrag/api.py @@ -167,6 +167,16 @@ async def openrag_exception_handler(request: Request, exc: OpenRAGError): return JSONResponse(status_code=exc.status_code, content=exc.to_dict()) +@app.exception_handler(Exception) +async def unhandled_exception_handler(request: Request, exc: Exception): + logger = get_logger() + logger.exception("Unhandled exception", error=str(exc)) + return JSONResponse( + status_code=500, + content={"detail": "[UNEXPECTED_ERROR]: An unexpected error occurred", "extra": {}}, + ) + + # Add CORS middleware allow_origins = [ "http://localhost:3042", diff --git a/openrag/app_front.py b/openrag/app_front.py index 5735bbc1..ebc66fc7 100644 --- a/openrag/app_front.py +++ b/openrag/app_front.py @@ -66,7 +66,7 @@ async def on_chat_resume(thread): @cl.password_auth_callback async def auth_callback(username: str, password: str): try: - async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=httpx.Timeout(4 * 60.0))) as client: + async with httpx.AsyncClient(timeout=httpx.Timeout(4 * 60.0)) as client: response = await client.get( url=f"{INTERNAL_BASE_URL}/users/info", headers=get_headers(password), @@ -131,7 +131,7 @@ async def on_chat_start(): api_key = user.metadata.get("api_key", "sk-1234") if user else "sk-1234" logger.debug("New Chat Started", internal_base_url=INTERNAL_BASE_URL) try: - async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=httpx.Timeout(4 * 60.0))) as client: + async with httpx.AsyncClient(timeout=httpx.Timeout(4 * 60.0)) as client: response = await client.get( url=f"{INTERNAL_BASE_URL}/health_check", headers=get_headers(api_key), diff --git a/openrag/components/indexer/chunker/chunker.py b/openrag/components/indexer/chunker/chunker.py index 3a8e5f1d..63431ae4 100644 --- a/openrag/components/indexer/chunker/chunker.py +++ b/openrag/components/indexer/chunker/chunker.py @@ -2,6 +2,7 @@ import openai from components.indexer.utils.text_sanitizer import sanitize_text +from utils.exceptions.base import OpenRAGError from components.prompts import CHUNK_CONTEXTUALIZER_PROMPT from components.utils import detect_language, get_vlm_semaphore, load_config from langchain_core.documents.base import Document @@ -68,18 +69,20 @@ async def _generate_context( ] output = await self.context_generator.ainvoke(messages) return output.content + except openai.APITimeoutError: - logger.warning( - f"OpenAI API timeout contextualizing chunk after {CONTEXTUALIZATION_TIMEOUT}s", - filename=filename, - ) + # VLM timeout - graceful degradation + logger.warning("VLM context generation timeout", timeout=CONTEXTUALIZATION_TIMEOUT) return "" - except Exception as e: - logger.warning( - "Error contextualizing chunk of document", - filename=filename, - error=str(e), - ) + + except openai.APIError as e: + # Other VLM API errors - log but don't fail chunking + logger.error("VLM context generation failed", error=str(e)) + return "" + + except Exception: + # Unexpected errors - log but still gracefully degrade + logger.exception("Unexpected error during context generation") return "" async def contextualize_chunks( @@ -130,8 +133,10 @@ async def contextualize_chunks( for chunk, context in zip(chunks, contexts, strict=True) ] - except Exception as e: - logger.warning(f"Error contextualizing chunks from `{filename}`: {e}") + except OpenRAGError: + raise + except Exception: + logger.exception("Error contextualizing chunks", filename=filename) return chunks diff --git a/openrag/components/indexer/embeddings/openai.py b/openrag/components/indexer/embeddings/openai.py index 9dc031e7..45a829b9 100644 --- a/openrag/components/indexer/embeddings/openai.py +++ b/openrag/components/indexer/embeddings/openai.py @@ -23,8 +23,18 @@ def embedding_dimension(self) -> int: # Test call to get embedding dimension output = self.embed_documents([Document(page_content="test")]) return len(output[0]) + + except openai.APIError as e: + logger.error("Failed to get embedding dimension", error=str(e)) + raise EmbeddingAPIError(f"API error: {e}", model_name=self.embedding_model) + + except (IndexError, AttributeError) as e: + logger.error("Invalid embedding response format", error=str(e)) + raise EmbeddingResponseError("Unexpected response format", error=str(e)) + except Exception: - raise + logger.exception("Unexpected error getting embedding dimension") + raise UnexpectedEmbeddingError("An unexpected error occurred") def embed_documents(self, texts: list[str | Document]) -> list[list[float]]: """ @@ -62,7 +72,7 @@ def embed_documents(self, texts: list[str | Document]) -> list[list[float]]: except Exception as e: logger.exception("Unexpected error while embedding documents", error=str(e)) raise UnexpectedEmbeddingError( - f"Failed to embed documents: {e!s}", + "An unexpected error occurred during document embedding", model_name=self.embedding_model, base_url=self.base_url, error=str(e), diff --git a/openrag/components/indexer/indexer.py b/openrag/components/indexer/indexer.py index 6c3930b7..3e0c16fd 100644 --- a/openrag/components/indexer/indexer.py +++ b/openrag/components/indexer/indexer.py @@ -10,6 +10,8 @@ import torch from config import load_config from langchain_core.documents.base import Document +from utils.exceptions.base import OpenRAGError +from utils.exceptions.common import UnexpectedError from .chunker import BaseChunker, ChunkerFactory from .utils import serialize_file @@ -119,25 +121,34 @@ async def add_file( # Mark task as completed await task_state_manager.set_state.remote(task_id, "COMPLETED") - except Exception as e: - log.exception(f"Task {task_id} failed in add_file") + except OpenRAGError as e: + log.error("Operation failed during file ingestion", code=e.code, error=e.message) tb = "".join(traceback.format_exception(type(e), e, e.__traceback__)) await task_state_manager.set_state.remote(task_id, "FAILED") await task_state_manager.set_error.remote(task_id, tb) raise + except Exception as e: + # Truly unexpected errors + log.exception("Unexpected error during file ingestion", task_id=task_id) + tb = "".join(traceback.format_exception(type(e), e, e.__traceback__)) + await task_state_manager.set_state.remote(task_id, "FAILED") + await task_state_manager.set_error.remote(task_id, tb) + raise UnexpectedError("An unexpected error occurred during file processing") from e + finally: + # GPU cleanup if torch.cuda.is_available(): gc.collect() torch.cuda.empty_cache() torch.cuda.ipc_collect() + + # File cleanup - nest try/except per Phase 2 decision try: - # Cleanup input file if not save_uploaded_files: Path(path).unlink(missing_ok=True) - log.debug(f"Deleted input file: {path}") except Exception as cleanup_err: - log.warning(f"Failed to delete input file {path}: {cleanup_err}") + log.warning("Failed to delete input file", path=path, error=str(cleanup_err)) return True @ray.method(concurrency_group="insert") @@ -157,10 +168,13 @@ async def delete_file(self, file_id: str, partition: str) -> bool: await vectordb.delete_file.remote(file_id, partition) log.info("Deleted file from partition.", file_id=file_id, partition=partition) - except Exception as e: - log.exception("Error in delete_file", error=str(e)) + except OpenRAGError: raise + except Exception as e: + log.exception("Unexpected error in delete_file") + raise UnexpectedError("An unexpected error occurred during file deletion") from e + @ray.method(concurrency_group="update") async def update_file_metadata( self, @@ -184,10 +198,14 @@ async def update_file_metadata( await vectordb.async_add_documents.remote(docs, user=user) log.info("Metadata updated for file.") - except Exception as e: - log.exception("Error in update_file_metadata", error=str(e)) + + except OpenRAGError: raise + except Exception as e: + log.exception("Unexpected error in update_file_metadata") + raise UnexpectedError("An unexpected error occurred during metadata update") from e + @ray.method(concurrency_group="update") async def copy_file( self, @@ -216,10 +234,14 @@ async def copy_file( new_file_id=metadata.get("file_id"), new_partition=metadata.get("partition"), ) - except Exception as e: - log.exception("Error in copy_file", error=str(e)) + + except OpenRAGError: raise + except Exception as e: + log.exception("Unexpected error in copy_file") + raise UnexpectedError("An unexpected error occurred during file copy") from e + @ray.method(concurrency_group="search") async def asearch( self, diff --git a/openrag/components/indexer/loaders/base.py b/openrag/components/indexer/loaders/base.py index 7a9e4e40..7071a19e 100644 --- a/openrag/components/indexer/loaders/base.py +++ b/openrag/components/indexer/loaders/base.py @@ -116,8 +116,13 @@ async def get_image_description( base64.b64decode(image_data) image_url = f"data:image/png;base64,{image_data}" logger.debug("Processing raw base64 string") - except Exception: - logger.error(f"Invalid image data type or format: {type(image_data)}") + except (ValueError, base64.binascii.Error) as e: + # Invalid base64 data + logger.warning("Failed to decode base64 image", error=str(e)[:100]) + return """\n\nInvalid image data format\n\n""" + except Exception as e: + # PIL image opening errors or other unexpected issues + logger.warning("Failed to process image data", error=str(e)[:100]) return """\n\nInvalid image data format\n\n""" else: logger.error(f"Unsupported image data type: {type(image_data)}") diff --git a/openrag/components/indexer/loaders/eml_loader.py b/openrag/components/indexer/loaders/eml_loader.py index c23bcca7..742a0a2f 100644 --- a/openrag/components/indexer/loaders/eml_loader.py +++ b/openrag/components/indexer/loaders/eml_loader.py @@ -1,5 +1,6 @@ import datetime import email +import email.errors import io import os import tempfile @@ -7,7 +8,8 @@ from pathlib import Path from langchain_core.documents.base import Document -from PIL import Image +from PIL import Image, UnidentifiedImageError +from utils.exceptions.common import FileStorageError, UnexpectedError from . import get_loader_classes from .base import BaseLoader @@ -52,7 +54,8 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar if email_data["header"]["date"]: try: email_data["header"]["date"] = parsedate_to_datetime(email_data["header"]["date"]).isoformat() - except Exception: + except (ValueError, TypeError, email.errors.MessageError): + # Invalid date format - keep original string pass # Extract body content and attachments @@ -98,7 +101,8 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar # Use plain text as primary body content if content_type == "text/plain" or not body_content: body_content = text_content - except Exception as e: + except (UnicodeDecodeError, email.errors.MessageError) as e: + # Failed to decode email text part - skip this part print(f"Failed to decode text content: {e}") # Extract body content @@ -140,7 +144,11 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar metadata={"source": f"attachment:{filename}"}, ) attachments_text += f"Content:\n{attachment_doc.page_content}\n" + except OSError as e: + # File I/O error - skip this attachment + attachments_text += f"Cannot read attachment file: {str(e)[:200]}...\n" except Exception as e: + # Loader-specific errors - try fallback loaders attachments_text += f"Failed to process attachment with loader ({loader_cls.__name__}): {str(e)[:200]}...\n" # Special fallback handling for PDFs with alternative loaders @@ -179,7 +187,11 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar attachments_text += f"Content (via {fallback_loader_name}):\n{attachment_doc.page_content}\n" fallback_success = True break + except OSError as fallback_e: + # File I/O error - skip fallback + attachments_text += f"Fallback {fallback_loader_name} file error: {str(fallback_e)[:100]}...\n" except Exception as fallback_e: + # Fallback loader failed - try next attachments_text += f"Fallback {fallback_loader_name} also failed: {str(fallback_e)[:100]}...\n" if not fallback_success: @@ -197,7 +209,11 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar attachments_text += ( "Image attachment present but image captioning disabled\n" ) + except (OSError, UnidentifiedImageError) as img_e: + # Invalid or unreadable image + attachments_text += f"Image fallback failed (invalid image): {str(img_e)[:100]}...\n" except Exception as img_e: + # Unexpected image processing error attachments_text += f"Image fallback also failed: {str(img_e)[:100]}...\n" # Try text fallback for other text-based formats @@ -213,7 +229,11 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar ) else: attachments_text += "No readable text found in attachment\n" + except UnicodeDecodeError as text_e: + # Text decoding failed + attachments_text += f"Text fallback failed (encoding error): {str(text_e)[:100]}...\n" except Exception as text_e: + # Unexpected text extraction error attachments_text += f"Text fallback failed: {str(text_e)[:100]}...\n" finally: # Clean up temporary file @@ -237,8 +257,9 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar # Generate caption using the base loader's method caption = await self.get_image_description(image_data=image) attachments_text += f"Image Description:\n{caption}\n" - except Exception as e: - attachments_text += f"Failed to generate image caption: {str(e)[:200]}...\n" + except (OSError, UnidentifiedImageError) as e: + # Invalid or corrupted image + attachments_text += f"Failed to generate image caption (invalid image): {str(e)[:200]}...\n" # Try to show basic image info if available try: size_info = f"Image size: {len(attachment['raw'])} bytes" @@ -247,6 +268,9 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar ) except Exception: attachments_text += "Image attachment present but corrupted or unreadable\n" + except Exception as e: + # Unexpected image captioning error (VLM errors handled in base.py) + attachments_text += f"Failed to generate image caption: {str(e)[:200]}...\n" elif content_type.startswith("text/"): # For text attachments, decode directly @@ -255,7 +279,11 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar else: # For other binary content, just show metadata attachments_text += f"Binary content (size: {len(attachment['raw'])} bytes)\n" + except OSError as e: + # File I/O error creating temp file + attachments_text += f"Cannot create temp file for attachment: {e}\n" except Exception as e: + # Unexpected error processing attachment attachments_text += f"Content could not be processed: {e}\n" attachments_text += "---\n" @@ -298,8 +326,15 @@ async def aload_document(self, file_path, metadata: dict | None = None, save_mar with open(markdown_path, "w", encoding="utf-8") as md_file: md_file.write(content_body) metadata["markdown_path"] = str(markdown_path) + except OSError as e: + # File I/O error reading email file + raise FileStorageError(f"Cannot read email file: {e}") from e + except email.errors.MessageError as e: + # Email parsing error + raise UnexpectedError(f"Invalid email format: {e}") from e except Exception as e: - raise ValueError(f"Failed to parse the EML file {file_path}: {e}") + # Unexpected error + raise UnexpectedError(f"Failed to parse the EML file {file_path}: {e}") from e document = Document(page_content=content_body, metadata=metadata) return document diff --git a/openrag/components/indexer/loaders/image.py b/openrag/components/indexer/loaders/image.py index a1356c80..0bf7a3fe 100644 --- a/openrag/components/indexer/loaders/image.py +++ b/openrag/components/indexer/loaders/image.py @@ -3,7 +3,7 @@ import cairosvg from langchain_core.documents import Document -from PIL import Image +from PIL import Image, UnidentifiedImageError from utils.logger import get_logger from .base import BaseLoader @@ -29,7 +29,16 @@ async def aload_document(self, file_path, metadata=None, save_markdown=False): img = Image.open(BytesIO(png_data)) else: img = Image.open(path) + except OSError as e: + # File not found, permission denied, etc. + log.error("Cannot read image file", file_path=str(path), error=str(e)) + raise ImageLoadError(f"Cannot read image file: {e}") from e + except UnidentifiedImageError as e: + # Invalid image format + log.error("Invalid image format", file_path=str(path), error=str(e)) + raise ImageLoadError(f"Invalid image format: {e}") from e except Exception as e: + # SVG conversion errors or other unexpected issues log.error( "Failed to load image file", file_path=str(path), diff --git a/openrag/components/indexer/loaders/media_loader.py b/openrag/components/indexer/loaders/media_loader.py index 41c9f635..e90b4871 100644 --- a/openrag/components/indexer/loaders/media_loader.py +++ b/openrag/components/indexer/loaders/media_loader.py @@ -7,7 +7,7 @@ import numpy as np from components.utils import get_audio_semaphore from langchain_core.documents.base import Document -from openai import AsyncOpenAI +from openai import APIError, AsyncOpenAI from pydub import AudioSegment, silence from tqdm.asyncio import tqdm from utils.logger import get_logger @@ -58,8 +58,13 @@ async def _process_chunk(self, index: int, segment: AudioSegment, wav_path: Path try: result = await self._transcribe_chunk(tmp_path, language) return result + except APIError as e: + # OpenAI API errors - gracefully degrade by returning empty transcript + logger.warning("Audio transcription API error", chunk=tmp_path.name, error=str(e)[:200]) + return "" except Exception as e: - logger.exception(f"Error transcribing chunk {tmp_path.name}", error=str(e)) + # Unexpected errors - log and return empty transcript for graceful degradation + logger.exception("Error transcribing chunk", chunk=tmp_path.name, error=str(e)) return "" finally: tmp_path.unlink(missing_ok=True) @@ -77,8 +82,13 @@ async def _transcribe_chunk(self, wav_path: Path, language: str = None) -> str: kwargs["language"] = language result = await self.client.audio.transcriptions.create(**kwargs) return result.text.strip() + except APIError as e: + # OpenAI API errors - gracefully degrade by returning empty transcript + logger.warning("Audio transcription API error", chunk=wav_path.name, error=str(e)[:200]) + return "" except Exception as e: - logger.exception(f"Error transcribing chunk {wav_path.name}", error=str(e)) + # Unexpected errors - log and return empty transcript for graceful degradation + logger.exception("Error transcribing chunk", chunk=wav_path.name, error=str(e)) return "" async def _get_audio_chunks(self, sound: AudioSegment) -> list[AudioSegment]: @@ -122,7 +132,12 @@ async def _detect_language(self, sound: AudioSegment, wav_path, fallback_languag return fallback_language # Fallback to English try: return langdetect.detect(text) + except langdetect.LangDetectException as e: + # Expected failure for non-textual content or too-short text + logger.warning("Language detection failed", error=str(e)) + return fallback_language except Exception as e: + # Unexpected language detection errors logger.exception("Language detection failed", error=str(e)) return fallback_language finally: diff --git a/openrag/components/indexer/loaders/pdf_loaders/marker.py b/openrag/components/indexer/loaders/pdf_loaders/marker.py index 7e5561aa..0e095f77 100644 --- a/openrag/components/indexer/loaders/pdf_loaders/marker.py +++ b/openrag/components/indexer/loaders/pdf_loaders/marker.py @@ -9,6 +9,7 @@ from config import load_config from langchain_core.documents.base import Document from marker.converters.pdf import PdfConverter +from utils.exceptions.common import FileStorageError, UnexpectedError from utils.logger import get_logger from ..base import BaseLoader @@ -102,9 +103,16 @@ def _process_pdf(file_path, config): ) render = converter(file_path) return render + except asyncio.CancelledError: + # Cancellation request - propagate immediately + logger.info("PDF processing cancelled", path=file_path) + raise + except OSError as e: + logger.error("Cannot read PDF file", path=file_path, error=str(e)) + raise FileStorageError(f"Cannot read PDF file: {e}") from e except Exception as e: logger.exception("Error processing PDF", path=file_path, error=str(e)) - raise + raise UnexpectedError("Failed to process PDF document") from e finally: gc.collect() if torch.cuda.is_available(): @@ -124,6 +132,10 @@ def run_with_timeout(): return result except MPTimeoutError: self.logger.exception("MarkerWorker child process timed out", path=file_path) + raise UnexpectedError("PDF processing timed out") + except asyncio.CancelledError: + # Cancellation - propagate + self.logger.info("PDF processing cancelled", path=file_path) raise except Exception: self.logger.exception("Error processing with MarkerWorker", path=file_path) @@ -229,7 +241,7 @@ async def aload_document( ) if not markdown: - raise RuntimeError(f"Conversion failed for {file_path_str}") + raise UnexpectedError(f"Conversion failed for {file_path_str}") if self.image_captioning: keys = list(images.keys()) @@ -253,6 +265,14 @@ async def aload_document( logger.info(f"Processed {file_path_str} in {duration:.2f}s") return doc + except asyncio.CancelledError: + # Cancellation - propagate immediately (MUST be first) + logger.info("PDF loading cancelled", path=file_path_str) + raise + except OSError as e: + logger.error("Cannot read PDF file", path=file_path_str, error=str(e)) + raise FileStorageError(f"Cannot read PDF file: {e}") from e except Exception: + # Ray actor errors or PDF processing failures logger.exception("Error in aload_document", path=file_path_str) raise diff --git a/openrag/components/indexer/loaders/pptx_loader.py b/openrag/components/indexer/loaders/pptx_loader.py index 69fa45be..ea46c18f 100644 --- a/openrag/components/indexer/loaders/pptx_loader.py +++ b/openrag/components/indexer/loaders/pptx_loader.py @@ -125,11 +125,20 @@ def _convert_chart_to_markdown(self, chart): separator = "|" + "|".join(["---"] * len(data[0])) + "|" return md + "\n".join([header, separator] + markdown_table[1:]) except ValueError as e: - # Handle the specific error for unsupported chart types + # Handle unsupported chart types (expected error) if "unsupported plot type" in str(e): + logger.debug("Unsupported chart type encountered") return "\n\n[unsupported chart]\n\n" - except Exception: - # Catch any other exceptions that might occur + # Other ValueError - log and return placeholder + logger.warning("Chart conversion value error", error=str(e)) + return "\n\n[unsupported chart]\n\n" + except (AttributeError, IndexError) as e: + # Missing chart data or unexpected structure + logger.warning("Chart structure error", error=str(e)) + return "\n\n[unsupported chart]\n\n" + except Exception as e: + # Unexpected errors - log and gracefully degrade + logger.warning("Chart conversion failed", error=str(e)) return "\n\n[unsupported chart]\n\n" diff --git a/openrag/components/indexer/loaders/serializer.py b/openrag/components/indexer/loaders/serializer.py index 9f687a86..e620f5c6 100644 --- a/openrag/components/indexer/loaders/serializer.py +++ b/openrag/components/indexer/loaders/serializer.py @@ -5,6 +5,7 @@ import torch from config import load_config from langchain_core.documents.base import Document +from utils.exceptions.common import FileStorageError, UnexpectedError from . import get_loader_classes @@ -84,6 +85,9 @@ async def serialize_document( torch.cuda.ipc_collect() log.info("Document serialized successfully") return doc + except OSError as e: + log.error("File operation failed during serialization", path=str(path), error=str(e)) + raise FileStorageError(f"Cannot read file: {e}") from e except Exception as e: - log.exception("Failed to serialize document", error=str(e)) - raise + log.exception("Failed to serialize document", path=str(path), file_type=file_ext, error=str(e)) + raise UnexpectedError("Failed to serialize document: unsupported format or corrupted file") from e diff --git a/openrag/components/indexer/vectordb/utils.py b/openrag/components/indexer/vectordb/utils.py index a7717b39..9137ff78 100644 --- a/openrag/components/indexer/vectordb/utils.py +++ b/openrag/components/indexer/vectordb/utils.py @@ -97,6 +97,24 @@ class User(Base): memberships = relationship("PartitionMembership", back_populates="user", cascade="all, delete-orphan") +class FileDomain(Base): + __tablename__ = "file_domains" + + id = Column(Integer, primary_key=True) + file_id = Column(String, nullable=False) + partition_name = Column( + String, + ForeignKey("partitions.partition", ondelete="CASCADE"), + nullable=False, + ) + domain = Column(String, nullable=False) + + __table_args__ = ( + UniqueConstraint("file_id", "partition_name", "domain", name="uix_file_domain"), + Index("ix_partition_domain", "partition_name", "domain"), + ) + + class PartitionMembership(Base): __tablename__ = "partition_memberships" @@ -134,10 +152,10 @@ def __init__(self, database_url: str, logger=logger): AUTH_TOKEN = os.getenv("AUTH_TOKEN") self._ensure_admin_user(AUTH_TOKEN) - except Exception as e: + except Exception: raise VDBConnectionError( - f"Failed to connect to database: {e!s}", - db_url=database_url, + "An unexpected database error occurred", + db_url=str(database_url), db_type="SQLAlchemy", ) @@ -225,10 +243,14 @@ def add_file_to_partition( session.commit() log.info("Added file successfully") return True - except Exception: + except Exception as e: session.rollback() - log.exception("Error adding file to partition") - raise + log.exception("Error adding file to partition", error=str(e)) + raise VDBInsertError( + "An unexpected database error occurred", + file_id=file_id, + partition=partition, + ) def remove_file_from_partition(self, file_id: str, partition: str): """Remove a file from its partition - Optimized without join""" @@ -246,8 +268,12 @@ def remove_file_from_partition(self, file_id: str, partition: str): return False except Exception as e: session.rollback() - log.error(f"Error removing file: {e}") - raise e + log.exception("Error removing file", error=str(e)) + raise VDBDeleteError( + "An unexpected database error occurred", + file_id=file_id, + partition=partition, + ) def delete_partition(self, partition: str): """Delete a partition and all its files""" @@ -293,6 +319,58 @@ def file_exists_in_partition(self, file_id: str, partition: str): session.query(File).filter(File.file_id == file_id, File.partition_name == partition).exists() ).scalar() + # Domains + + def get_file_ids_by_domains(self, partition: str | None, domains: list[str]) -> list[str]: + """Get file_ids matching ANY of the given domains in a partition (or all partitions if None).""" + with self.Session() as session: + query = session.query(FileDomain.file_id).filter(FileDomain.domain.in_(domains)) + if partition is not None: + query = query.filter(FileDomain.partition_name == partition) + rows = query.distinct().all() + return [row[0] for row in rows] + + def set_file_domains(self, file_id: str, partition: str, domains: list[str]): + """Replace all domains for a file with the given list.""" + with self.Session() as session: + try: + session.query(FileDomain).filter( + FileDomain.file_id == file_id, + FileDomain.partition_name == partition, + ).delete() + for domain in domains: + session.add(FileDomain(file_id=file_id, partition_name=partition, domain=domain)) + session.commit() + except Exception as e: + session.rollback() + self.logger.exception("Error setting file domains", error=str(e)) + raise VDBInsertError( + "An unexpected database error occurred", + file_id=file_id, + partition=partition, + ) + + def get_file_domains(self, file_id: str, partition: str) -> list[str]: + """Get domains for a file in a partition.""" + with self.Session() as session: + rows = ( + session.query(FileDomain.domain) + .filter(FileDomain.file_id == file_id, FileDomain.partition_name == partition) + .all() + ) + return [row[0] for row in rows] + + def list_partition_domains(self, partition: str) -> list[str]: + """List all unique domains in a partition.""" + with self.Session() as session: + rows = ( + session.query(FileDomain.domain) + .filter(FileDomain.partition_name == partition) + .distinct() + .all() + ) + return [row[0] for row in rows] + # Users def create_user( diff --git a/openrag/components/indexer/vectordb/vectordb.py b/openrag/components/indexer/vectordb/vectordb.py index 291f8ac0..666e6931 100644 --- a/openrag/components/indexer/vectordb/vectordb.py +++ b/openrag/components/indexer/vectordb/vectordb.py @@ -16,6 +16,7 @@ MilvusException, RRFRanker, ) +from sqlalchemy import URL from utils.exceptions.base import EmbeddingError from utils.exceptions.vectordb import * from utils.logger import get_logger @@ -177,7 +178,7 @@ def __init__(self): except Exception as e: self.logger.exception("Unexpected error initializing Milvus clients", error=str(e)) raise VDBConnectionError( - f"Unexpected error initializing Milvus clients: {e!s}", + "An unexpected database error occurred", db_url=uri, db_type="Milvus", ) @@ -225,8 +226,16 @@ def load_collection(self): operation="load_collection", ) + database_url = URL.create( + drivername="postgresql", + username=self.rdb_user, + password=self.rdb_password, + host=self.rdb_host, + port=self.rdb_port, + database=f"partitions_for_collection_{self.collection_name}", + ) self.partition_file_manager = PartitionFileManager( - database_url=f"postgresql://{self.rdb_user}:{self.rdb_password}@{self.rdb_host}:{self.rdb_port}/partitions_for_collection_{self.collection_name}", + database_url=database_url, logger=self.logger, ) self.logger.info("Milvus collection loaded.") @@ -238,7 +247,7 @@ def load_collection(self): error=str(e), ) raise UnexpectedVDBError( - f"Unexpected error setting collection name `{self.collection_name}`: {e!s}", + "An unexpected database error occurred", collection_name=self.collection_name, ) @@ -361,6 +370,9 @@ async def async_add_documents(self, chunks: list[Document], user: dict) -> None: file_id=file_id, ) + # Extract domains before inserting into Milvus (not stored as chunk metadata) + domains = file_metadata.pop("domains", None) + entities = [] vectors = await self.embedder.aembed_documents(chunks) order_metadata_l: list[dict] = _gen_chunk_order_metadata(n=len(chunks)) @@ -386,6 +398,10 @@ async def async_add_documents(self, chunks: list[Document], user: dict) -> None: file_metadata=file_metadata, user_id=user.get("id"), ) + + if domains: + self.partition_file_manager.set_file_domains(file_id, partition, domains) + self.logger.info(f"File '{file_id}' added to partition '{partition}'") except EmbeddingError as e: self.logger.exception("Embedding failed", error=str(e)) @@ -394,10 +410,16 @@ async def async_add_documents(self, chunks: list[Document], user: dict) -> None: self.logger.exception("VectorDB operation failed", error=str(e)) raise + except MilvusException as e: + self.logger.exception("Milvus insert operation failed", error=str(e)) + raise VDBInsertError( + "Failed to insert document into collection", + collection_name=self.collection_name, + ) except Exception as e: self.logger.exception("Unexpected error while adding a document", error=str(e)) raise UnexpectedVDBError( - f"Unexpected error while adding a document: {e!s}", + "An unexpected database error occurred", collection_name=self.collection_name, ) @@ -431,6 +453,8 @@ async def async_multi_query_search( retrieved_chunks[document.metadata["_id"]] = document return list(retrieved_chunks.values()) + DOMAIN_BATCH_SIZE = 1000 + async def async_search( self, query: str, @@ -439,6 +463,51 @@ async def async_search( partition: list[str] = None, filter: dict | None = None, with_surrounding_chunks: bool = False, + ) -> list[Document]: + filter = dict(filter) if filter else {} + + # Resolve domains to file_ids via PostgreSQL + if "domains" in filter: + domains = filter.pop("domains") + partition_for_lookup = partition[0] if partition != ["all"] else None + file_ids = self.partition_file_manager.get_file_ids_by_domains(partition_for_lookup, domains) + if not file_ids: + return [] + + if len(file_ids) <= self.DOMAIN_BATCH_SIZE: + filter["file_id"] = file_ids + else: + # Split into batches, run parallel searches, merge results + batches = [ + file_ids[i : i + self.DOMAIN_BATCH_SIZE] for i in range(0, len(file_ids), self.DOMAIN_BATCH_SIZE) + ] + tasks = [] + for batch in batches: + batch_filter = {**filter, "file_id": batch} + tasks.append( + self._search_with_filter( + query, top_k, similarity_threshold, partition, batch_filter, with_surrounding_chunks + ) + ) + batch_results = await asyncio.gather(*tasks) + merged = {} + for results in batch_results: + for doc in results: + merged[doc.metadata["_id"]] = doc + return sorted(merged.values(), key=lambda d: d.metadata.get("_id", 0), reverse=True)[:top_k] + + return await self._search_with_filter( + query, top_k, similarity_threshold, partition, filter, with_surrounding_chunks + ) + + async def _search_with_filter( + self, + query: str, + top_k: int, + similarity_threshold: float, + partition: list[str], + filter: dict, + with_surrounding_chunks: bool, ) -> list[Document]: expr_parts = [] if partition != ["all"]: @@ -446,7 +515,11 @@ async def async_search( if filter: for key, value in filter.items(): - expr_parts.append(f"{key} == '{value}'") + if isinstance(value, list): + formatted = ", ".join(f"'{v}'" for v in value) + expr_parts.append(f"{key} in [{formatted}]") + else: + expr_parts.append(f"{key} == '{value}'") # Join all parts with " and " only if there are multiple conditions expr = " and ".join(expr_parts) if expr_parts else "" @@ -520,7 +593,7 @@ async def async_search( except Exception as e: self.logger.exception("Unexpected error occurred", error=str(e)) raise UnexpectedVDBError( - f"Unexpected error occurred: {e!s}", + "An unexpected database error occurred", collection_name=self.collection_name, partition=partition, ) @@ -594,7 +667,7 @@ async def delete_file(self, file_id: str, partition: str): except Exception as e: log.exception("Unexpected error while deleting file chunks", error=str(e)) raise UnexpectedVDBError( - f"Unexpected error while deleting file chunks {file_id}: {e!s}", + "An unexpected database error occurred", collection_name=self.collection_name, partition=partition, file_id=file_id, @@ -651,8 +724,8 @@ async def get_file_chunks(self, file_id: str, partition: str, include_id: bool = except Exception as e: log.exception("Unexpected error while getting file chunks", error=str(e)) - raise VDBSearchError( - f"Unexpected error while getting file chunks {file_id}: {e!s}", + raise UnexpectedVDBError( + "An unexpected database error occurred", collection_name=self.collection_name, partition=partition, file_id=file_id, @@ -696,7 +769,7 @@ async def get_chunk_by_id(self, chunk_id: str): except Exception as e: log.exception("Unexpected error while retrieving chunk", error=str(e)) raise UnexpectedVDBError( - f"Unexpected error while retrieving chunk {chunk_id}: {e!s}", + "An unexpected database error occurred", collection_name=self.collection_name, ) @@ -706,6 +779,8 @@ def file_exists(self, file_id: str, partition: str): """ try: return self.partition_file_manager.file_exists_in_partition(file_id=file_id, partition=partition) + except VDBError: + raise except Exception as e: self.logger.exception( "File existence check failed.", @@ -713,7 +788,12 @@ def file_exists(self, file_id: str, partition: str): partition=partition, error=str(e), ) - return False + raise UnexpectedVDBError( + "An unexpected database error occurred", + collection_name=self.collection_name, + partition=partition, + file_id=file_id, + ) def list_partition_files(self, partition: str, limit: int | None = None): try: @@ -729,7 +809,7 @@ def list_partition_files(self, partition: str, limit: int | None = None): error=str(e), ) raise UnexpectedVDBError( - f"Unexpected error while listing files in partition {partition}: {e!s}", + "An unexpected database error occurred", collection_name=self.collection_name, partition=partition, ) @@ -737,9 +817,14 @@ def list_partition_files(self, partition: str, limit: int | None = None): def list_partitions(self): try: return self.partition_file_manager.list_partitions() + except VDBError: + raise except Exception as e: self.logger.exception("Failed to list partitions", error=str(e)) - raise + raise UnexpectedVDBError( + "An unexpected database error occurred", + collection_name=self.collection_name, + ) def collection_exists(self, collection_name: str): """ @@ -773,7 +858,7 @@ async def delete_partition(self, partition: str): except Exception as e: log.exception("Unexpected error while deleting partition", error=str(e)) raise UnexpectedVDBError( - f"Unexpected error while deleting partition {partition}: {e!s}", + "An unexpected database error occurred", collection_name=self.collection_name, partition=partition, ) @@ -785,9 +870,15 @@ def partition_exists(self, partition: str): log = self.logger.bind(partition=partition) try: return self.partition_file_manager.partition_exists(partition=partition) + except VDBError: + raise except Exception as e: log.exception("Partition existence check failed.", error=str(e)) - return False + raise UnexpectedVDBError( + "An unexpected database error occurred", + collection_name=self.collection_name, + partition=partition, + ) async def list_all_chunk(self, partition: str, include_embedding: bool = True): """ @@ -855,7 +946,7 @@ def prepare_metadata(res: dict): error=str(e), ) raise UnexpectedVDBError( - f"Unexpected error while listing all chunks in partition {partition}: {e!s}", + "An unexpected database error occurred", collection_name=self.collection_name, partition=partition, ) @@ -920,6 +1011,20 @@ async def remove_partition_member(self, partition: str, user_id: int) -> bool: self.partition_file_manager.remove_partition_member(partition, user_id) self.logger.info(f"User_id {user_id} removed from partition '{partition}'.") + # Domain management (pure PostgreSQL, no Milvus interaction) + + async def set_file_domains(self, file_id: str, partition: str, domains: list[str]): + self._check_file_exists(file_id, partition) + self.partition_file_manager.set_file_domains(file_id, partition, domains) + + async def get_file_domains(self, file_id: str, partition: str) -> list[str]: + self._check_file_exists(file_id, partition) + return self.partition_file_manager.get_file_domains(file_id, partition) + + async def list_partition_domains(self, partition: str) -> list[str]: + self._check_partition_exists(partition) + return self.partition_file_manager.list_partition_domains(partition) + def _check_user_exists(self, user_id: int): if not self.partition_file_manager.user_exists(user_id): self.logger.warning(f"User with ID {user_id} does not exist.") diff --git a/openrag/components/llm.py b/openrag/components/llm.py index 634663c4..0d06b549 100644 --- a/openrag/components/llm.py +++ b/openrag/components/llm.py @@ -1,7 +1,9 @@ +import asyncio import copy import json import httpx +from utils.exceptions.common import UnexpectedError from utils.logger import get_logger logger = get_logger() @@ -37,10 +39,10 @@ async def completions(self, request: dict): data = response.json() yield data except httpx.HTTPStatusError as e: - error_detail = e.response.text - raise ValueError(f"LLM API error ({e.response.status_code}): {error_detail}") + logger.error("LLM API returned error", status_code=e.response.status_code) + raise UnexpectedError(f"LLM API error ({e.response.status_code})") from e except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON in API response: {str(e)}") + raise UnexpectedError("Invalid JSON in LLM API response") from e async def chat_completion(self, request: dict): request.pop("model") @@ -58,17 +60,28 @@ async def chat_completion(self, request: dict): headers=self.headers, json=payload, ) as response: - if response.status_code >= 400: - await response.aread() - error_detail = response.text - raise ValueError(f"LLM API error ({response.status_code}): {error_detail}") + response.raise_for_status() async for line in response.aiter_lines(): yield line - except ValueError: + + except asyncio.CancelledError: + # MUST be first per Phase 2 decision + logger.info("LLM streaming cancelled by client") raise + + except httpx.HTTPStatusError as e: + # 4xx/5xx responses + logger.error("LLM API returned error", status_code=e.response.status_code) + raise UnexpectedError(f"LLM API error ({e.response.status_code})") from e + + except httpx.RequestError as e: + # Network/connection failures + logger.error("Network error during LLM streaming", error=str(e)) + raise UnexpectedError("Network error during LLM streaming") from e + except Exception as e: - logger.error(f"Error while streaming chat completion: {str(e)}") - raise + logger.exception("Unexpected error during LLM streaming") + raise UnexpectedError("An unexpected error occurred during streaming") from e else: # Handle non-streaming response try: @@ -81,7 +94,7 @@ async def chat_completion(self, request: dict): data = response.json() yield data except httpx.HTTPStatusError as e: - error_detail = e.response.text - raise ValueError(f"LLM API error ({e.response.status_code}): {error_detail}") + logger.error("LLM API returned error", status_code=e.response.status_code) + raise UnexpectedError(f"LLM API error ({e.response.status_code})") from e except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON in API response: {str(e)}") + raise UnexpectedError("Invalid JSON in LLM API response") from e diff --git a/openrag/components/map_reduce.py b/openrag/components/map_reduce.py index e9f89f34..9572a7c4 100644 --- a/openrag/components/map_reduce.py +++ b/openrag/components/map_reduce.py @@ -80,7 +80,8 @@ async def infer_chunk_relevancy(self, query, chunk: Document) -> SummarizedChunk ) return output_chunk except Exception as e: - logger.error("Error during chunk relevancy inference", error=str(e)) + # Graceful degradation - mark chunk as irrelevant on error + logger.warning("Failed to infer chunk relevancy", chunk_id=chunk.metadata.get("id"), error=str(e)[:200]) return SummarizedChunk(relevancy=False, summary="") async def map_batch( diff --git a/openrag/components/pipeline.py b/openrag/components/pipeline.py index 3c20a3fc..36af5ae9 100644 --- a/openrag/components/pipeline.py +++ b/openrag/components/pipeline.py @@ -44,8 +44,10 @@ def __init__(self, config) -> None: if self.reranker_enabled: self.reranker = Reranker(logger, config) - async def retrieve_docs(self, partition: list[str], query: str, use_map_reduce: bool = False) -> list[Document]: - docs = await self.retriever.retrieve(partition=partition, query=query) + async def retrieve_docs( + self, partition: list[str], query: str, use_map_reduce: bool = False, filter: dict | None = None + ) -> list[Document]: + docs = await self.retriever.retrieve(partition=partition, query=query, filter=filter) top_k = max(self.map_reduce_max_docs, self.reranker_top_k) if use_map_reduce else self.reranker_top_k logger.debug("Documents retreived", document_count=len(docs)) if docs: @@ -117,20 +119,25 @@ async def _prepare_for_chat_completion(self, partition: list[str], payload: dict query = await self.generate_query(messages) logger.debug("Prepared query for chat completion", query=query) - metadata = payload.get("metadata", {}) + metadata = payload.get("metadata", {}) or {} use_map_reduce = metadata.get("use_map_reduce", False) spoken_style_answer = metadata.get("spoken_style_answer", False) + domains = metadata.get("domains") logger.debug( "Metadata parameters", use_map_reduce=use_map_reduce, spoken_style_answer=spoken_style_answer, + domains=domains, ) + # Build filter from metadata + search_filter = {"domains": domains} if domains else None + # 2. get docs docs = await self.retriever_pipeline.retrieve_docs( - partition=partition, query=query, use_map_reduce=use_map_reduce + partition=partition, query=query, use_map_reduce=use_map_reduce, filter=search_filter ) if use_map_reduce and docs: @@ -157,11 +164,14 @@ async def _prepare_for_chat_completion(self, partition: list[str], payload: dict async def _prepare_for_completions(self, partition: list[str], payload: dict): prompt = payload["prompt"] + metadata = payload.get("metadata", {}) or {} + domains = metadata.get("domains") + search_filter = {"domains": domains} if domains else None # 1. get the query query = await self.generate_query(messages=[{"role": "user", "content": prompt}]) # 2. get docs - docs = await self.retriever_pipeline.retrieve_docs(partition=partition, query=query) + docs = await self.retriever_pipeline.retrieve_docs(partition=partition, query=query, filter=search_filter) # 3. Format the retrieved docs context = format_context(docs, max_context_tokens=self.max_context_tokens) @@ -178,25 +188,17 @@ async def _prepare_for_completions(self, partition: list[str], payload: dict): return payload, docs async def completions(self, partition: list[str], payload: dict): - try: - if partition is None: - docs = [] - else: - payload, docs = await self._prepare_for_completions(partition=partition, payload=payload) - llm_output = self.llm_client.completions(request=payload) - return llm_output, docs - except Exception as e: - logger.error(f"Error during chat completion: {e!s}") - raise e + if partition is None: + docs = [] + else: + payload, docs = await self._prepare_for_completions(partition=partition, payload=payload) + llm_output = self.llm_client.completions(request=payload) + return llm_output, docs async def chat_completion(self, partition: list[str] | None, payload: dict): - try: - if partition is None: - docs = [] - else: - payload, docs = await self._prepare_for_chat_completion(partition=partition, payload=payload) - llm_output = self.llm_client.chat_completion(request=payload) - return llm_output, docs - except Exception as e: - logger.error(f"Error during chat completion: {e!s}") - raise e + if partition is None: + docs = [] + else: + payload, docs = await self._prepare_for_chat_completion(partition=partition, payload=payload) + llm_output = self.llm_client.chat_completion(request=payload) + return llm_output, docs diff --git a/openrag/components/ray_utils.py b/openrag/components/ray_utils.py index 9d1c31a5..06bba133 100644 --- a/openrag/components/ray_utils.py +++ b/openrag/components/ray_utils.py @@ -3,6 +3,7 @@ import ray from ray.exceptions import RayTaskError, TaskCancelledError +from utils.exceptions.common import RayActorError from utils.logger import get_logger logger = get_logger() @@ -33,7 +34,7 @@ async def call_ray_actor_with_timeout( TimeoutError: If the task exceeds the timeout asyncio.CancelledError: If the calling coroutine is cancelled TaskCancelledError: If the Ray task was cancelled - RuntimeError: If the Ray task failed with an error + RayActorError: If the Ray task failed with an error """ try: result = await asyncio.wait_for(asyncio.gather(future), timeout=timeout) @@ -54,4 +55,4 @@ async def call_ray_actor_with_timeout( raise except RayTaskError as e: - raise RuntimeError(f"{task_description} failed") from e + raise RayActorError(f"{task_description} failed") from e diff --git a/openrag/components/reranker.py b/openrag/components/reranker.py index 3768720f..3e19c315 100644 --- a/openrag/components/reranker.py +++ b/openrag/components/reranker.py @@ -4,6 +4,7 @@ from infinity_client.api.default import rerank from infinity_client.models import RerankInput, ReRankResult from langchain_core.documents.base import Document +from utils.exceptions.common import UnexpectedError class Reranker: @@ -38,10 +39,5 @@ async def rerank(self, query: str, documents: list[Document], top_k: int) -> lis return output except Exception as e: - self.logger.error( - "Reranking failed", - error=str(e), - model_name=self.model_name, - documents_count=len(documents), - ) - raise e + self.logger.exception("Reranking failed", query=query[:100], doc_count=len(documents)) + raise UnexpectedError("An unexpected error occurred during reranking") from e diff --git a/openrag/components/retriever.py b/openrag/components/retriever.py index 3c5caf32..8bf2c140 100644 --- a/openrag/components/retriever.py +++ b/openrag/components/retriever.py @@ -27,7 +27,7 @@ def __init__( pass @abstractmethod - async def retrieve(self, partition: list[str], query: str) -> list[Document]: + async def retrieve(self, partition: list[str], query: str, filter: dict | None = None) -> list[Document]: pass @@ -43,6 +43,7 @@ async def retrieve( self, partition: list[str], query: str, + filter: dict | None = None, ) -> list[Document]: db = get_vectordb() chunks = await db.async_search.remote( @@ -50,6 +51,7 @@ async def retrieve( partition=partition, top_k=self.top_k, similarity_threshold=self.similarity_threshold, + filter=filter, with_surrounding_chunks=self.with_surrounding_chunks, ) return chunks @@ -78,7 +80,7 @@ def __init__( prompt: ChatPromptTemplate = ChatPromptTemplate.from_template(MULTI_QUERY_PROMPT) self.generate_queries = prompt | llm | StrOutputParser() | (lambda x: x.split("[SEP]")) - async def retrieve(self, partition: list[str], query: str) -> list[Document]: + async def retrieve(self, partition: list[str], query: str, filter: dict | None = None) -> list[Document]: db = get_vectordb() logger.debug("Generating multiple queries", k_queries=self.k_queries) generated_queries = await self.generate_queries.ainvoke( @@ -92,6 +94,7 @@ async def retrieve(self, partition: list[str], query: str) -> list[Document]: partition=partition, top_k_per_query=self.top_k, similarity_threshold=self.similarity_threshold, + filter=filter, with_surrounding_chunks=self.with_surrounding_chunks, ) return chunks @@ -121,7 +124,7 @@ async def get_hyde(self, query: str): hyde_document = await self.hyde_generator.ainvoke({"query": query}) return hyde_document - async def retrieve(self, partition: list[str], query: str) -> list[Document]: + async def retrieve(self, partition: list[str], query: str, filter: dict | None = None) -> list[Document]: db = get_vectordb() hyde = await self.get_hyde(query) queries = [hyde] @@ -133,6 +136,7 @@ async def retrieve(self, partition: list[str], query: str) -> list[Document]: partition=partition, top_k_per_query=self.top_k, similarity_threshold=self.similarity_threshold, + filter=filter, with_surrounding_chunks=self.with_surrounding_chunks, ) diff --git a/openrag/models/indexer.py b/openrag/models/indexer.py index 61a14227..8dfccf81 100644 --- a/openrag/models/indexer.py +++ b/openrag/models/indexer.py @@ -1,6 +1,35 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field, field_validator class SearchRequest(BaseModel): query: str top_k: int | None = 5 # default to 5 if not provided + + +class FileMetadataSchema(BaseModel): + """Schema for validating file upload metadata. + + Metadata is passed as JSON in the file upload form and contains + optional file processing hints and domain filtering configuration. + """ + + mimetype: str | None = None + domains: list[str] = Field(default_factory=list) + + # Allow additional fields for backward compatibility + # (existing code may pass extra fields we don't validate) + model_config = {"extra": "allow"} + + @field_validator("domains") + @classmethod + def validate_domains(cls, v): + """Ensure domains is a list of non-empty strings.""" + if v: + if not isinstance(v, list): + raise ValueError("domains must be a list") + for domain in v: + if not isinstance(domain, str): + raise ValueError("All domains must be strings") + if not domain.strip(): + raise ValueError("Domain names cannot be empty") + return v diff --git a/openrag/routers/actors.py b/openrag/routers/actors.py index 49ab1dfa..9f30b884 100644 --- a/openrag/routers/actors.py +++ b/openrag/routers/actors.py @@ -51,24 +51,17 @@ ) async def list_ray_actors(): """List all known Ray actors and their status.""" - try: - actors = [ - { - "actor_id": a.actor_id, - "name": a.name, - "class_name": a.class_name, - "state": a.state, - "namespace": a.ray_namespace, - } - for a in list_actors() - ] - return JSONResponse(status_code=status.HTTP_200_OK, content={"actors": actors}) - except Exception: - logger.exception("Error getting actor summaries") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve actor summaries.", - ) + actors = [ + { + "actor_id": a.actor_id, + "name": a.name, + "class_name": a.class_name, + "state": a.state, + "namespace": a.ray_namespace, + } + for a in list_actors() + ] + return JSONResponse(status_code=status.HTTP_200_OK, content={"actors": actors}) @router.post( @@ -116,29 +109,16 @@ async def restart_actor( logger.info(f"Killed actor: {actor_name}") except ValueError: logger.warning("Actor not found. Creating new instance.", actor=actor_name) - except Exception as e: - logger.exception("Failed to kill actor", actor=actor_name) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to kill actor {actor_name}: {e!s}", - ) - try: - new_actor = actor_creation_map[actor_name]() - if "Semaphore" in actor_name: - new_actor = new_actor._actor - logger.info(f"Restarted actor: {actor_name}") - return JSONResponse( - status_code=status.HTTP_200_OK, - content={ - "message": f"Actor {actor_name} restarted successfully.", - "actor_name": actor_name, - "actor_id": new_actor._actor_id.hex(), - }, - ) - except Exception as e: - logger.exception("Failed to restart actor", actor=actor_name) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to restart actor {actor_name}: {e!s}", - ) + new_actor = actor_creation_map[actor_name]() + if "Semaphore" in actor_name: + new_actor = new_actor._actor + logger.info(f"Restarted actor: {actor_name}") + return JSONResponse( + status_code=status.HTTP_200_OK, + content={ + "message": f"Actor {actor_name} restarted successfully.", + "actor_name": actor_name, + "actor_id": new_actor._actor_id.hex(), + }, + ) diff --git a/openrag/routers/extract.py b/openrag/routers/extract.py index 32f51e5a..5179a146 100644 --- a/openrag/routers/extract.py +++ b/openrag/routers/extract.py @@ -48,31 +48,23 @@ async def get_extract( user_partitions=Depends(current_user_or_admin_partitions_list), ): log = logger.bind(extract_id=extract_id) - try: - chunk = await vectordb.get_chunk_by_id.remote(extract_id) - if chunk is None: - log.warning("Extract not found.") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Extract '{extract_id}' not found.", - ) - chunk_partition = chunk.metadata["partition"] - log.info(f"User partitions: {user_partitions}, Chunk partition: {chunk_partition}") - if chunk_partition not in user_partitions and user_partitions != ["all"]: - log.warning("User does not have access to this extract.") - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"User does not have access to extract '{extract_id}'.", - ) - log.info("Extract successfully retrieved.") - except HTTPException: - raise - except Exception as e: - log.exception("Failed to retrieve extract.", error=str(e)) + + chunk = await vectordb.get_chunk_by_id.remote(extract_id) + if chunk is None: + log.warning("Extract not found.") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Extract '{extract_id}' not found.", + ) + chunk_partition = chunk.metadata["partition"] + log.info(f"User partitions: {user_partitions}, Chunk partition: {chunk_partition}") + if chunk_partition not in user_partitions and user_partitions != ["all"]: + log.warning("User does not have access to this extract.") raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to retrieve extract: {e!s}", + status_code=status.HTTP_403_FORBIDDEN, + detail=f"User does not have access to extract '{extract_id}'.", ) + log.info("Extract successfully retrieved.") return JSONResponse( status_code=status.HTTP_200_OK, diff --git a/openrag/routers/indexer.py b/openrag/routers/indexer.py index 8f023d54..5a17c1fd 100644 --- a/openrag/routers/indexer.py +++ b/openrag/routers/indexer.py @@ -123,13 +123,6 @@ async def add_file( vectordb=Depends(get_vectordb), user=Depends(require_partition_editor), ): - log = logger.bind( - file_id=file_id, - partition=partition, - filename=file.filename, - user=user.get("display_name"), - ) - if await vectordb.file_exists.remote(file_id, partition): raise HTTPException( status_code=status.HTTP_409_CONFLICT, @@ -137,16 +130,9 @@ async def add_file( ) save_dir = Path(DATA_DIR) - try: - original_filename = file.filename - file.filename = sanitize_filename(file.filename) - file_path = await save_file_to_disk(file, save_dir, with_random_prefix=True) - except Exception as e: - log.exception("Failed to save file to disk.", error=str(e)) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(e), - ) + original_filename = file.filename + file.filename = sanitize_filename(file.filename) + file_path = await save_file_to_disk(file, save_dir, with_random_prefix=True) metadata.update( { @@ -240,8 +226,6 @@ async def put_file( vectordb=Depends(get_vectordb), user=Depends(require_partition_editor), ): - log = logger.bind(file_id=file_id, partition=partition, filename=file.filename) - if not await vectordb.file_exists.remote(file_id, partition): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -252,16 +236,9 @@ async def put_file( await indexer.delete_file.remote(file_id, partition) save_dir = Path(DATA_DIR) - try: - original_filename = file.filename - file.filename = sanitize_filename(file.filename) - file_path = await save_file_to_disk(file, save_dir, with_random_prefix=True) - except Exception: - log.exception("Failed to save file to disk.") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to save uploaded file.", - ) + original_filename = file.filename + file.filename = sanitize_filename(file.filename) + file_path = await save_file_to_disk(file, save_dir, with_random_prefix=True) metadata.update( { @@ -441,21 +418,13 @@ async def get_task_error( task_state_manager=Depends(get_task_state_manager), task_details=Depends(require_task_owner), ): - try: - error = await task_state_manager.get_error.remote(task_id) - if error is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"No error found for task '{task_id}'.", - ) - return {"task_id": task_id, "traceback": error.splitlines()} - except HTTPException: - raise - except Exception: + error = await task_state_manager.get_error.remote(task_id) + if error is None: raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve task error.", + status_code=status.HTTP_404_NOT_FOUND, + detail=f"No error found for task '{task_id}'.", ) + return {"task_id": task_id, "traceback": error.splitlines()} @router.get( @@ -475,32 +444,27 @@ async def get_task_error( """, ) async def get_task_logs(task_id: str, max_lines: int = 100, task_details=Depends(require_task_owner)): - try: - if not LOG_FILE.exists(): - raise HTTPException(status_code=500, detail="Log file not found.") - - logs = [] - with open(LOG_FILE, errors="replace") as f: - for line in reversed(list(f)): - try: - record = json.loads(line).get("record", {}) - if record.get("extra", {}).get("task_id") == task_id: - logs.append( - f"{record['time']['repr']} - {record['level']['name']} - {record['message']} - {(record['extra'])}" - ) - if len(logs) >= max_lines: - break - except json.JSONDecodeError: - continue - - if not logs: - raise HTTPException(status_code=404, detail=f"No logs found for task '{task_id}'") - - return JSONResponse(content={"task_id": task_id, "logs": logs[::-1]}) # restore order - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to fetch logs: {e!s}") + if not LOG_FILE.exists(): + raise HTTPException(status_code=500, detail="Log file not found.") + + logs = [] + with open(LOG_FILE, errors="replace") as f: + for line in reversed(list(f)): + try: + record = json.loads(line).get("record", {}) + if record.get("extra", {}).get("task_id") == task_id: + logs.append( + f"{record['time']['repr']} - {record['level']['name']} - {record['message']} - {(record['extra'])}" + ) + if len(logs) >= max_lines: + break + except json.JSONDecodeError: + continue + + if not logs: + raise HTTPException(status_code=404, detail=f"No logs found for task '{task_id}'") + + return JSONResponse(content={"task_id": task_id, "logs": logs[::-1]}) # restore order @router.delete( @@ -525,13 +489,9 @@ async def cancel_task( task_state_manager=Depends(get_task_state_manager), task_details=Depends(require_task_owner), ): - try: - obj_ref = await task_state_manager.get_object_ref.remote(task_id) - if obj_ref is None: - raise HTTPException(404, f"No ObjectRef stored for task {task_id}") - - ray.cancel(obj_ref["ref"], recursive=True) - return {"message": f"Cancellation signal sent for task {task_id}"} - except Exception as e: - logger.exception("Failed to cancel task.") - raise HTTPException(status_code=500, detail=str(e)) + obj_ref = await task_state_manager.get_object_ref.remote(task_id) + if obj_ref is None: + raise HTTPException(404, f"No ObjectRef stored for task {task_id}") + + ray.cancel(obj_ref["ref"], recursive=True) + return {"message": f"Cancellation signal sent for task {task_id}"} diff --git a/openrag/routers/openai.py b/openrag/routers/openai.py index f7e874cd..564f2dad 100644 --- a/openrag/routers/openai.py +++ b/openrag/routers/openai.py @@ -1,3 +1,4 @@ +import asyncio import json from pathlib import Path from urllib.parse import quote @@ -13,6 +14,7 @@ OpenAICompletionRequest, ) from utils.dependencies import get_vectordb +from utils.exceptions.base import OpenRAGError from utils.logger import get_logger from .utils import ( @@ -31,6 +33,19 @@ ragpipe = RagPipeline(config=config) +def _make_sse_error(message: str, code: str) -> str: + """Format an error as an SSE data chunk for streaming responses.""" + chunk = { + "error": { + "message": message, + "type": "error", + "param": None, + "code": code, + } + } + return f"data: {json.dumps(chunk)}\n\ndata: [DONE]\n\n" + + @router.get( "/models", summary="OpenAI-compatible model listing endpoint", @@ -163,25 +178,14 @@ async def openai_chat_completion( truncate(str(request.messages)), ) - try: - if is_direct_llm_model(request): - partitions = None - else: - partitions = await get_partition_name(model_name, user_partitions, is_admin=user["is_admin"]) - log.debug(f"Using partitions: {partitions}") - except Exception as e: - log.warning("Invalid model or partition", error=str(e)) - raise - - try: - llm_output, docs = await ragpipe.chat_completion(partition=partitions, payload=request.model_dump()) - log.debug("RAG chat completion pipeline executed.") - except Exception as e: - log.exception("Chat completion failed.", error=str(e)) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Chat completion failed: {e!s}", - ) + if is_direct_llm_model(request): + partitions = None + else: + partitions = await get_partition_name(model_name, user_partitions, is_admin=user["is_admin"]) + log.debug(f"Using partitions: {partitions}") + + llm_output, docs = await ragpipe.chat_completion(partition=partitions, payload=request.model_dump()) + log.debug("RAG chat completion pipeline executed.") metadata = __prepare_sources(request2, docs) metadata_json = json.dumps({"sources": metadata}) @@ -204,33 +208,23 @@ async def stream_response(): except json.JSONDecodeError as e: log.error("Failed to decode streamed chunk.", error=str(e)) raise + except asyncio.CancelledError: + log.info("Client disconnected during streaming") + return + except OpenRAGError as e: + log.warning("OpenRAG error during streaming", code=e.code, error=e.message) + yield _make_sse_error(e.message, e.code) except Exception as e: - log.warning("Error while generating streaming answer", error=str(e)) - error_chunk = { - "error": { - "message": f"Error while generating answer: {str(e)}", - "type": "error", - "param": None, - "code": "ERROR_ANSWER_GENERATION", - } - } - yield f"data: {json.dumps(error_chunk)}\n\n" - yield "data: [DONE]\n\n" + log.warning("Error during streaming", error=str(e)) + yield _make_sse_error("An unexpected error occurred during streaming", "UNEXPECTED_ERROR") return StreamingResponse(stream_response(), media_type="text/event-stream") else: - try: - chunk = await llm_output.__anext__() - chunk["model"] = model_name - chunk["extra"] = metadata_json - log.debug("Returning non-streaming completion chunk.") - return JSONResponse(content=chunk) - except Exception as e: - log.warning("Error while generating answer", error=str(e)) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error while generating answer: {e!s}", - ) + chunk = await llm_output.__anext__() + chunk["model"] = model_name + chunk["extra"] = metadata_json + log.debug("Returning non-streaming completion chunk.") + return JSONResponse(content=chunk) @router.post( @@ -285,37 +279,18 @@ async def openai_completion( detail="Streaming is not supported for this endpoint", ) - try: - if is_direct_llm_model(request): - partitions = None - else: - partitions = await get_partition_name(model_name, user_partitions, is_admin=user["is_admin"]) - - except Exception as e: - log.warning(f"Invalid model or partition: {e}") - raise - - try: - llm_output, docs = await ragpipe.completions(partition=partitions, payload=request.model_dump()) - log.debug("RAG completion pipeline executed.") - except Exception as e: - log.exception("Completion request failed.", error=str(e)) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Completion failed: {e!s}", - ) + if is_direct_llm_model(request): + partitions = None + else: + partitions = await get_partition_name(model_name, user_partitions, is_admin=user["is_admin"]) + + llm_output, docs = await ragpipe.completions(partition=partitions, payload=request.model_dump()) + log.debug("RAG completion pipeline executed.") metadata = __prepare_sources(request2, docs) metadata_json = json.dumps({"sources": metadata}) - try: - complete_response = await llm_output.__anext__() - complete_response["extra"] = metadata_json - log.debug("Returning completion response.") - return JSONResponse(content=complete_response) - except Exception as e: - log.warning("No response from LLM.", error=str(e)) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"No response from LLM: {e!s}", - ) + complete_response = await llm_output.__anext__() + complete_response["extra"] = metadata_json + log.debug("Returning completion response.") + return JSONResponse(content=complete_response) diff --git a/openrag/routers/tools.py b/openrag/routers/tools.py index 9d2c725f..315f394f 100644 --- a/openrag/routers/tools.py +++ b/openrag/routers/tools.py @@ -118,15 +118,6 @@ async def execute_tool( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Tool {tool['name']} not found", ) - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed during tool execution.", extra={"error": str(e)}) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Tool execution failed due to an internal error.", - ) finally: # Cleanup of the temporary file if file_path is not None: diff --git a/openrag/routers/utils.py b/openrag/routers/utils.py index cf94ae6e..2aa6fed5 100644 --- a/openrag/routers/utils.py +++ b/openrag/routers/utils.py @@ -4,9 +4,12 @@ from typing import Any import consts +import httpx from config import load_config from fastapi import Depends, Form, HTTPException, Request, UploadFile, status +from models.indexer import FileMetadataSchema from openai import AsyncOpenAI +from pydantic import ValidationError from utils.dependencies import get_task_state_manager, get_vectordb from utils.logger import get_logger @@ -196,10 +199,18 @@ async def validate_file_id(file_id: str): async def validate_metadata(metadata: Any | None = Form(None)): try: processed_metadata = metadata or "{}" - processed_metadata = json.loads(processed_metadata) - return processed_metadata + parsed = json.loads(processed_metadata) + + # Validate against Pydantic schema + validated = FileMetadataSchema(**parsed) + return validated.model_dump() + except json.JSONDecodeError: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid JSON in metadata") + except ValidationError as e: + # Format Pydantic validation errors for user-friendly response + errors = "; ".join(f"{err['loc'][0]}: {err['msg']}" for err in e.errors()) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid metadata: {errors}") async def validate_file_format( @@ -248,13 +259,25 @@ async def check_llm_model_availability(request: Request): status_code=status.HTTP_404_NOT_FOUND, detail=f"Only these models ({available_models}) are available for your `{model_type}`. Please check your configuration file.", ) + except HTTPException: + raise + except httpx.TimeoutException: + logger.warning("LLM model availability check timed out", model=model_type) + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail="LLM service timed out", + ) + except httpx.HTTPError as e: + logger.warning("LLM model availability check failed", model=model_type, error=str(e)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="LLM service is unavailable", + ) except Exception as e: - logger.exception("Failed to validate model", model=model_type, error=str(e)) - if isinstance(e, HTTPException): - raise + logger.exception("Failed to check LLM model availability", model=model_type, error=str(e)) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error while checking the `{model_type}` endpoint, it seems not available at this moment", + detail="Failed to check LLM model availability", ) diff --git a/openrag/scripts/migrations/alembic/env.py b/openrag/scripts/migrations/alembic/env.py index 23038aed..182c78e6 100644 --- a/openrag/scripts/migrations/alembic/env.py +++ b/openrag/scripts/migrations/alembic/env.py @@ -3,7 +3,7 @@ from alembic import context from components.indexer.vectordb.utils import Base from config import load_config -from sqlalchemy import engine_from_config, pool +from sqlalchemy import URL, engine_from_config, pool rag_config = load_config() @@ -25,10 +25,16 @@ collection_name = rag_config.vectordb.collection_name -database_url = ( - f"postgresql://{rdb_user}:{rdb_password}@{rdb_host}:{rdb_port}/partitions_for_collection_{collection_name}" +database_url = URL.create( + drivername="postgresql", + username=rdb_user, + password=rdb_password, + host=rdb_host, + port=rdb_port, + database=f"partitions_for_collection_{collection_name}", ) -config.set_main_option("sqlalchemy.url", database_url) +# config.set_main_option expects a string, so convert URL object +config.set_main_option("sqlalchemy.url", str(database_url)) # add your model's MetaData object here # for 'autogenerate' support diff --git a/openrag/utils/exceptions/__init__.py b/openrag/utils/exceptions/__init__.py index 9b5ed21c..0333b6a3 100644 --- a/openrag/utils/exceptions/__init__.py +++ b/openrag/utils/exceptions/__init__.py @@ -1 +1,2 @@ from .base import * +from .common import * diff --git a/openrag/utils/exceptions/common.py b/openrag/utils/exceptions/common.py new file mode 100644 index 00000000..eaf1c7ce --- /dev/null +++ b/openrag/utils/exceptions/common.py @@ -0,0 +1,29 @@ +from .base import OpenRAGError + + +class FileStorageError(OpenRAGError): + """Raised when file I/O operations fail (save, read, delete).""" + + def __init__(self, message: str, **kwargs): + super().__init__(message=message, code="FILE_STORAGE_ERROR", status_code=500, **kwargs) + + +class RayActorError(OpenRAGError): + """Raised when a Ray actor operation fails.""" + + def __init__(self, message: str, **kwargs): + super().__init__(message=message, code="RAY_ACTOR_ERROR", status_code=500, **kwargs) + + +class ToolExecutionError(OpenRAGError): + """Raised when tool execution fails.""" + + def __init__(self, message: str, **kwargs): + super().__init__(message=message, code="TOOL_EXECUTION_ERROR", status_code=500, **kwargs) + + +class UnexpectedError(OpenRAGError): + """Raised for unexpected errors that don't match any specific category.""" + + def __init__(self, message: str, **kwargs): + super().__init__(message=message, code="UNEXPECTED_ERROR", status_code=500, **kwargs)