diff --git a/inference/core/env.py b/inference/core/env.py index 017f476e4c..66177dca5d 100644 --- a/inference/core/env.py +++ b/inference/core/env.py @@ -844,6 +844,31 @@ WEBRTC_MODAL_USAGE_QUOTA_ENABLED = str2bool( os.getenv("WEBRTC_MODAL_USAGE_QUOTA_ENABLED", "False") ) + +# +# Workspace stream quota +# +# Redis-base rate limiting that disables more than N concurrent +# connections from a single workspace +WEBRTC_WORKSPACE_STREAM_QUOTA_ENABLED = str2bool( + os.getenv("WEBRTC_WORKSPACE_STREAM_QUOTA_ENABLED", "False") +) +WEBRTC_WORKSPACE_STREAM_QUOTA = int(os.getenv("WEBRTC_WORKSPACE_STREAM_QUOTA", "10")) +# TTL in seconds for active stream entries (auto-expire if no explicit cleanup) +WEBRTC_WORKSPACE_STREAM_TTL_SECONDS = int( + os.getenv("WEBRTC_WORKSPACE_STREAM_TTL_SECONDS", "60") +) +# URL for Modal to send session heartbeats to keep session alive +# Example: "https://serverless.roboflow.com/webrtc/session/heartbeat" +WEBRTC_SESSION_HEARTBEAT_URL = os.getenv( + "WEBRTC_SESSION_HEARTBEAT_URL", + None, +) +# How often Modal sends session heartbeats (in seconds) +WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS = int( + os.getenv("WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS", "30") +) + WEBRTC_DATA_CHANNEL_BUFFER_DRAINING_DELAY = float( os.getenv("WEBRTC_DATA_CHANNEL_BUFFER_DRAINING_DELAY", "0.1") ) diff --git a/inference/core/exceptions.py b/inference/core/exceptions.py index 8ec1549152..4555f49765 100644 --- a/inference/core/exceptions.py +++ b/inference/core/exceptions.py @@ -232,3 +232,14 @@ class WebRTCConfigurationError(Exception): class CreditsExceededError(Exception): pass + + +class WorkspaceStreamQuotaError(Exception): + """Raised when the workspace stream quota has been exceeded. + + This error is returned when a workspace has reached its maximum number + of concurrent WebRTC streams. This is to prevent that a single user + uses all our modal resources. + """ + + pass diff --git a/inference/core/interfaces/http/error_handlers.py b/inference/core/interfaces/http/error_handlers.py index a5ccbcf76d..94be13c370 100644 --- a/inference/core/interfaces/http/error_handlers.py +++ b/inference/core/interfaces/http/error_handlers.py @@ -33,6 +33,7 @@ ServiceConfigurationError, WebRTCConfigurationError, WorkspaceLoadError, + WorkspaceStreamQuotaError, ) from inference.core.interfaces.stream_manager.api.errors import ( ProcessesManagerAuthorisationError, @@ -462,6 +463,15 @@ def wrapped_route(*args, **kwargs): "error_type": "CreditsExceededError", }, ) + except WorkspaceStreamQuotaError as error: + logger.error("%s: %s", type(error).__name__, error) + resp = JSONResponse( + status_code=429, + content={ + "message": str(error), + "error_type": "WorkspaceStreamQuotaError", + }, + ) except Exception as error: logger.exception("%s: %s", type(error).__name__, error) resp = JSONResponse(status_code=500, content={"message": "Internal error."}) @@ -850,6 +860,15 @@ async def wrapped_route(*args, **kwargs): "error_type": "CreditsExceededError", }, ) + except WorkspaceStreamQuotaError as error: + logger.error("%s: %s", type(error).__name__, error) + resp = JSONResponse( + status_code=429, + content={ + "message": str(error), + "error_type": "WorkspaceStreamQuotaError", + }, + ) except Exception as error: logger.exception("%s: %s", type(error).__name__, error) resp = JSONResponse(status_code=500, content={"message": "Internal error."}) diff --git a/inference/core/interfaces/http/http_api.py b/inference/core/interfaces/http/http_api.py index 0171cf7d03..7961d79650 100644 --- a/inference/core/interfaces/http/http_api.py +++ b/inference/core/interfaces/http/http_api.py @@ -244,9 +244,14 @@ ) from inference.core.interfaces.webrtc_worker import start_worker from inference.core.interfaces.webrtc_worker.entities import ( + WebRTCSessionHeartbeatRequest, WebRTCWorkerRequest, WebRTCWorkerResult, ) +from inference.core.interfaces.webrtc_worker.utils import ( + deregister_webrtc_session, + refresh_webrtc_session, +) from inference.core.managers.base import ModelManager from inference.core.managers.metrics import get_container_stats from inference.core.managers.model_load_collector import ( @@ -1786,6 +1791,87 @@ async def initialise_webrtc_worker( type=worker_result.answer.type, ) + @app.post( + "/webrtc/session/heartbeat", + summary="WebRTC session heartbeat", + ) + @with_route_exceptions_async + async def webrtc_session_heartbeat( + request: WebRTCSessionHeartbeatRequest, + ) -> dict: + """Receive heartbeat for an active WebRTC session. + + This endpoint is called periodically to indicate + that their session is still active. The session will be removed from + the quota count if no heartbeat is received within the TTL period. + + Requires api_key for authentication. + """ + try: + workspace_id = await get_roboflow_workspace_async( + api_key=request.api_key + ) + except (RoboflowAPINotAuthorizedError, WorkspaceLoadError): + raise HTTPException( + status_code=401, + detail={"status": "error", "message": "unauthorized"}, + ) + if not workspace_id: + raise HTTPException( + status_code=500, + detail={ + "status": "error", + "message": "failed to retrieve workspace", + }, + ) + + session_refreshed = refresh_webrtc_session( + workspace_id=workspace_id, + session_id=request.session_id, + ) + if not session_refreshed: + raise HTTPException( + status_code=404, + detail={"status": "error", "message": "session not found"}, + ) + return {"status": "ok"} + + @app.post( + "/webrtc/session/heartbeat/end", + summary="End WebRTC session", + ) + @with_route_exceptions_async + async def webrtc_session_end( + request: WebRTCSessionHeartbeatRequest, + ) -> dict: + """End a WebRTC session and immediately free the quota slot. + + Requires api_key for authentication. + """ + try: + workspace_id = await get_roboflow_workspace_async( + api_key=request.api_key + ) + except (RoboflowAPINotAuthorizedError, WorkspaceLoadError): + raise HTTPException( + status_code=401, + detail={"status": "error", "message": "unauthorized"}, + ) + if not workspace_id: + raise HTTPException( + status_code=500, + detail={ + "status": "error", + "message": "failed to retrieve workspace", + }, + ) + + deregister_webrtc_session( + workspace_id=workspace_id, + session_id=request.session_id, + ) + return {"status": "ok"} + if ENABLE_STREAM_API: @app.get( diff --git a/inference/core/interfaces/webrtc_worker/__init__.py b/inference/core/interfaces/webrtc_worker/__init__.py index a69dea8e6f..6529d1db56 100644 --- a/inference/core/interfaces/webrtc_worker/__init__.py +++ b/inference/core/interfaces/webrtc_worker/__init__.py @@ -1,12 +1,16 @@ import asyncio import multiprocessing +import uuid from inference.core.env import ( WEBRTC_MODAL_TOKEN_ID, WEBRTC_MODAL_TOKEN_SECRET, WEBRTC_MODAL_USAGE_QUOTA_ENABLED, + WEBRTC_WORKSPACE_STREAM_QUOTA, + WEBRTC_WORKSPACE_STREAM_QUOTA_ENABLED, + WEBRTC_WORKSPACE_STREAM_TTL_SECONDS, ) -from inference.core.exceptions import CreditsExceededError +from inference.core.exceptions import CreditsExceededError, WorkspaceStreamQuotaError from inference.core.interfaces.webrtc_worker.cpu import rtc_peer_connection_process from inference.core.interfaces.webrtc_worker.entities import ( RTCIceServer, @@ -15,6 +19,7 @@ WebRTCWorkerResult, ) from inference.core.logger import logger +from inference.core.roboflow_api import get_roboflow_workspace async def start_worker( @@ -36,7 +41,12 @@ async def start_worker( from inference.core.interfaces.webrtc_worker.modal import ( spawn_rtc_peer_connection_modal, ) - from inference.core.interfaces.webrtc_worker.utils import is_over_quota + from inference.core.interfaces.webrtc_worker.utils import ( + get_total_concurrent_sessions, + is_over_quota, + is_over_workspace_session_quota, + register_webrtc_session, + ) except ImportError: raise ImportError( "Modal not installed, please install it using 'pip install modal'" @@ -46,6 +56,47 @@ async def start_worker( logger.error("API key over quota") raise CreditsExceededError("API key over quota") + session_id = str(uuid.uuid4()) + workspace_id = get_roboflow_workspace(api_key=webrtc_request.api_key) + webrtc_request.workspace_id = workspace_id + webrtc_request.session_id = session_id + + if WEBRTC_WORKSPACE_STREAM_QUOTA_ENABLED: + if workspace_id and is_over_workspace_session_quota( + workspace_id=workspace_id, + quota=WEBRTC_WORKSPACE_STREAM_QUOTA, + ttl_seconds=WEBRTC_WORKSPACE_STREAM_TTL_SECONDS, + ): + logger.warning( + "Workspace %s has exceeded the stream quota of %d", + workspace_id, + WEBRTC_WORKSPACE_STREAM_QUOTA, + ) + raise WorkspaceStreamQuotaError( + f"You have reached the maximum of {WEBRTC_WORKSPACE_STREAM_QUOTA} " + f"concurrent streams." + ) + + if workspace_id: + register_webrtc_session( + workspace_id=workspace_id, + session_id=session_id, + ) + + total_sessions = get_total_concurrent_sessions( + ttl_seconds=WEBRTC_WORKSPACE_STREAM_TTL_SECONDS + ) + logger.info( + "Total concurrent WebRTC sessions: %d", + total_sessions, + ) + + logger.info( + "Started WebRTC session %s for workspace %s", + session_id, + workspace_id, + ) + loop = asyncio.get_event_loop() result = await loop.run_in_executor( None, diff --git a/inference/core/interfaces/webrtc_worker/entities.py b/inference/core/interfaces/webrtc_worker/entities.py index b705c4864c..2285d32a0d 100644 --- a/inference/core/interfaces/webrtc_worker/entities.py +++ b/inference/core/interfaces/webrtc_worker/entities.py @@ -49,6 +49,9 @@ class WebRTCWorkerRequest(BaseModel): # must be valid region: https://modal.com/docs/guide/region-selection#region-options requested_region: Optional[str] = None + workspace_id: Optional[str] = None + session_id: Optional[str] = None + class WebRTCVideoMetadata(BaseModel): frame_id: int @@ -86,6 +89,13 @@ class WebRTCWorkerResult(BaseModel): inner_error: Optional[str] = None +class WebRTCSessionHeartbeatRequest(BaseModel): + """Request body for WebRTC session heartbeat and end endpoints.""" + + session_id: str + api_key: str + + class StreamOutputMode(str, Enum): AUTO_DETECT = "auto_detect" # None -> auto-detect first image NO_VIDEO = "no_video" # [] -> no video track diff --git a/inference/core/interfaces/webrtc_worker/modal.py b/inference/core/interfaces/webrtc_worker/modal.py index 59e19aa3b7..fa78435c19 100644 --- a/inference/core/interfaces/webrtc_worker/modal.py +++ b/inference/core/interfaces/webrtc_worker/modal.py @@ -50,6 +50,8 @@ WEBRTC_MODAL_TOKEN_SECRET, WEBRTC_MODAL_USAGE_QUOTA_ENABLED, WEBRTC_MODAL_WATCHDOG_TIMEMOUT, + WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS, + WEBRTC_SESSION_HEARTBEAT_URL, WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE, ) from inference.core.exceptions import ( @@ -210,6 +212,12 @@ def check_nvidia_smi_gpu() -> str: "WEBRTC_GZIP_PREVIEW_FRAME_COMPRESSION": str( WEBRTC_GZIP_PREVIEW_FRAME_COMPRESSION ), + "WEBRTC_SESSION_HEARTBEAT_URL": ( + WEBRTC_SESSION_HEARTBEAT_URL if WEBRTC_SESSION_HEARTBEAT_URL else "" + ), + "WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS": str( + WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS + ), }, "volumes": {MODEL_CACHE_DIR: rfcache_volume}, } @@ -388,6 +396,9 @@ def send_answer(obj: WebRTCWorkerResult): watchdog = Watchdog( api_key=webrtc_request.api_key, timeout_seconds=WEBRTC_MODAL_WATCHDOG_TIMEMOUT, + workspace_id=getattr(webrtc_request, "workspace_id", None), + session_id=getattr(webrtc_request, "session_id", None), + heartbeat_url=WEBRTC_SESSION_HEARTBEAT_URL, ) try: diff --git a/inference/core/interfaces/webrtc_worker/utils.py b/inference/core/interfaces/webrtc_worker/utils.py index 2c0a4bd4f3..0d4c90d2ba 100644 --- a/inference/core/interfaces/webrtc_worker/utils.py +++ b/inference/core/interfaces/webrtc_worker/utils.py @@ -10,6 +10,8 @@ from av import VideoFrame from inference.core import logger +from inference.core.cache import cache +from inference.core.cache.redis import RedisCache from inference.core.env import DEBUG_WEBRTC_PROCESSING_LATENCY from inference.core.interfaces.camera.entities import VideoFrame as InferenceVideoFrame from inference.core.interfaces.stream.inference_pipeline import InferencePipeline @@ -240,6 +242,213 @@ def is_over_quota(api_key: str) -> bool: return is_over_quota +def _get_concurrent_sessions_key(workspace_id: str) -> str: + """Get the Redis key for tracking concurrent sessions for a workspace.""" + return f"webrtc:concurrent_sessions:{workspace_id}" + + +def register_webrtc_session(workspace_id: str, session_id: str) -> None: + """Register a new concurrent WebRTC session for a workspace. + + Adds the session to a Redis sorted set with current timestamp as score. + Expired entries are cleaned up on read via ZREMRANGEBYSCORE (O(log N + M)). + + Args: + workspace_id: The workspace identifier + session_id: Unique identifier for this session + """ + if not isinstance(cache, RedisCache): + logger.warning( + "[REDIS] Redis not available (cache is %s), skipping session registration", + type(cache).__name__, + ) + return + + key = _get_concurrent_sessions_key(workspace_id) + try: + cache.client.zadd(key, {session_id: time.time()}) + cache.client.expire(key, 600) # TTL 600 seconds, extended on each heartbeat + logger.info( + "Registered session: workspace=%s, session=%s", + workspace_id, + session_id, + ) + except Exception as e: + logger.error("Failed to register session: %s", e) + + +def deregister_webrtc_session(workspace_id: str, session_id: str) -> None: + """Remove a WebRTC session from the concurrent sessions set. + + Should be called when a session ends to immediately free the quota slot, + rather than waiting for TTL expiry. + + Args: + workspace_id: The workspace identifier + session_id: The session identifier to remove + """ + if not isinstance(cache, RedisCache): + logger.warning( + "[REDIS] Redis not available (cache is %s), skipping session deregistration", + type(cache).__name__, + ) + return + + key = _get_concurrent_sessions_key(workspace_id) + try: + result = cache.client.zrem(key, session_id) + logger.info( + "Deregistered session: workspace=%s, session=%s, removed=%s", + workspace_id, + session_id, + result, + ) + except Exception as e: + logger.error("Failed to deregister session: %s", e) + + +def refresh_webrtc_session(workspace_id: str, session_id: str) -> bool: + """Refresh the timestamp for a concurrent WebRTC session. + + Should be called periodically to keep the session marked as active. + If not refreshed, the session will be considered expired after TTL. + + Args: + workspace_id: The workspace identifier + session_id: The session identifier to refresh + + Returns: + True if session was refreshed (existed), False otherwise + """ + logger.debug( + "[REDIS] refresh_webrtc_session called: workspace=%s, session=%s, cache_type=%s", + workspace_id, + session_id, + type(cache).__name__, + ) + if not isinstance(cache, RedisCache): + logger.warning( + "[REDIS] Redis not available (cache is %s), cannot refresh session", + type(cache).__name__, + ) + return False + + key = _get_concurrent_sessions_key(workspace_id) + timestamp = time.time() + try: + # Only refresh sessions that already exist: we want to avoid attacks + # where an attacker injects arbitrary session IDs via an authenticated + # heartbeat endpoint + if cache.client.zscore(key, session_id) is None: + logger.warning( + "[REDIS] Session not found: workspace=%s, session=%s", + workspace_id, + session_id, + ) + return False + + cache.client.zadd(key, {session_id: timestamp}) + cache.client.expire(key, 600) # Extend TTL on each heartbeat + logger.info( + "[REDIS] Refreshed session: workspace=%s, session=%s", + workspace_id, + session_id, + ) + return True + except Exception as e: + logger.error("[REDIS] Failed to refresh session: %s", e, exc_info=True) + return False + + +def get_concurrent_session_count(workspace_id: str, ttl_seconds: int) -> int: + """Get the count of concurrent sessions for a workspace. + + Cleans up expired entries (older than TTL) before counting. + + Args: + workspace_id: The workspace identifier + ttl_seconds: TTL in seconds - entries older than this are considered expired + + Returns: + Number of concurrent sessions for the workspace + """ + if not isinstance(cache, RedisCache): + logger.warning( + "Redis not available, cannot count concurrent sessions - allowing request" + ) + return 0 + + key = _get_concurrent_sessions_key(workspace_id) + cutoff = time.time() - ttl_seconds + + try: + # Step 1: we remove expired entries + removed = cache.client.zremrangebyscore(key, "-inf", cutoff) + logger.info("[REDIS] Removed %s expired entries from %s", removed, key) + # Step 2: we return what is still valid + count = cache.client.zcard(key) + return count + except Exception as e: + logger.error( + "[REDIS] Failed to get concurrent session count: %s", e, exc_info=True + ) + return 0 + + +def is_over_workspace_session_quota( + workspace_id: str, quota: int, ttl_seconds: int +) -> bool: + """Check if a workspace has exceeded its concurrent session quota. + + Args: + workspace_id: The workspace identifier + quota: Maximum number of concurrent sessions allowed + ttl_seconds: TTL for considering sessions as active + + Returns: + True if the workspace has reached or exceeded the quota + """ + count = get_concurrent_session_count(workspace_id, ttl_seconds) + logger.info( + "Workspace %s has %d concurrent sessions (quota: %d)", + workspace_id, + count, + quota, + ) + return count >= quota + + +def get_total_concurrent_sessions(ttl_seconds: int) -> int: + """Get total concurrent WebRTC sessions across all workspaces. + + Args: + ttl_seconds: TTL in seconds - entries older than this are considered expired + + Returns: + Total number of active sessions + """ + if not isinstance(cache, RedisCache): + logger.warning( + "[REDIS] Redis not available, cannot count total concurrent sessions" + ) + return 0 + + pattern = "webrtc:concurrent_sessions:*" + cutoff = time.time() - ttl_seconds + total = 0 + + try: + for key in cache.client.scan_iter(match=pattern): + cache.client.zremrangebyscore(key, "-inf", cutoff) + total += cache.client.zcard(key) + return total + except Exception as e: + logger.error( + "[REDIS] Failed to get total concurrent sessions: %s", e, exc_info=True + ) + return 0 + + def get_video_fps(filepath: str) -> Optional[float]: """Detect video FPS from container metadata. diff --git a/inference/core/interfaces/webrtc_worker/watchdog.py b/inference/core/interfaces/webrtc_worker/watchdog.py index 1c276024a3..90d082810c 100644 --- a/inference/core/interfaces/webrtc_worker/watchdog.py +++ b/inference/core/interfaces/webrtc_worker/watchdog.py @@ -3,7 +3,12 @@ import time from typing import Callable, Optional -from inference.core.env import WEBRTC_MODAL_USAGE_QUOTA_ENABLED +import requests + +from inference.core.env import ( + WEBRTC_MODAL_USAGE_QUOTA_ENABLED, + WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS, +) from inference.core.interfaces.webrtc_worker.utils import is_over_quota from inference.core.logger import logger @@ -14,6 +19,9 @@ def __init__( api_key: str, timeout_seconds: int, on_timeout: Optional[Callable[[], None]] = None, + workspace_id: Optional[str] = None, + session_id: Optional[str] = None, + heartbeat_url: Optional[str] = None, ): self._api_key = api_key self.timeout_seconds = timeout_seconds @@ -25,6 +33,10 @@ def __init__( self._log_interval_seconds = 10 self._heartbeats = 0 self._total_heartbeats = 0 + self._workspace_id = workspace_id + self._session_id = session_id + self._heartbeat_url = heartbeat_url + self._last_session_heartbeat_ts = datetime.datetime.now() @property def total_heartbeats(self) -> int: @@ -42,9 +54,86 @@ def stop(self): self._stopping = True if self._thread.is_alive(): self._thread.join() + self._send_session_heartbeat_stop() + + def _send_session_heartbeat(self): + """Send heartbeat to keep the session alive in the quota system. + + This is used to sign that the session is alive so the system + doesnt allow more than N concurrent sessions from a single workspace. + """ + if not all( + [ + self._heartbeat_url, + self._workspace_id, + self._session_id, + ] + ): + logger.info( + "Skipping session heartbeat: url=%s, workspace=%s, session=%s", + bool(self._heartbeat_url), + bool(self._workspace_id), + bool(self._session_id), + ) + return + + try: + response = requests.post( + self._heartbeat_url, + json={ + "session_id": self._session_id, + "api_key": self._api_key, + }, + headers={"Content-Type": "application/json"}, + timeout=5, + ) + if response.status_code == 200: + logger.info( + "Session heartbeat sent for workspace=%s session=%s", + self._workspace_id, + self._session_id, + ) + else: + logger.warning( + "Failed to send session heartbeat: %s", response.status_code + ) + except Exception as e: + logger.warning("Error sending session heartbeat: %s", e) + + def _send_session_heartbeat_stop(self): + """Send session end to immediately free the quota slot.""" + if not all([self._heartbeat_url, self._session_id]): + return + + url = self._heartbeat_url + "/end" + try: + response = requests.post( + url, + json={ + "session_id": self._session_id, + "api_key": self._api_key, + }, + headers={"Content-Type": "application/json"}, + timeout=5, + ) + if response.status_code == 200: + logger.info( + "Session ended for workspace=%s session=%s", + self._workspace_id, + self._session_id, + ) + else: + logger.warning("Failed to send session end: %s", response.status_code) + except Exception as e: + logger.warning("Error sending session end: %s", e) def _watchdog_thread(self): logger.info("Watchdog thread started") + + # Send first heartbeat immediately to prevent session expiry if we have a cold start + self._send_session_heartbeat() + self._last_session_heartbeat_ts = datetime.datetime.now() + while not self._stopping: if not self.is_alive(): logger.error( @@ -62,6 +151,12 @@ def _watchdog_thread(self): message=f"API key over quota, heartbeats: {self._total_heartbeats}" ) break + + if ( + datetime.datetime.now() - self._last_session_heartbeat_ts + ).total_seconds() > WEBRTC_SESSION_HEARTBEAT_INTERVAL_SECONDS: + self._send_session_heartbeat() + self._last_session_heartbeat_ts = datetime.datetime.now() time.sleep(1) logger.info("Watchdog thread stopped, heartbeats: %s", self._total_heartbeats) diff --git a/inference/core/interfaces/webrtc_worker/webrtc.py b/inference/core/interfaces/webrtc_worker/webrtc.py index 978fbffbc5..44cb6a1ab7 100644 --- a/inference/core/interfaces/webrtc_worker/webrtc.py +++ b/inference/core/interfaces/webrtc_worker/webrtc.py @@ -937,8 +937,8 @@ async def init_rtc_peer_connection_with_loop( KeyError, NotImplementedError, ) as error: - # heartbeat to indicate caller error - heartbeat_callback() + if heartbeat_callback: + heartbeat_callback() send_answer( WebRTCWorkerResult( exception_type=error.__class__.__name__, @@ -947,8 +947,8 @@ async def init_rtc_peer_connection_with_loop( ) return except WebRTCConfigurationError as error: - # heartbeat to indicate caller error - heartbeat_callback() + if heartbeat_callback: + heartbeat_callback() send_answer( WebRTCWorkerResult( exception_type=error.__class__.__name__, @@ -957,8 +957,8 @@ async def init_rtc_peer_connection_with_loop( ) return except RoboflowAPINotAuthorizedError: - # heartbeat to indicate caller error - heartbeat_callback() + if heartbeat_callback: + heartbeat_callback() send_answer( WebRTCWorkerResult( exception_type=RoboflowAPINotAuthorizedError.__name__, @@ -967,8 +967,8 @@ async def init_rtc_peer_connection_with_loop( ) return except RoboflowAPINotNotFoundError: - # heartbeat to indicate caller error - heartbeat_callback() + if heartbeat_callback: + heartbeat_callback() send_answer( WebRTCWorkerResult( exception_type=RoboflowAPINotNotFoundError.__name__, @@ -977,8 +977,8 @@ async def init_rtc_peer_connection_with_loop( ) return except WorkflowSyntaxError as error: - # heartbeat to indicate caller error - heartbeat_callback() + if heartbeat_callback: + heartbeat_callback() send_answer( WebRTCWorkerResult( exception_type=WorkflowSyntaxError.__name__, @@ -989,8 +989,8 @@ async def init_rtc_peer_connection_with_loop( ) return except WorkflowError as error: - # heartbeat to indicate caller error - heartbeat_callback() + if heartbeat_callback: + heartbeat_callback() send_answer( WebRTCWorkerResult( exception_type=WorkflowError.__name__,