Skip to content
25 changes: 25 additions & 0 deletions inference/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,31 @@
WEBRTC_MODAL_USAGE_QUOTA_ENABLED = str2bool(
os.getenv("WEBRTC_MODAL_USAGE_QUOTA_ENABLED", "False")
)

#
# Workspace stream quota
#
# Redis-base rate limiting that disables more than N concurrent
# connections from a single workspace
WEBRTC_WORKSPACE_STREAM_QUOTA_ENABLED = str2bool(
os.getenv("WEBRTC_WORKSPACE_STREAM_QUOTA_ENABLED", "False")
)
WEBRTC_WORKSPACE_STREAM_QUOTA = int(os.getenv("WEBRTC_WORKSPACE_STREAM_QUOTA", "2"))
# TTL in seconds for active stream entries (auto-expire if no explicit cleanup)
WEBRTC_WORKSPACE_STREAM_TTL_SECONDS = int(
os.getenv("WEBRTC_WORKSPACE_STREAM_TTL_SECONDS", "60")
)
# URL for Modal to send session heartbeats to keep session alive
# Example: "https://serverless.roboflow.com/webrtc/session/heartbeat"
WEBRTC_SESSION_HEARTBEAT_URL = os.getenv(
"WEBRTC_SESSION_HEARTBEAT_URL",
None,
)
# How often Modal sends session heartbeats (in seconds)
WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS = int(
os.getenv("WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS", "30")
)

WEBRTC_DATA_CHANNEL_BUFFER_DRAINING_DELAY = float(
os.getenv("WEBRTC_DATA_CHANNEL_BUFFER_DRAINING_DELAY", "0.1")
)
Expand Down
11 changes: 11 additions & 0 deletions inference/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,14 @@ class WebRTCConfigurationError(Exception):

class CreditsExceededError(Exception):
pass


class WorkspaceStreamQuotaError(Exception):
"""Raised when the workspace stream quota has been exceeded.

This error is returned when a workspace has reached its maximum number
of concurrent WebRTC streams. This is to prevent that a single user
uses all our modal resources.
"""

pass
27 changes: 26 additions & 1 deletion inference/core/interfaces/http/error_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ContentTypeMissing,
CreditsExceededError,
InferenceModelNotFound,
WorkspaceStreamQuotaError,
InputImageLoadError,
InvalidEnvironmentVariableError,
InvalidMaskDecodeArgument,
Expand Down Expand Up @@ -444,6 +445,15 @@ def wrapped_route(*args, **kwargs):
"error_type": "CreditsExceededError",
},
)
except WorkspaceStreamQuotaError as error:
logger.error("%s: %s", type(error).__name__, error)
resp = JSONResponse(
status_code=429,
content={
"message": str(error),
"error_type": "WorkspaceStreamQuotaError",
},
)
except Exception as error:
logger.exception("%s: %s", type(error).__name__, error)
resp = JSONResponse(status_code=500, content={"message": "Internal error."})
Expand All @@ -464,7 +474,6 @@ def with_route_exceptions_async(route):
Callable: The wrapped route.
"""

@wraps(route)
async def wrapped_route(*args, **kwargs):
try:
return await route(*args, **kwargs)
Expand Down Expand Up @@ -816,9 +825,25 @@ async def wrapped_route(*args, **kwargs):
"error_type": "CreditsExceededError",
},
)
except WorkspaceStreamQuotaError as error:
logger.error("%s: %s", type(error).__name__, error)
resp = JSONResponse(
status_code=429,
content={
"message": str(error),
"error_type": "WorkspaceStreamQuotaError",
},
)
except Exception as error:
logger.exception("%s: %s", type(error).__name__, error)
resp = JSONResponse(status_code=500, content={"message": "Internal error."})
return resp

wrapped_route.__wrapped__ = route
wrapped_route.__name__ = route.__name__
wrapped_route.__doc__ = route.__doc__
wrapped_route.__module__ = route.__module__
wrapped_route.__qualname__ = route.__qualname__
wrapped_route.__annotations__ = route.__annotations__

return wrapped_route
30 changes: 30 additions & 0 deletions inference/core/interfaces/http/http_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@
WebRTCWorkerRequest,
WebRTCWorkerResult,
)
from inference.core.interfaces.webrtc_worker.utils import refresh_webrtc_session
from inference.core.managers.base import ModelManager
from inference.core.managers.metrics import get_container_stats
from inference.core.managers.prometheus import InferenceInstrumentator
Expand Down Expand Up @@ -1645,6 +1646,35 @@ async def initialise_webrtc_worker(
type=worker_result.answer.type,
)

@app.post(
"/webrtc/session/heartbeat",
summary="WebRTC session heartbeat",
)
async def webrtc_session_heartbeat(
request: Request,
) -> dict:
"""Receive heartbeat for an active WebRTC session.

