diff --git a/inference/core/env.py b/inference/core/env.py index f52adb0199..343e8a19f3 100644 --- a/inference/core/env.py +++ b/inference/core/env.py @@ -807,6 +807,31 @@ WEBRTC_MODAL_USAGE_QUOTA_ENABLED = str2bool( os.getenv("WEBRTC_MODAL_USAGE_QUOTA_ENABLED", "False") ) + +# +# Workspace stream quota +# +# Redis-base rate limiting that disables more than N concurrent +# connections from a single workspace +WEBRTC_WORKSPACE_STREAM_QUOTA_ENABLED = str2bool( + os.getenv("WEBRTC_WORKSPACE_STREAM_QUOTA_ENABLED", "False") +) +WEBRTC_WORKSPACE_STREAM_QUOTA = int(os.getenv("WEBRTC_WORKSPACE_STREAM_QUOTA", "2")) +# TTL in seconds for active stream entries (auto-expire if no explicit cleanup) +WEBRTC_WORKSPACE_STREAM_TTL_SECONDS = int( + os.getenv("WEBRTC_WORKSPACE_STREAM_TTL_SECONDS", "60") +) +# URL for Modal to send session heartbeats to keep session alive +# Example: "https://serverless.roboflow.com/webrtc/session/heartbeat" +WEBRTC_SESSION_HEARTBEAT_URL = os.getenv( + "WEBRTC_SESSION_HEARTBEAT_URL", + None, +) +# How often Modal sends session heartbeats (in seconds) +WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS = int( + os.getenv("WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS", "30") +) + WEBRTC_DATA_CHANNEL_BUFFER_DRAINING_DELAY = float( os.getenv("WEBRTC_DATA_CHANNEL_BUFFER_DRAINING_DELAY", "0.1") ) diff --git a/inference/core/exceptions.py b/inference/core/exceptions.py index a4541c48cc..1b2fb3e7fd 100644 --- a/inference/core/exceptions.py +++ b/inference/core/exceptions.py @@ -224,3 +224,14 @@ class WebRTCConfigurationError(Exception): class CreditsExceededError(Exception): pass + + +class WorkspaceStreamQuotaError(Exception): + """Raised when the workspace stream quota has been exceeded. + + This error is returned when a workspace has reached its maximum number + of concurrent WebRTC streams. This is to prevent that a single user + uses all our modal resources. + """ + + pass diff --git a/inference/core/interfaces/http/error_handlers.py b/inference/core/interfaces/http/error_handlers.py index b7f38c0012..d30a6e696f 100644 --- a/inference/core/interfaces/http/error_handlers.py +++ b/inference/core/interfaces/http/error_handlers.py @@ -9,6 +9,7 @@ ContentTypeMissing, CreditsExceededError, InferenceModelNotFound, + WorkspaceStreamQuotaError, InputImageLoadError, InvalidEnvironmentVariableError, InvalidMaskDecodeArgument, @@ -444,6 +445,15 @@ def wrapped_route(*args, **kwargs): "error_type": "CreditsExceededError", }, ) + except WorkspaceStreamQuotaError as error: + logger.error("%s: %s", type(error).__name__, error) + resp = JSONResponse( + status_code=429, + content={ + "message": str(error), + "error_type": "WorkspaceStreamQuotaError", + }, + ) except Exception as error: logger.exception("%s: %s", type(error).__name__, error) resp = JSONResponse(status_code=500, content={"message": "Internal error."}) @@ -464,7 +474,6 @@ def with_route_exceptions_async(route): Callable: The wrapped route. """ - @wraps(route) async def wrapped_route(*args, **kwargs): try: return await route(*args, **kwargs) @@ -816,9 +825,25 @@ async def wrapped_route(*args, **kwargs): "error_type": "CreditsExceededError", }, ) + except WorkspaceStreamQuotaError as error: + logger.error("%s: %s", type(error).__name__, error) + resp = JSONResponse( + status_code=429, + content={ + "message": str(error), + "error_type": "WorkspaceStreamQuotaError", + }, + ) except Exception as error: logger.exception("%s: %s", type(error).__name__, error) resp = JSONResponse(status_code=500, content={"message": "Internal error."}) return resp + wrapped_route.__wrapped__ = route + wrapped_route.__name__ = route.__name__ + wrapped_route.__doc__ = route.__doc__ + wrapped_route.__module__ = route.__module__ + wrapped_route.__qualname__ = route.__qualname__ + wrapped_route.__annotations__ = route.__annotations__ + return wrapped_route diff --git a/inference/core/interfaces/http/http_api.py b/inference/core/interfaces/http/http_api.py index 16b98c1e69..373cbaf374 100644 --- a/inference/core/interfaces/http/http_api.py +++ b/inference/core/interfaces/http/http_api.py @@ -235,6 +235,7 @@ WebRTCWorkerRequest, WebRTCWorkerResult, ) +from inference.core.interfaces.webrtc_worker.utils import refresh_webrtc_session from inference.core.managers.base import ModelManager from inference.core.managers.metrics import get_container_stats from inference.core.managers.prometheus import InferenceInstrumentator @@ -1645,6 +1646,35 @@ async def initialise_webrtc_worker( type=worker_result.answer.type, ) + @app.post( + "/webrtc/session/heartbeat", + summary="WebRTC session heartbeat", + ) + async def webrtc_session_heartbeat( + request: Request, + ) -> dict: + """Receive heartbeat for an active WebRTC session. + + This endpoint is called periodically to indicate + that their session is still active. The session will be removed from + the quota count if no heartbeat is received within the TTL period. + """ + body = await request.json() + workspace_id = body.get("workspace_id") + session_id = body.get("session_id") + + if not workspace_id or not session_id: + return { + "status": "error", + "message": "workspace_id and session_id required", + } + + refresh_webrtc_session( + workspace_id=workspace_id, + session_id=session_id, + ) + return {"status": "ok"} + if ENABLE_STREAM_API: @app.get( diff --git a/inference/core/interfaces/webrtc_worker/__init__.py b/inference/core/interfaces/webrtc_worker/__init__.py index a69dea8e6f..82a122afbd 100644 --- a/inference/core/interfaces/webrtc_worker/__init__.py +++ b/inference/core/interfaces/webrtc_worker/__init__.py @@ -1,12 +1,16 @@ import asyncio import multiprocessing +import uuid from inference.core.env import ( WEBRTC_MODAL_TOKEN_ID, WEBRTC_MODAL_TOKEN_SECRET, WEBRTC_MODAL_USAGE_QUOTA_ENABLED, + WEBRTC_WORKSPACE_STREAM_QUOTA, + WEBRTC_WORKSPACE_STREAM_QUOTA_ENABLED, + WEBRTC_WORKSPACE_STREAM_TTL_SECONDS, ) -from inference.core.exceptions import CreditsExceededError +from inference.core.exceptions import CreditsExceededError, WorkspaceStreamQuotaError from inference.core.interfaces.webrtc_worker.cpu import rtc_peer_connection_process from inference.core.interfaces.webrtc_worker.entities import ( RTCIceServer, @@ -15,6 +19,7 @@ WebRTCWorkerResult, ) from inference.core.logger import logger +from inference.core.roboflow_api import get_roboflow_workspace async def start_worker( @@ -36,7 +41,11 @@ async def start_worker( from inference.core.interfaces.webrtc_worker.modal import ( spawn_rtc_peer_connection_modal, ) - from inference.core.interfaces.webrtc_worker.utils import is_over_quota + from inference.core.interfaces.webrtc_worker.utils import ( + is_over_quota, + is_over_workspace_session_quota, + register_webrtc_session, + ) except ImportError: raise ImportError( "Modal not installed, please install it using 'pip install modal'" @@ -46,6 +55,39 @@ async def start_worker( logger.error("API key over quota") raise CreditsExceededError("API key over quota") + workspace_id = None + session_id = str(uuid.uuid4()) + if WEBRTC_WORKSPACE_STREAM_QUOTA_ENABLED: + workspace_id = get_roboflow_workspace(api_key=webrtc_request.api_key) + if workspace_id and is_over_workspace_session_quota( + workspace_id=workspace_id, + quota=WEBRTC_WORKSPACE_STREAM_QUOTA, + ttl_seconds=WEBRTC_WORKSPACE_STREAM_TTL_SECONDS, + ): + logger.warning( + "Workspace %s has exceeded the stream quota of %d", + workspace_id, + WEBRTC_WORKSPACE_STREAM_QUOTA, + ) + raise WorkspaceStreamQuotaError( + f"You have reached the maximum of {WEBRTC_WORKSPACE_STREAM_QUOTA} " + f"concurrent streams. Please contact Roboflow to increase your quota." + ) + + if workspace_id: + register_webrtc_session( + workspace_id=workspace_id, + session_id=session_id, + ) + # we need to pass to modal how to identifier workspace/ session id. + webrtc_request.workspace_id = workspace_id + webrtc_request.session_id = session_id + logger.info( + "Started WebRTC session %s for workspace %s", + session_id, + workspace_id, + ) + loop = asyncio.get_event_loop() result = await loop.run_in_executor( None, diff --git a/inference/core/interfaces/webrtc_worker/entities.py b/inference/core/interfaces/webrtc_worker/entities.py index b705c4864c..55c748df53 100644 --- a/inference/core/interfaces/webrtc_worker/entities.py +++ b/inference/core/interfaces/webrtc_worker/entities.py @@ -49,6 +49,9 @@ class WebRTCWorkerRequest(BaseModel): # must be valid region: https://modal.com/docs/guide/region-selection#region-options requested_region: Optional[str] = None + workspace_id: Optional[str] = None + session_id: Optional[str] = None + class WebRTCVideoMetadata(BaseModel): frame_id: int diff --git a/inference/core/interfaces/webrtc_worker/modal.py b/inference/core/interfaces/webrtc_worker/modal.py index ce1be6f638..c681fda9f2 100644 --- a/inference/core/interfaces/webrtc_worker/modal.py +++ b/inference/core/interfaces/webrtc_worker/modal.py @@ -50,6 +50,8 @@ WEBRTC_MODAL_TOKEN_SECRET, WEBRTC_MODAL_USAGE_QUOTA_ENABLED, WEBRTC_MODAL_WATCHDOG_TIMEMOUT, + WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS, + WEBRTC_SESSION_HEARTBEAT_URL, WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE, ) from inference.core.exceptions import ( @@ -210,6 +212,12 @@ def check_nvidia_smi_gpu() -> str: "WEBRTC_GZIP_PREVIEW_FRAME_COMPRESSION": str( WEBRTC_GZIP_PREVIEW_FRAME_COMPRESSION ), + "WEBRTC_SESSION_HEARTBEAT_URL": ( + WEBRTC_SESSION_HEARTBEAT_URL if WEBRTC_SESSION_HEARTBEAT_URL else "" + ), + "WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS": str( + WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS + ), }, "volumes": {MODEL_CACHE_DIR: rfcache_volume}, } @@ -388,6 +396,10 @@ def send_answer(obj: WebRTCWorkerResult): watchdog = Watchdog( api_key=webrtc_request.api_key, timeout_seconds=WEBRTC_MODAL_WATCHDOG_TIMEMOUT, + # Use getattr for backwards compatibility with older Docker images + workspace_id=getattr(webrtc_request, "workspace_id", None), + session_id=getattr(webrtc_request, "session_id", None), + heartbeat_url=WEBRTC_SESSION_HEARTBEAT_URL, ) try: diff --git a/inference/core/interfaces/webrtc_worker/utils.py b/inference/core/interfaces/webrtc_worker/utils.py index 2c0a4bd4f3..da8cf48182 100644 --- a/inference/core/interfaces/webrtc_worker/utils.py +++ b/inference/core/interfaces/webrtc_worker/utils.py @@ -10,6 +10,8 @@ from av import VideoFrame from inference.core import logger +from inference.core.cache import cache +from inference.core.cache.redis import RedisCache from inference.core.env import DEBUG_WEBRTC_PROCESSING_LATENCY from inference.core.interfaces.camera.entities import VideoFrame as InferenceVideoFrame from inference.core.interfaces.stream.inference_pipeline import InferencePipeline @@ -240,6 +242,114 @@ def is_over_quota(api_key: str) -> bool: return is_over_quota +def _get_concurrent_sessions_key(workspace_id: str) -> str: + """Get the Redis key for tracking concurrent sessions for a workspace.""" + return f"webrtc:concurrent_sessions:{workspace_id}" + + +def register_webrtc_session(workspace_id: str, session_id: str) -> None: + """Register a new concurrent WebRTC session for a workspace. + + Adds the session to a Redis sorted set with current timestamp as score. + Expired entries are cleaned up on read via ZREMRANGEBYSCORE (O(log N + M)). + + Args: + workspace_id: The workspace identifier + session_id: Unique identifier for this session + """ + if not isinstance(cache, RedisCache): + logger.warning( + "Redis not available, skipping session registration for rate limiting" + ) + return + + key = _get_concurrent_sessions_key(workspace_id) + try: + cache.client.zadd(key, {session_id: time.time()}) + logger.info( + "Registered session: workspace=%s, session=%s", + workspace_id, + session_id, + ) + except Exception as e: + logger.error("Failed to register session: %s", e) + + +def refresh_webrtc_session(workspace_id: str, session_id: str) -> None: + """Refresh the timestamp for a concurrent WebRTC session. + + Should be called periodically to keep the session marked as active. + If not refreshed, the session will be considered expired after TTL. + + Args: + workspace_id: The workspace identifier + session_id: The session identifier to refresh + """ + if not isinstance(cache, RedisCache): + return + + key = _get_concurrent_sessions_key(workspace_id) + try: + cache.client.zadd(key, {session_id: time.time()}) + except Exception as e: + logger.error("Failed to refresh session: %s", e) + + +def get_concurrent_session_count(workspace_id: str, ttl_seconds: int) -> int: + """Get the count of concurrent sessions for a workspace. + + Cleans up expired entries (older than TTL) before counting. + + Args: + workspace_id: The workspace identifier + ttl_seconds: TTL in seconds - entries older than this are considered expired + + Returns: + Number of concurrent sessions for the workspace + """ + if not isinstance(cache, RedisCache): + logger.warning( + "Redis not available, cannot count concurrent sessions - allowing request" + ) + return 0 + + key = _get_concurrent_sessions_key(workspace_id) + cutoff = time.time() - ttl_seconds + + try: + # Step 1: we remove expired entries + # Step 2: we return what is still valid + cache.client.zremrangebyscore(key, "-inf", cutoff) + count = cache.client.zcard(key) + return count + except Exception as e: + logger.error("Failed to get concurrent session count: %s", e) + return 0 + + +def is_over_workspace_session_quota( + workspace_id: str, quota: int, ttl_seconds: int +) -> bool: + """Check if a workspace has exceeded its concurrent session quota. + + Args: + workspace_id: The workspace identifier + quota: Maximum number of concurrent sessions allowed + ttl_seconds: TTL for considering sessions as active + + Returns: + True if the workspace has reached or exceeded the quota + """ + count = get_concurrent_session_count(workspace_id, ttl_seconds) + logger.info( + "Workspace %s has %d concurrent sessions (quota: %d)", + workspace_id, + count, + quota, + ) + return count >= quota + + def get_video_fps(filepath: str) -> Optional[float]: """Detect video FPS from container metadata. diff --git a/inference/core/interfaces/webrtc_worker/watchdog.py b/inference/core/interfaces/webrtc_worker/watchdog.py index 1c276024a3..650beeee8f 100644 --- a/inference/core/interfaces/webrtc_worker/watchdog.py +++ b/inference/core/interfaces/webrtc_worker/watchdog.py @@ -3,7 +3,12 @@ import time from typing import Callable, Optional -from inference.core.env import WEBRTC_MODAL_USAGE_QUOTA_ENABLED +import requests + +from inference.core.env import ( + WEBRTC_MODAL_USAGE_QUOTA_ENABLED, + WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS, +) from inference.core.interfaces.webrtc_worker.utils import is_over_quota from inference.core.logger import logger @@ -14,6 +19,9 @@ def __init__( api_key: str, timeout_seconds: int, on_timeout: Optional[Callable[[], None]] = None, + workspace_id: Optional[str] = None, + session_id: Optional[str] = None, + heartbeat_url: Optional[str] = None, ): self._api_key = api_key self.timeout_seconds = timeout_seconds @@ -25,6 +33,10 @@ def __init__( self._log_interval_seconds = 10 self._heartbeats = 0 self._total_heartbeats = 0 + self._workspace_id = workspace_id + self._session_id = session_id + self._heartbeat_url = heartbeat_url + self._last_session_heartbeat_ts = datetime.datetime.now() @property def total_heartbeats(self) -> int: @@ -43,6 +55,50 @@ def stop(self): if self._thread.is_alive(): self._thread.join() + def _send_session_heartbeat(self): + """Send heartbeat to keep the session alive in the quota system. + + This is used to sign that the session is alive so the system + doesnt allow more than N concurrent sessions from a single workspace. + """ + if not all( + [ + self._heartbeat_url, + self._workspace_id, + self._session_id, + ] + ): + logger.info( + "Skipping session heartbeat: url=%s, workspace=%s, session=%s", + bool(self._heartbeat_url), + bool(self._workspace_id), + bool(self._session_id), + ) + return + + try: + response = requests.post( + self._heartbeat_url, + json={ + "workspace_id": self._workspace_id, + "session_id": self._session_id, + }, + headers={"Content-Type": "application/json"}, + timeout=5, + ) + if response.status_code == 200: + logger.info( + "Session heartbeat sent for workspace=%s session=%s", + self._workspace_id, + self._session_id, + ) + else: + logger.warning( + "Failed to send session heartbeat: %s", response.status_code + ) + except Exception as e: + logger.warning("Error sending session heartbeat: %s", e) + def _watchdog_thread(self): logger.info("Watchdog thread started") while not self._stopping: @@ -62,6 +118,12 @@ def _watchdog_thread(self): message=f"API key over quota, heartbeats: {self._total_heartbeats}" ) break + + if ( + datetime.datetime.now() - self._last_session_heartbeat_ts + ).total_seconds() > WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS: + self._send_session_heartbeat() + self._last_session_heartbeat_ts = datetime.datetime.now() time.sleep(1) logger.info("Watchdog thread stopped, heartbeats: %s", self._total_heartbeats) diff --git a/inference/core/interfaces/webrtc_worker/webrtc.py b/inference/core/interfaces/webrtc_worker/webrtc.py index 2ddd883232..2cdbe77a56 100644 --- a/inference/core/interfaces/webrtc_worker/webrtc.py +++ b/inference/core/interfaces/webrtc_worker/webrtc.py @@ -927,8 +927,8 @@ async def init_rtc_peer_connection_with_loop( KeyError, NotImplementedError, ) as error: - # heartbeat to indicate caller error - heartbeat_callback() + if heartbeat_callback: + heartbeat_callback() send_answer( WebRTCWorkerResult( exception_type=error.__class__.__name__, @@ -937,8 +937,8 @@ async def init_rtc_peer_connection_with_loop( ) return except WebRTCConfigurationError as error: - # heartbeat to indicate caller error - heartbeat_callback() + if heartbeat_callback: + heartbeat_callback() send_answer( WebRTCWorkerResult( exception_type=error.__class__.__name__, @@ -947,8 +947,8 @@ async def init_rtc_peer_connection_with_loop( ) return except RoboflowAPINotAuthorizedError: - # heartbeat to indicate caller error - heartbeat_callback() + if heartbeat_callback: + heartbeat_callback() send_answer( WebRTCWorkerResult( exception_type=RoboflowAPINotAuthorizedError.__name__, @@ -957,8 +957,8 @@ async def init_rtc_peer_connection_with_loop( ) return except RoboflowAPINotNotFoundError: - # heartbeat to indicate caller error - heartbeat_callback() + if heartbeat_callback: + heartbeat_callback() send_answer( WebRTCWorkerResult( exception_type=RoboflowAPINotNotFoundError.__name__, @@ -967,8 +967,8 @@ async def init_rtc_peer_connection_with_loop( ) return except WorkflowSyntaxError as error: - # heartbeat to indicate caller error - heartbeat_callback() + if heartbeat_callback: + heartbeat_callback() send_answer( WebRTCWorkerResult( exception_type=WorkflowSyntaxError.__name__, @@ -979,8 +979,8 @@ async def init_rtc_peer_connection_with_loop( ) return except WorkflowError as error: - # heartbeat to indicate caller error - heartbeat_callback() + if heartbeat_callback: + heartbeat_callback() send_answer( WebRTCWorkerResult( exception_type=WorkflowError.__name__,