diff --git a/docs/guides/environments.md b/docs/guides/environments.md index 64421f7cb6..987c6ffd10 100644 --- a/docs/guides/environments.md +++ b/docs/guides/environments.md @@ -100,6 +100,8 @@ from nemo_rl.environments.code_environment import CodeEnvironment env_config = { "num_workers": 2, "terminate_on_evaluation": True, # Terminate after code execution + "default_timeout_seconds": 1.0, # Optional default wall-clock timeout per step + "default_memory_limit_bytes": 268435456, # Optional default memory limit per step } code_env = CodeEnvironment.remote(env_config) @@ -108,6 +110,32 @@ code_env = CodeEnvironment.remote(env_config) ### Configuration - `num_workers`: Number of parallel workers for code execution - `terminate_on_evaluation`: Whether to terminate after code execution (True for single-turn, False for multi-turn). +- `default_timeout_seconds`: Optional default wall-clock timeout for each code execution step. +- `default_memory_limit_bytes`: Optional default virtual-memory cap for each code execution step. + +Timeouts are enforced in the child process, with a small parent-side startup grace so worker bootstrapping does not immediately trip very small limits. + +### Per-sample execution limits + +The code environment reads limits from `extra_env_info`, so you can set them per sample: + +```python +{ + "context": {}, + "working_dir": "/tmp/code-env/sample-123", + "timeout_seconds": 1.0, + "memory_limit_bytes": 268435456, +} +``` + +If both per-sample limits and environment defaults are provided, the per-sample values win. + +`context` and `working_dir` still behave the same as before: + +- `context` stores variables and functions defined by prior turns. +- `working_dir` bounds file access for `open()`. + +The memory limit uses `resource.RLIMIT_AS` and is currently supported on Linux only. On other platforms, configuring `memory_limit_bytes` returns an error observation instead of enforcing the limit. We are tracking an end-to-end example of this environment in [#858](https://github.com/NVIDIA-NeMo/RL/issues/858). Add a 👍 to show your interest. diff --git a/nemo_rl/environments/code_environment.py b/nemo_rl/environments/code_environment.py index a1047d081b..ca582ca76a 100644 --- a/nemo_rl/environments/code_environment.py +++ b/nemo_rl/environments/code_environment.py @@ -13,16 +13,23 @@ # limitations under the License. import ast import builtins +import multiprocessing as mp import os import re +import signal +import sys +import time from collections.abc import Mapping, Sequence from contextlib import contextmanager from copy import copy +from dataclasses import dataclass, replace from io import IOBase +from multiprocessing.connection import Connection from pprint import pformat from types import ModuleType -from typing import Any, Dict, List, Optional, Tuple, TypedDict +from typing import Any, Dict, List, NamedTuple, NotRequired, Optional, Tuple, TypedDict +import cloudpickle import ray import torch @@ -32,6 +39,14 @@ from nemo_rl.environments.interfaces import EnvironmentInterface, EnvironmentReturn from nemo_rl.environments.utils import chunk_list_to_workers +try: + import resource +except ImportError: # pragma: no cover + resource = None + + +SUBPROCESS_STARTUP_GRACE_SECONDS = 2.0 + class CodeEnvConfig(TypedDict): num_workers: int @@ -39,152 +54,449 @@ class CodeEnvConfig(TypedDict): # if you want to execute multiple rounds of code, set this to False # and wrap CodeEnvironment in another environment that terminates the generation terminate_on_evaluation: bool + default_timeout_seconds: NotRequired[float | None] + default_memory_limit_bytes: NotRequired[int | None] class CodeEnvMetadata(TypedDict): context: Dict[str, Any] # Hold functions and variables defined in the code working_dir: str # Working directory for file operations + timeout_seconds: NotRequired[float | None] + memory_limit_bytes: NotRequired[int | None] + + +class ExecutionLimits(NamedTuple): + timeout_seconds: float | None + memory_limit_bytes: int | None + + +@dataclass +class CodeExecutionRequest: + code: str + context: dict[str, Any] + working_dir: str + lookahead: str | None = None + timeout_seconds: float | None = None + memory_limit_bytes: int | None = None + + +@dataclass +class CodeExecutionResponse: + formatted_result: str + terminated: bool + context: dict[str, Any] + + +def sanitize_object(obj: Any) -> Any: + """Sanitize objects that are not safe to return through Ray.""" + if isinstance(obj, (IOBase, ModuleType)): + return repr(obj) + if isinstance(obj, Mapping): + return obj.__class__( + {sanitize_object(k): sanitize_object(v) for k, v in obj.items()} + ) + if isinstance(obj, Sequence) and not isinstance(obj, str): + return obj.__class__(sanitize_object(v) for v in obj) + if hasattr(obj, "__dict__"): + new_obj = copy(obj) + new_obj.__dict__ = { + sanitize_object(k): sanitize_object(v) for k, v in obj.__dict__.items() + } + return new_obj + return obj + + +def format_result( + result: Any, code: str | None = None, lookahead: str | None = None +) -> str: + """Format a code execution result as an environment observation.""" + if result is None: + return "" + + result = pformat(result) + multiline = (code and "\n" in code) or "\n" in result + if multiline: + formatted_result = f"\n\n\n{result}\n" + else: + formatted_result = f"{result}" + + if lookahead and formatted_result.startswith(lookahead): + # The generation may look like "\n" if ">\n" is a single token. + # We trim \n from the result if the model has already generated it. + formatted_result = formatted_result[len(lookahead) :] + + return formatted_result + + +def _validate_timeout_seconds(timeout_seconds: float | None) -> float | None: + if timeout_seconds is None: + return None + if isinstance(timeout_seconds, bool) or not isinstance( + timeout_seconds, (int, float) + ): + raise TypeError( + "timeout_seconds must be a positive number or None, " + f"got {type(timeout_seconds)}" + ) + + timeout_value = float(timeout_seconds) + if timeout_value <= 0: + raise ValueError("timeout_seconds must be greater than 0") + + return timeout_value + + +def _validate_memory_limit_bytes(memory_limit_bytes: int | None) -> int | None: + if memory_limit_bytes is None: + return None + if isinstance(memory_limit_bytes, bool) or not isinstance(memory_limit_bytes, int): + raise TypeError( + "memory_limit_bytes must be a positive integer number of bytes or None" + ) + if memory_limit_bytes <= 0: + raise ValueError("memory_limit_bytes must be greater than 0") + + return memory_limit_bytes + + +def _resolve_execution_limits( + metadata: CodeEnvMetadata, + *, + default_timeout_seconds: float | None, + default_memory_limit_bytes: int | None, +) -> ExecutionLimits: + timeout_seconds = _validate_timeout_seconds( + metadata.get("timeout_seconds", default_timeout_seconds) + ) + memory_limit_bytes = _validate_memory_limit_bytes( + metadata.get("memory_limit_bytes", default_memory_limit_bytes) + ) + return ExecutionLimits( + timeout_seconds=timeout_seconds, + memory_limit_bytes=memory_limit_bytes, + ) + + +def _supports_memory_limit() -> bool: + return ( + resource is not None + and hasattr(resource, "RLIMIT_AS") + and sys.platform.startswith("linux") + ) + + +def _supports_signal_timeout() -> bool: + return hasattr(signal, "SIGALRM") and hasattr(signal, "setitimer") + + +def _timeout_error(timeout_seconds: float) -> TimeoutError: + return TimeoutError( + "Code execution exceeded the configured timeout " + f"of {timeout_seconds} seconds" + ) + + +@contextmanager +def _execution_timeout(timeout_seconds: float | None): + if timeout_seconds is None or not _supports_signal_timeout(): + yield + return + + previous_handler = signal.getsignal(signal.SIGALRM) + + def _handle_timeout(signum, frame): + raise _timeout_error(timeout_seconds) + + signal.signal(signal.SIGALRM, _handle_timeout) + signal.setitimer(signal.ITIMER_REAL, timeout_seconds) + try: + yield + finally: + signal.setitimer(signal.ITIMER_REAL, 0) + signal.signal(signal.SIGALRM, previous_handler) + + +@contextmanager +def _chdir(dir_path: str): + """Change to a temporary directory for file operations.""" + current_dir = os.getcwd() + os.chdir(dir_path) + try: + yield + finally: + os.chdir(current_dir) + + +def _safe_open(file: str, *args, **kwargs): + """Safe version of open() that only allows access to the current directory.""" + real_file = os.path.realpath(file) + working_dir = os.path.realpath(os.getcwd()) + if os.path.commonpath([real_file, working_dir]) != working_dir: + raise PermissionError( + "Access beyond the temporary working directory is blocked" + ) + return open(file, *args, **kwargs) + + +def _safe_import(name: str, *args, **kwargs): + """Safe version of import that blocks risky modules.""" + risky_modules = { + "os", + "shutil", # erase filesystem + "sys", + "signal", # exit the current program + "socket", # network communication + "subprocess", + "threading", + "multiprocessing", # spawn threads or processes + "builtins", + "importlib", # bypass current blockers + } + if name in risky_modules: + raise PermissionError("Importing system and network modules is blocked") + return builtins.__import__(name, *args, **kwargs) + + +def _create_sandbox() -> dict[str, Any]: + builtin_dict = {k: getattr(builtins, k) for k in dir(builtins)} + builtin_dict["open"] = _safe_open + builtin_dict["__import__"] = _safe_import + return {"__builtins__": builtin_dict} + + +def _apply_memory_limit(memory_limit_bytes: int | None) -> None: + if memory_limit_bytes is None: + return + if not _supports_memory_limit(): + raise RuntimeError( + "memory_limit_bytes is not supported on this platform because " + "resource.RLIMIT_AS is unavailable" + ) + + assert resource is not None # for type checkers + resource.setrlimit( + resource.RLIMIT_AS, + (memory_limit_bytes, memory_limit_bytes), + ) + + +def _execute_code_request(request: CodeExecutionRequest) -> CodeExecutionResponse: + sandbox = _create_sandbox() + result: Any = None + terminated = False + + try: + with _execution_timeout(request.timeout_seconds): + tree = ast.parse(request.code) + if tree.body and isinstance(tree.body[-1], ast.Expr): + exec_code = ast.unparse(tree.body[:-1]) + eval_code = ast.unparse(tree.body[-1]) + else: + exec_code = request.code + eval_code = None + + with _chdir(request.working_dir): + _apply_memory_limit(request.memory_limit_bytes) + exec(exec_code, sandbox, request.context) + if eval_code: + result = eval(eval_code, sandbox, request.context) + terminated = True + except Exception as err: + result = err + + return CodeExecutionResponse( + formatted_result=format_result(result, request.code, request.lookahead), + terminated=terminated, + context=request.context, + ) + + +def _serialize_request( + request: CodeExecutionRequest, +) -> tuple[bytes, CodeExecutionRequest]: + try: + return cloudpickle.dumps(request), request + except Exception: + transport_request = replace(request, context=sanitize_object(request.context)) + return cloudpickle.dumps(transport_request), transport_request + + +def _serialize_response(response: CodeExecutionResponse) -> bytes: + try: + return cloudpickle.dumps(response) + except Exception: + sanitized_response = replace( + response, context=sanitize_object(response.context) + ) + return cloudpickle.dumps(sanitized_response) + + +def _subprocess_main( + send_conn: Connection, request_bytes: bytes +) -> None: # pragma: no cover + try: + request = cloudpickle.loads(request_bytes) + response = _execute_code_request(request) + send_conn.send_bytes(_serialize_response(response)) + finally: + send_conn.close() + + +def _wait_for_subprocess_response( + recv_conn: Connection, + process: mp.Process, + timeout_seconds: float | None, +) -> tuple[CodeExecutionResponse | None, bool]: + wait_timeout_seconds = None + if timeout_seconds is not None: + wait_timeout_seconds = timeout_seconds + SUBPROCESS_STARTUP_GRACE_SECONDS + deadline = ( + None + if wait_timeout_seconds is None + else time.monotonic() + wait_timeout_seconds + ) + + while True: + poll_timeout = 0.1 + if deadline is not None: + remaining = deadline - time.monotonic() + if remaining <= 0: + return None, True + poll_timeout = min(poll_timeout, remaining) + + if recv_conn.poll(poll_timeout): + try: + response = cloudpickle.loads(recv_conn.recv_bytes()) + except EOFError: + return None, False + return response, False + + if not process.is_alive(): + return None, False + + +def _execute_code_in_subprocess( + request: CodeExecutionRequest, +) -> CodeExecutionResponse: + ctx = mp.get_context("spawn") + recv_conn, send_conn = ctx.Pipe(duplex=False) + request_bytes, transport_request = _serialize_request(request) + process = ctx.Process( + target=_subprocess_main, + args=(send_conn, request_bytes), + daemon=True, + ) + + try: + process.start() + send_conn.close() + response, timed_out = _wait_for_subprocess_response( + recv_conn, + process, + transport_request.timeout_seconds, + ) + + if response is not None: + process.join() + return replace(response, context=sanitize_object(response.context)) + + if timed_out and process.is_alive(): + process.kill() + process.join() + return CodeExecutionResponse( + formatted_result=format_result( + _timeout_error(transport_request.timeout_seconds) + ), + terminated=False, + context=sanitize_object(transport_request.context), + ) + + process.join() + + if ( + transport_request.memory_limit_bytes is not None + and process.exitcode is not None + and process.exitcode < 0 + ): + error: Exception = MemoryError( + "Code execution exceeded the configured memory limit " + f"of {transport_request.memory_limit_bytes} bytes" + ) + else: + error = RuntimeError( + "Code execution subprocess exited unexpectedly " + f"with exit code {process.exitcode}" + ) + return CodeExecutionResponse( + formatted_result=format_result(error), + terminated=False, + context=sanitize_object(transport_request.context), + ) + finally: + recv_conn.close() + if process.is_alive(): + process.kill() + process.join() @ray.remote # pragma: no cover class CodeExecutionWorker: """Helper class to process individual code execution steps.""" - def __init__(self): - # Create sandbox with safe builtins - builtin_dict = {k: getattr(builtins, k) for k in dir(builtins)} - builtin_dict["open"] = self.safe_open - builtin_dict["__import__"] = self.safe_import - self.sandbox = {"__builtins__": builtin_dict} - - def sanitize(self, obj: Any) -> Any: - # TODO: better handling of unpicklable objects: custom __getstate__ and __setstate__ - # recursively remove all file objects as they are not picklable by ray - if isinstance(obj, (IOBase, ModuleType)): - # replace unpickable objects with a string representation - return repr(obj) - if isinstance(obj, Mapping): - return obj.__class__( - {self.sanitize(k): self.sanitize(v) for k, v in obj.items()} - ) - if isinstance(obj, Sequence) and not isinstance(obj, str): - return obj.__class__(self.sanitize(v) for v in obj) - if hasattr(obj, "__dict__"): - new_obj = copy(obj) - new_obj.__dict__ = { - self.sanitize(k): self.sanitize(v) for k, v in obj.__dict__.items() - } - return new_obj - return obj - - def format_result( - self, result: Any, code: Optional[str] = None, lookahead: Optional[str] = None - ) -> str: - if result is None: - # no return value - return "" - result = pformat(result) - multiline = (code and "\n" in code) or "\n" in result - if multiline: - # multi-line format - result = f"\n\n\n{result}\n" - else: - # inline format - result = f"{result}" - if lookahead: - if result.startswith(lookahead): - # The generation may look like "\n" if ">\n" is a single token. - # We trim \n from the result if the model has already generated it. - result = result[len(lookahead) :] - return result + def __init__( + self, + *, + default_timeout_seconds: float | None, + default_memory_limit_bytes: int | None, + ): + self.default_timeout_seconds = default_timeout_seconds + self.default_memory_limit_bytes = default_memory_limit_bytes def execute( - self, message_batch: str, metadata_batch: List[CodeEnvMetadata] + self, message_batch: list[str], metadata_batch: List[CodeEnvMetadata] ) -> Tuple[List[Dict[str, str]], List[bool], List[Any]]: """Execute code in a sandboxed environment.""" results = [] terminateds = [] + updated_metadata_batch: list[CodeEnvMetadata] = [] for message, metadata in zip(message_batch, metadata_batch): match = re.search(r"(.*)(.*)", message, re.DOTALL) if not match: results.append("") terminateds.append(False) + updated_metadata_batch.append(metadata) continue code, lookahead = match.groups() - tree = ast.parse(code) - - if tree.body and isinstance(tree.body[-1], ast.Expr): - # Interactive mode - exec_code = ast.unparse(tree.body[:-1]) - eval_code = ast.unparse(tree.body[-1]) - else: - # Silent mode - exec_code = code - eval_code = None + execution_limits = _resolve_execution_limits( + metadata, + default_timeout_seconds=self.default_timeout_seconds, + default_memory_limit_bytes=self.default_memory_limit_bytes, + ) + response = _execute_code_in_subprocess( + CodeExecutionRequest( + code=code, + context=metadata["context"], + working_dir=metadata["working_dir"], + lookahead=lookahead, + timeout_seconds=execution_limits.timeout_seconds, + memory_limit_bytes=execution_limits.memory_limit_bytes, + ) + ) + updated_metadata = dict(metadata) + updated_metadata["context"] = response.context - result = None - terminated = False - with self.chdir(metadata["working_dir"]): - try: - # isolate the code in a sandbox - # capture local variables in metadata["context"] - exec(exec_code, self.sandbox, metadata["context"]) - if eval_code: - result = eval(eval_code, self.sandbox, metadata["context"]) - terminated = True - except Exception as err: - result = err - - result = self.format_result(result, code, lookahead) - results.append(result) - terminateds.append(terminated) + results.append(response.formatted_result) + terminateds.append(response.terminated) + updated_metadata_batch.append(updated_metadata) observations = [ {"role": "environment", "content": result} for result in results ] - metadata_batch = self.sanitize(metadata_batch) - - return observations, terminateds, metadata_batch - - @contextmanager - def chdir(self, dir: str): - """Change to temporary directory for file operations.""" - current_dir = os.getcwd() - os.chdir(dir) - try: - yield - finally: - os.chdir(current_dir) - - def safe_open(self, file: str, *args, **kwargs): - """Safe version of open() that only allows access to temporary directory.""" - real_file = os.path.realpath(file) - working_dir = os.path.realpath(os.getcwd()) - if os.path.commonpath([real_file, working_dir]) != working_dir: - raise PermissionError( - "Access beyond the temporary working directory is blocked" - ) - return open(file, *args, **kwargs) - - def safe_import(self, name: str, *args, **kwargs): - """Safe version of import that blocks risky modules.""" - risky_modules = { - "os", - "shutil", # erase filesystem - "sys", - "signal", # exit the current program - "socket", # network communication - "subprocess", - "threading", - "multiprocessing", # spawn threads or processes - "builtins", - "importlib", # bypass current blockers - } - if name in risky_modules: - raise PermissionError("Importing system and network modules is blocked") - return builtins.__import__(name, *args, **kwargs) + updated_metadata_batch = sanitize_object(updated_metadata_batch) + + return observations, terminateds, updated_metadata_batch @ray.remote # pragma: no cover @@ -195,10 +507,19 @@ def __init__(self, cfg: CodeEnvConfig): self.cfg = cfg self.num_workers = cfg["num_workers"] self.terminate_on_evaluation = cfg["terminate_on_evaluation"] + self.default_timeout_seconds = _validate_timeout_seconds( + cfg.get("default_timeout_seconds") + ) + self.default_memory_limit_bytes = _validate_memory_limit_bytes( + cfg.get("default_memory_limit_bytes") + ) self.workers = [ CodeExecutionWorker.options( runtime_env={"py_executable": PY_EXECUTABLES.SYSTEM} - ).remote() + ).remote( + default_timeout_seconds=self.default_timeout_seconds, + default_memory_limit_bytes=self.default_memory_limit_bytes, + ) for _ in range(self.num_workers) ] @@ -213,7 +534,6 @@ def step( chunked_message_batch = chunk_list_to_workers(message_batch, self.num_workers) chunked_metadata_batch = chunk_list_to_workers(metadata_batch, self.num_workers) - # Process each chunk in parallel futures = [ self.workers[i].execute.remote(message_chunk, metadata_chunk) for i, (message_chunk, metadata_chunk) in enumerate( @@ -223,7 +543,6 @@ def step( results = ray.get(futures) - # Unpack results observations = [] terminateds = [] new_metadata_batch = [] @@ -242,7 +561,8 @@ def step( next_stop_strings = [[""]] * len(message_log_batch) assert return_extracted_answer == False, ( - "return_extracted_answer is not supported in CodeEnvironment. Please set it to False." + "return_extracted_answer is not supported in CodeEnvironment. " + "Please set it to False." ) extracted_answers = None @@ -256,7 +576,6 @@ def step( ) def shutdown(self): - # shutdown all workers for worker in self.workers: ray.kill(worker) @@ -264,5 +583,4 @@ def global_post_process_and_metrics( self, batch: BatchedDataDict ) -> Tuple[BatchedDataDict, dict]: """Compute metrics for the batch.""" - # No specific metrics for code execution return batch, {} diff --git a/tests/unit/environments/test_code_environment.py b/tests/unit/environments/test_code_environment.py index d32550aba1..bc6cdaa2eb 100644 --- a/tests/unit/environments/test_code_environment.py +++ b/tests/unit/environments/test_code_environment.py @@ -12,22 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import ExitStack from tempfile import TemporaryDirectory import pytest import ray from transformers import AutoTokenizer -from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.environments.code_environment import ( CodeEnvConfig, CodeEnvironment, CodeEnvMetadata, + _resolve_execution_limits, + _supports_memory_limit, ) -from nemo_rl.experience.rollouts import run_multi_turn_rollout -from nemo_rl.models.generation import configure_generation_config -from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration MODEL_NAME = "meta-llama/Llama-3.2-1B" @@ -36,43 +34,41 @@ "terminate_on_evaluation": True, } -# Define basic vLLM test config -basic_vllm_test_config: VllmConfig = { - "backend": "vllm", - "model_name": MODEL_NAME, - "tokenizer_name": None, - "dtype": "bfloat16", - "max_new_tokens": 100, - "temperature": 1.0, - "top_p": 1.0, - "top_k": None, - "stop_token_ids": None, - "stop_strings": None, - "vllm_cfg": { - "async_engine": False, - "precision": "bfloat16", - "tensor_parallel_size": 1, - "pipeline_parallel_size": 1, - "expert_parallel_size": 1, - "max_model_len": 1024, - "disable_log_stats": True, - "disable_log_requests": True, - "gpu_memory_utilization": 0.6, - "enforce_eager": "False", - }, - "colocated": { - "enabled": True, - "resources": { - "gpus_per_node": None, - "num_nodes": None, - }, - }, -} + +def _make_metadata( + working_dir: str, + *, + timeout_seconds: float | None = None, + memory_limit_bytes: int | None = None, +) -> CodeEnvMetadata: + metadata: CodeEnvMetadata = { + "context": {}, + "working_dir": working_dir, + } + if timeout_seconds is not None: + metadata["timeout_seconds"] = timeout_seconds + if memory_limit_bytes is not None: + metadata["memory_limit_bytes"] = memory_limit_bytes + return metadata + + +def _step_code( + env_actor, + code: str, + metadata: CodeEnvMetadata, +): + return ray.get( + env_actor.step.remote( + [[{"role": "user", "content": f"{code}"}]], + [metadata], + ) + ) @pytest.fixture(scope="function") def code_env(): """Create a code environment for testing.""" + env_actor = None try: env_actor = CodeEnvironment.remote(cfg) yield env_actor @@ -89,7 +85,9 @@ def tokenizer(): if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print( - f"Tokenizer loaded. Pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id}), EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})" + "Tokenizer loaded. " + f"Pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id}), " + f"EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})" ) return tokenizer @@ -97,6 +95,8 @@ def tokenizer(): @pytest.fixture(scope="function") def cluster(): """Create a virtual cluster for testing.""" + from nemo_rl.distributed.virtual_cluster import RayVirtualCluster + cluster_instance = None cluster_name = f"test-code-cluster-{id(cluster_instance)}" print(f"\nCreating virtual cluster '{cluster_name}'...") @@ -115,6 +115,40 @@ def cluster(): cluster_instance.shutdown() +def test_resolve_execution_limits_uses_metadata_overrides(): + metadata = CodeEnvMetadata( + context={}, + working_dir="/tmp/code-env", + timeout_seconds=2.0, + memory_limit_bytes=2048, + ) + + limits = _resolve_execution_limits( + metadata, + default_timeout_seconds=1.0, + default_memory_limit_bytes=1024, + ) + + assert limits.timeout_seconds == 2.0 + assert limits.memory_limit_bytes == 2048 + + +def test_resolve_execution_limits_uses_defaults_when_metadata_missing(): + metadata = CodeEnvMetadata( + context={}, + working_dir="/tmp/code-env", + ) + + limits = _resolve_execution_limits( + metadata, + default_timeout_seconds=1.5, + default_memory_limit_bytes=4096, + ) + + assert limits.timeout_seconds == 1.5 + assert limits.memory_limit_bytes == 4096 + + def test_untrusted_code(code_env): """Test whether the code environment can block untrusted code.""" codes = [ @@ -129,46 +163,170 @@ def test_untrusted_code(code_env): ] results = [ "\n\n\n'some content'\n", - "\n\n\nPermissionError('Access beyond the temporary working directory is blocked')\n", + ( + "\n\n\n" + "PermissionError(" + "'Access beyond the temporary working directory is blocked'" + ")\n" + "" + ), "\n\n\n3\n", - "PermissionError('Importing system and network modules is blocked')", + ( + "" + "PermissionError('Importing system and network modules is blocked')" + "" + ), ] - message_log_batch = [ - [{"role": "user", "content": f"{code}"}] for code in codes - ] - temp_dirs = [TemporaryDirectory() for _ in codes] - metadata_batch = [ - CodeEnvMetadata( - context={}, - working_dir=temp_dir.name, - ) - for temp_dir in temp_dirs - ] + with ExitStack() as stack: + temp_dirs = [stack.enter_context(TemporaryDirectory()) for _ in codes] + metadata_batch = [ + _make_metadata(temp_dir) + for temp_dir in temp_dirs + ] - # Execute the code - output = ray.get(code_env.step.remote(message_log_batch, metadata_batch)) - responses = [obs["content"] for obs in output.observations] + output = ray.get( + code_env.step.remote( + [ + [{"role": "user", "content": f"{code}"}] + for code in codes + ], + metadata_batch, + ) + ) + responses = [obs["content"] for obs in output.observations] assert responses == results, f"Got wrong output {responses}" +def test_syntax_error_returns_observation_and_actor_stays_healthy(code_env): + with TemporaryDirectory() as temp_dir: + metadata = _make_metadata(temp_dir) + output = _step_code(code_env, "def broken(:\n pass", metadata) + + assert "SyntaxError" in output.observations[0]["content"] + follow_up = _step_code(code_env, "1 + 1", output.metadata[0]) + assert follow_up.observations[0]["content"] == "2" + + +def test_timeout_returns_observation_and_actor_stays_healthy(code_env): + with TemporaryDirectory() as temp_dir: + metadata = _make_metadata(temp_dir, timeout_seconds=0.5) + output = _step_code(code_env, "while True:\n pass", metadata) + + assert "TimeoutError" in output.observations[0]["content"] + follow_up = _step_code(code_env, "40 + 2", output.metadata[0]) + assert follow_up.observations[0]["content"] == "42" + + +@pytest.mark.skipif( + not _supports_memory_limit(), + reason="Memory limits require resource.RLIMIT_AS support", +) +def test_memory_limit_returns_observation_and_actor_stays_healthy(code_env): + with TemporaryDirectory() as temp_dir: + metadata = _make_metadata( + temp_dir, + memory_limit_bytes=256 * 1024 * 1024, + ) + output = _step_code( + code_env, + "blob = bytearray(1024 * 1024 * 1024)\nlen(blob)", + metadata, + ) + + assert "MemoryError" in output.observations[0]["content"] + follow_up = _step_code(code_env, "6 * 7", output.metadata[0]) + assert follow_up.observations[0]["content"] == "42" + + +def test_default_timeout_applies_when_metadata_omits_limit(): + env_actor = None + try: + env_actor = CodeEnvironment.remote( + CodeEnvConfig( + num_workers=1, + terminate_on_evaluation=True, + default_timeout_seconds=0.5, + ) + ) + with TemporaryDirectory() as temp_dir: + metadata = _make_metadata(temp_dir) + output = _step_code(env_actor, "while True:\n pass", metadata) + + assert "TimeoutError" in output.observations[0]["content"] + finally: + if env_actor: + ray.kill(env_actor) + + +def test_multiturn_context_survives_subprocess_boundary(code_env): + with TemporaryDirectory() as temp_dir: + metadata = _make_metadata(temp_dir) + first_output = _step_code( + code_env, + "def square(x):\n return x * x\nvalue = 7", + metadata, + ) + second_output = _step_code( + code_env, + "square(value)", + first_output.metadata[0], + ) + + assert second_output.observations[0]["content"] == "49" + + @pytest.mark.hf_gated def test_vllm_execute_code(cluster, tokenizer, code_env): """Test that vLLM can call the code executor.""" - # Prepare test data + from nemo_rl.distributed.batched_data_dict import BatchedDataDict + from nemo_rl.experience.rollouts import run_multi_turn_rollout + from nemo_rl.models.generation import configure_generation_config + from nemo_rl.models.generation.vllm import VllmGeneration + + basic_vllm_test_config = { + "backend": "vllm", + "model_name": MODEL_NAME, + "tokenizer_name": None, + "dtype": "bfloat16", + "max_new_tokens": 100, + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, + "stop_token_ids": None, + "stop_strings": None, + "vllm_cfg": { + "async_engine": False, + "precision": "bfloat16", + "tensor_parallel_size": 1, + "pipeline_parallel_size": 1, + "expert_parallel_size": 1, + "max_model_len": 1024, + "disable_log_stats": True, + "disable_log_requests": True, + "gpu_memory_utilization": 0.6, + "enforce_eager": "False", + }, + "colocated": { + "enabled": True, + "resources": { + "gpus_per_node": None, + "num_nodes": None, + }, + }, + } + codes = [ "x = 3; y = 4\nThis is some regular text.\nx + y\n", "\ndef f(x):\n return x * x\n\nf(2)\n\n", ] results = ["7", "\n\n4\n"] - # Create message logs message_logs = [] metadata_batch = [] temp_dirs = [] for code in codes: - # Tokenize the message content prompt = code * 4 token_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)[ "input_ids" @@ -177,10 +335,9 @@ def test_vllm_execute_code(cluster, tokenizer, code_env): message_logs.append( [{"role": "user", "content": prompt, "token_ids": token_ids}] ) - metadata_batch.append(CodeEnvMetadata(context={}, working_dir=temp_dir.name)) + metadata_batch.append(_make_metadata(temp_dir.name)) temp_dirs.append(temp_dir) - # Create initial batch initial_batch = BatchedDataDict( { "message_log": message_logs, @@ -190,30 +347,29 @@ def test_vllm_execute_code(cluster, tokenizer, code_env): } ) - # Create vLLM generation vllm_config = basic_vllm_test_config.copy() vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True) vllm_generation = VllmGeneration(cluster, vllm_config) - # Create code environment task_to_env = {"code_execution": code_env} - # Run rollout - vllm_generation.prepare_for_generation() - final_batch, _ = run_multi_turn_rollout( - policy_generation=vllm_generation, - input_batch=initial_batch, - tokenizer=tokenizer, - task_to_env=task_to_env, - max_seq_len=256, - max_rollout_turns=2, - greedy=True, - ) - vllm_generation.finish_generation() + try: + vllm_generation.prepare_for_generation() + final_batch, _ = run_multi_turn_rollout( + policy_generation=vllm_generation, + input_batch=initial_batch, + tokenizer=tokenizer, + task_to_env=task_to_env, + max_seq_len=256, + max_rollout_turns=2, + greedy=True, + ) + finally: + vllm_generation.finish_generation() + for temp_dir in temp_dirs: + temp_dir.cleanup() - # Check results for i, msg_log in enumerate(final_batch["message_log"]): - # Get the last message which should contain the result last_msg = msg_log[-1] assert last_msg["role"] == "environment" assert last_msg["content"] == results[i], (