This endpoint is called periodically to indicate
that their session is still active. The session will be removed from
the quota count if no heartbeat is received within the TTL period.
"""
body = await request.json()
workspace_id = body.get("workspace_id")
session_id = body.get("session_id")

if not workspace_id or not session_id:
return {
"status": "error",
"message": "workspace_id and session_id required",
}

refresh_webrtc_session(
workspace_id=workspace_id,
session_id=session_id,
)
return {"status": "ok"}

if ENABLE_STREAM_API:

@app.get(
Expand Down
46 changes: 44 additions & 2 deletions inference/core/interfaces/webrtc_worker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import asyncio
import multiprocessing
import uuid

from inference.core.env import (
WEBRTC_MODAL_TOKEN_ID,
WEBRTC_MODAL_TOKEN_SECRET,
WEBRTC_MODAL_USAGE_QUOTA_ENABLED,
WEBRTC_WORKSPACE_STREAM_QUOTA,
WEBRTC_WORKSPACE_STREAM_QUOTA_ENABLED,
WEBRTC_WORKSPACE_STREAM_TTL_SECONDS,
)
from inference.core.exceptions import CreditsExceededError
from inference.core.exceptions import CreditsExceededError, WorkspaceStreamQuotaError
from inference.core.interfaces.webrtc_worker.cpu import rtc_peer_connection_process
from inference.core.interfaces.webrtc_worker.entities import (
RTCIceServer,
Expand All @@ -15,6 +19,7 @@
WebRTCWorkerResult,
)
from inference.core.logger import logger
from inference.core.roboflow_api import get_roboflow_workspace


async def start_worker(
Expand All @@ -36,7 +41,11 @@ async def start_worker(
from inference.core.interfaces.webrtc_worker.modal import (
spawn_rtc_peer_connection_modal,
)
from inference.core.interfaces.webrtc_worker.utils import is_over_quota
from inference.core.interfaces.webrtc_worker.utils import (
is_over_quota,
is_over_workspace_session_quota,
register_webrtc_session,
)
except ImportError:
raise ImportError(
"Modal not installed, please install it using 'pip install modal'"
Expand All @@ -46,6 +55,39 @@ async def start_worker(
logger.error("API key over quota")
raise CreditsExceededError("API key over quota")

workspace_id = None
session_id = str(uuid.uuid4())
if WEBRTC_WORKSPACE_STREAM_QUOTA_ENABLED:
workspace_id = get_roboflow_workspace(api_key=webrtc_request.api_key)
if workspace_id and is_over_workspace_session_quota(
workspace_id=workspace_id,
quota=WEBRTC_WORKSPACE_STREAM_QUOTA,
ttl_seconds=WEBRTC_WORKSPACE_STREAM_TTL_SECONDS,
):
logger.warning(
"Workspace %s has exceeded the stream quota of %d",
workspace_id,
WEBRTC_WORKSPACE_STREAM_QUOTA,
)
raise WorkspaceStreamQuotaError(
f"You have reached the maximum of {WEBRTC_WORKSPACE_STREAM_QUOTA} "
f"concurrent streams. Please contact Roboflow to increase your quota."
)

if workspace_id:
register_webrtc_session(
workspace_id=workspace_id,
session_id=session_id,
)
# we need to pass to modal how to identifier workspace/ session id.
webrtc_request.workspace_id = workspace_id
webrtc_request.session_id = session_id
logger.info(
"Started WebRTC session %s for workspace %s",
session_id,
workspace_id,
)

loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
Expand Down
3 changes: 3 additions & 0 deletions inference/core/interfaces/webrtc_worker/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class WebRTCWorkerRequest(BaseModel):
# must be valid region: https://modal.com/docs/guide/region-selection#region-options
requested_region: Optional[str] = None

workspace_id: Optional[str] = None
session_id: Optional[str] = None


class WebRTCVideoMetadata(BaseModel):
frame_id: int
Expand Down
12 changes: 12 additions & 0 deletions inference/core/interfaces/webrtc_worker/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
WEBRTC_MODAL_TOKEN_SECRET,
WEBRTC_MODAL_USAGE_QUOTA_ENABLED,
WEBRTC_MODAL_WATCHDOG_TIMEMOUT,
WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS,
WEBRTC_SESSION_HEARTBEAT_URL,
WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE,
)
from inference.core.exceptions import (
Expand Down Expand Up @@ -210,6 +212,12 @@ def check_nvidia_smi_gpu() -> str:
"WEBRTC_GZIP_PREVIEW_FRAME_COMPRESSION": str(
WEBRTC_GZIP_PREVIEW_FRAME_COMPRESSION
),
"WEBRTC_SESSION_HEARTBEAT_URL": (
WEBRTC_SESSION_HEARTBEAT_URL if WEBRTC_SESSION_HEARTBEAT_URL else ""
),
"WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS": str(
WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS
),
},
"volumes": {MODEL_CACHE_DIR: rfcache_volume},
}
Expand Down Expand Up @@ -388,6 +396,10 @@ def send_answer(obj: WebRTCWorkerResult):
watchdog = Watchdog(
api_key=webrtc_request.api_key,
timeout_seconds=WEBRTC_MODAL_WATCHDOG_TIMEMOUT,
# Use getattr for backwards compatibility with older Docker images
workspace_id=getattr(webrtc_request, "workspace_id", None),
session_id=getattr(webrtc_request, "session_id", None),
heartbeat_url=WEBRTC_SESSION_HEARTBEAT_URL,
)

try:
Expand Down
110 changes: 110 additions & 0 deletions inference/core/interfaces/webrtc_worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from av import VideoFrame

from inference.core import logger
from inference.core.cache import cache
from inference.core.cache.redis import RedisCache
from inference.core.env import DEBUG_WEBRTC_PROCESSING_LATENCY
from inference.core.interfaces.camera.entities import VideoFrame as InferenceVideoFrame
from inference.core.interfaces.stream.inference_pipeline import InferencePipeline
Expand Down Expand Up @@ -240,6 +242,114 @@ def is_over_quota(api_key: str) -> bool:
return is_over_quota


def _get_concurrent_sessions_key(workspace_id: str) -> str:
"""Get the Redis key for tracking concurrent sessions for a workspace."""
return f"webrtc:concurrent_sessions:{workspace_id}"


def register_webrtc_session(workspace_id: str, session_id: str) -> None:
"""Register a new concurrent WebRTC session for a workspace.

