diff --git a/inference/core/interfaces/webrtc_worker/__init__.py b/inference/core/interfaces/webrtc_worker/__init__.py index 6529d1db56..f405c50f4f 100644 --- a/inference/core/interfaces/webrtc_worker/__init__.py +++ b/inference/core/interfaces/webrtc_worker/__init__.py @@ -1,5 +1,5 @@ import asyncio -import multiprocessing +import threading import uuid from inference.core.env import ( @@ -11,7 +11,6 @@ WEBRTC_WORKSPACE_STREAM_TTL_SECONDS, ) from inference.core.exceptions import CreditsExceededError, WorkspaceStreamQuotaError -from inference.core.interfaces.webrtc_worker.cpu import rtc_peer_connection_process from inference.core.interfaces.webrtc_worker.entities import ( RTCIceServer, WebRTCConfig, @@ -105,23 +104,41 @@ async def start_worker( ) return result else: - ctx = multiprocessing.get_context("spawn") - parent_conn, child_conn = ctx.Pipe(duplex=False) - p = ctx.Process( - target=rtc_peer_connection_process, - kwargs={ - "webrtc_request": webrtc_request, - "answer_conn": child_conn, - }, - daemon=False, + from inference.core.interfaces.webrtc_worker.webrtc import ( + init_rtc_peer_connection_with_loop, ) - p.start() - child_conn.close() - loop = asyncio.get_running_loop() - answer = WebRTCWorkerResult.model_validate( - await loop.run_in_executor(None, parent_conn.recv) + main_loop = asyncio.get_running_loop() + answer_future: asyncio.Future[WebRTCWorkerResult] = main_loop.create_future() + worker_loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() + + def _send_answer(obj: WebRTCWorkerResult) -> None: + main_loop.call_soon_threadsafe(answer_future.set_result, obj) + + def _run_worker() -> None: + try: + worker_loop.run_until_complete( + init_rtc_peer_connection_with_loop( + webrtc_request=webrtc_request, + send_answer=_send_answer, + asyncio_loop=worker_loop, + ) + ) + except Exception as exc: + logger.exception("WebRTC worker thread crashed: %s", exc) + try: + _send_answer( + WebRTCWorkerResult( + exception_type=exc.__class__.__name__, + error_message=str(exc), + ) + ) + except Exception: + pass + + worker_thread = threading.Thread( + target=_run_worker, daemon=True, name="webrtc-worker" ) - parent_conn.close() + worker_thread.start() - return answer + return await answer_future diff --git a/inference/core/interfaces/webrtc_worker/cpu.py b/inference/core/interfaces/webrtc_worker/cpu.py index a9bdc45e6c..62b434eb27 100644 --- a/inference/core/interfaces/webrtc_worker/cpu.py +++ b/inference/core/interfaces/webrtc_worker/cpu.py @@ -19,10 +19,22 @@ def send_answer(obj: WebRTCWorkerResult): answer_conn.send(obj) answer_conn.close() - asyncio.run( - init_rtc_peer_connection_with_loop( - webrtc_request=webrtc_request, - send_answer=send_answer, + try: + asyncio.run( + init_rtc_peer_connection_with_loop( + webrtc_request=webrtc_request, + send_answer=send_answer, + ) ) - ) - logger.info("WebRTC process terminated") + logger.info("WebRTC process terminated") + except Exception as exc: + logger.exception("WebRTC worker process crashed: %s", exc) + try: + send_answer( + WebRTCWorkerResult( + exception_type=exc.__class__.__name__, + error_message=str(exc), + ) + ) + except Exception: + pass