Adds the session to a Redis sorted set with current timestamp as score.
Expired entries are cleaned up on read via ZREMRANGEBYSCORE (O(log N + M)).

Args:
workspace_id: The workspace identifier
session_id: Unique identifier for this session
"""
if not isinstance(cache, RedisCache):
logger.warning(
"Redis not available, skipping session registration for rate limiting"
)
return

key = _get_concurrent_sessions_key(workspace_id)
try:
cache.client.zadd(key, {session_id: time.time()})
logger.info(
"Registered session: workspace=%s, session=%s",
workspace_id,
session_id,
)
except Exception as e:
logger.error("Failed to register session: %s", e)


def refresh_webrtc_session(workspace_id: str, session_id: str) -> None:
"""Refresh the timestamp for a concurrent WebRTC session.

Should be called periodically to keep the session marked as active.
If not refreshed, the session will be considered expired after TTL.

Args:
workspace_id: The workspace identifier
session_id: The session identifier to refresh
"""
if not isinstance(cache, RedisCache):
return

key = _get_concurrent_sessions_key(workspace_id)
try:
cache.client.zadd(key, {session_id: time.time()})
except Exception as e:
logger.error("Failed to refresh session: %s", e)


def get_concurrent_session_count(workspace_id: str, ttl_seconds: int) -> int:
"""Get the count of concurrent sessions for a workspace.

Cleans up expired entries (older than TTL) before counting.

Args:
workspace_id: The workspace identifier
ttl_seconds: TTL in seconds - entries older than this are considered expired

Returns:
Number of concurrent sessions for the workspace
"""
if not isinstance(cache, RedisCache):
logger.warning(
"Redis not available, cannot count concurrent sessions - allowing request"
)
return 0

key = _get_concurrent_sessions_key(workspace_id)
cutoff = time.time() - ttl_seconds

try:
# Step 1: we remove expired entries
# Step 2: we return what is still valid
cache.client.zremrangebyscore(key, "-inf", cutoff)
count = cache.client.zcard(key)
return count
except Exception as e:
logger.error("Failed to get concurrent session count: %s", e)
return 0


def is_over_workspace_session_quota(
workspace_id: str, quota: int, ttl_seconds: int
) -> bool:
"""Check if a workspace has exceeded its concurrent session quota.

Args:
workspace_id: The workspace identifier
quota: Maximum number of concurrent sessions allowed
ttl_seconds: TTL for considering sessions as active

Returns:
True if the workspace has reached or exceeded the quota
"""
count = get_concurrent_session_count(workspace_id, ttl_seconds)
logger.info(
"Workspace %s has %d concurrent sessions (quota: %d)",
workspace_id,
count,
quota,
)
return count >= quota


def get_video_fps(filepath: str) -> Optional[float]:
"""Detect video FPS from container metadata.

Expand Down
Loading