From 4107156450693ebf83188d581a5a9eaf1cfc56f4 Mon Sep 17 00:00:00 2001
From: taivu1998 <46636857+taivu1998@users.noreply.github.com>
Date: Fri, 3 Apr 2026 10:49:48 -0700
Subject: [PATCH] Add CodeEnvironment execution limits
---
docs/guides/environments.md | 28 +
nemo_rl/environments/code_environment.py | 562 ++++++++++++++----
.../environments/test_code_environment.py | 306 +++++++---
3 files changed, 699 insertions(+), 197 deletions(-)
diff --git a/docs/guides/environments.md b/docs/guides/environments.md
index 64421f7cb6..987c6ffd10 100644
--- a/docs/guides/environments.md
+++ b/docs/guides/environments.md
@@ -100,6 +100,8 @@ from nemo_rl.environments.code_environment import CodeEnvironment
env_config = {
"num_workers": 2,
"terminate_on_evaluation": True, # Terminate after code execution
+ "default_timeout_seconds": 1.0, # Optional default wall-clock timeout per step
+ "default_memory_limit_bytes": 268435456, # Optional default memory limit per step
}
code_env = CodeEnvironment.remote(env_config)
@@ -108,6 +110,32 @@ code_env = CodeEnvironment.remote(env_config)
### Configuration
- `num_workers`: Number of parallel workers for code execution
- `terminate_on_evaluation`: Whether to terminate after code execution (True for single-turn, False for multi-turn).
+- `default_timeout_seconds`: Optional default wall-clock timeout for each code execution step.
+- `default_memory_limit_bytes`: Optional default virtual-memory cap for each code execution step.
+
+Timeouts are enforced in the child process, with a small parent-side startup grace so worker bootstrapping does not immediately trip very small limits.
+
+### Per-sample execution limits
+
+The code environment reads limits from `extra_env_info`, so you can set them per sample:
+
+```python
+{
+ "context": {},
+ "working_dir": "/tmp/code-env/sample-123",
+ "timeout_seconds": 1.0,
+ "memory_limit_bytes": 268435456,
+}
+```
+
+If both per-sample limits and environment defaults are provided, the per-sample values win.
+
+`context` and `working_dir` still behave the same as before:
+
+- `context` stores variables and functions defined by prior turns.
+- `working_dir` bounds file access for `open()`.
+
+The memory limit uses `resource.RLIMIT_AS` and is currently supported on Linux only. On other platforms, configuring `memory_limit_bytes` returns an error observation instead of enforcing the limit.
We are tracking an end-to-end example of this environment in [#858](https://github.com/NVIDIA-NeMo/RL/issues/858). Add a 👍 to show your interest.
diff --git a/nemo_rl/environments/code_environment.py b/nemo_rl/environments/code_environment.py
index a1047d081b..ca582ca76a 100644
--- a/nemo_rl/environments/code_environment.py
+++ b/nemo_rl/environments/code_environment.py
@@ -13,16 +13,23 @@
# limitations under the License.
import ast
import builtins
+import multiprocessing as mp
import os
import re
+import signal
+import sys
+import time
from collections.abc import Mapping, Sequence
from contextlib import contextmanager
from copy import copy
+from dataclasses import dataclass, replace
from io import IOBase
+from multiprocessing.connection import Connection
from pprint import pformat
from types import ModuleType
-from typing import Any, Dict, List, Optional, Tuple, TypedDict
+from typing import Any, Dict, List, NamedTuple, NotRequired, Optional, Tuple, TypedDict
+import cloudpickle
import ray
import torch
@@ -32,6 +39,14 @@
from nemo_rl.environments.interfaces import EnvironmentInterface, EnvironmentReturn
from nemo_rl.environments.utils import chunk_list_to_workers
+try:
+ import resource
+except ImportError: # pragma: no cover
+ resource = None
+
+
+SUBPROCESS_STARTUP_GRACE_SECONDS = 2.0
+
class CodeEnvConfig(TypedDict):
num_workers: int
@@ -39,152 +54,449 @@ class CodeEnvConfig(TypedDict):
# if you want to execute multiple rounds of code, set this to False
# and wrap CodeEnvironment in another environment that terminates the generation
terminate_on_evaluation: bool
+ default_timeout_seconds: NotRequired[float | None]
+ default_memory_limit_bytes: NotRequired[int | None]
class CodeEnvMetadata(TypedDict):
context: Dict[str, Any] # Hold functions and variables defined in the code
working_dir: str # Working directory for file operations
+ timeout_seconds: NotRequired[float | None]
+ memory_limit_bytes: NotRequired[int | None]
+
+
+class ExecutionLimits(NamedTuple):
+ timeout_seconds: float | None
+ memory_limit_bytes: int | None
+
+
+@dataclass
+class CodeExecutionRequest:
+ code: str
+ context: dict[str, Any]
+ working_dir: str
+ lookahead: str | None = None
+ timeout_seconds: float | None = None
+ memory_limit_bytes: int | None = None
+
+
+@dataclass
+class CodeExecutionResponse:
+ formatted_result: str
+ terminated: bool
+ context: dict[str, Any]
+
+
+def sanitize_object(obj: Any) -> Any:
+ """Sanitize objects that are not safe to return through Ray."""
+ if isinstance(obj, (IOBase, ModuleType)):
+ return repr(obj)
+ if isinstance(obj, Mapping):
+ return obj.__class__(
+ {sanitize_object(k): sanitize_object(v) for k, v in obj.items()}
+ )
+ if isinstance(obj, Sequence) and not isinstance(obj, str):
+ return obj.__class__(sanitize_object(v) for v in obj)
+ if hasattr(obj, "__dict__"):
+ new_obj = copy(obj)
+ new_obj.__dict__ = {
+ sanitize_object(k): sanitize_object(v) for k, v in obj.__dict__.items()
+ }
+ return new_obj
+ return obj
+
+
+def format_result(
+ result: Any, code: str | None = None, lookahead: str | None = None
+) -> str:
+ """Format a code execution result as an environment observation."""
+ if result is None:
+ return ""
+
+ result = pformat(result)
+ multiline = (code and "\n" in code) or "\n" in result
+ if multiline:
+ formatted_result = f"\n\n\n{result}\n"
+ else:
+ formatted_result = f"{result}"
+
+ if lookahead and formatted_result.startswith(lookahead):
+ # The generation may look like "\n" if ">\n" is a single token.
+ # We trim \n from the result if the model has already generated it.
+ formatted_result = formatted_result[len(lookahead) :]
+
+ return formatted_result
+
+
+def _validate_timeout_seconds(timeout_seconds: float | None) -> float | None:
+ if timeout_seconds is None:
+ return None
+ if isinstance(timeout_seconds, bool) or not isinstance(
+ timeout_seconds, (int, float)
+ ):
+ raise TypeError(
+ "timeout_seconds must be a positive number or None, "
+ f"got {type(timeout_seconds)}"
+ )
+
+ timeout_value = float(timeout_seconds)
+ if timeout_value <= 0:
+ raise ValueError("timeout_seconds must be greater than 0")
+
+ return timeout_value
+
+
+def _validate_memory_limit_bytes(memory_limit_bytes: int | None) -> int | None:
+ if memory_limit_bytes is None:
+ return None
+ if isinstance(memory_limit_bytes, bool) or not isinstance(memory_limit_bytes, int):
+ raise TypeError(
+ "memory_limit_bytes must be a positive integer number of bytes or None"
+ )
+ if memory_limit_bytes <= 0:
+ raise ValueError("memory_limit_bytes must be greater than 0")
+
+ return memory_limit_bytes
+
+
+def _resolve_execution_limits(
+ metadata: CodeEnvMetadata,
+ *,
+ default_timeout_seconds: float | None,
+ default_memory_limit_bytes: int | None,
+) -> ExecutionLimits:
+ timeout_seconds = _validate_timeout_seconds(
+ metadata.get("timeout_seconds", default_timeout_seconds)
+ )
+ memory_limit_bytes = _validate_memory_limit_bytes(
+ metadata.get("memory_limit_bytes", default_memory_limit_bytes)
+ )
+ return ExecutionLimits(
+ timeout_seconds=timeout_seconds,
+ memory_limit_bytes=memory_limit_bytes,
+ )
+
+
+def _supports_memory_limit() -> bool:
+ return (
+ resource is not None
+ and hasattr(resource, "RLIMIT_AS")
+ and sys.platform.startswith("linux")
+ )
+
+
+def _supports_signal_timeout() -> bool:
+ return hasattr(signal, "SIGALRM") and hasattr(signal, "setitimer")
+
+
+def _timeout_error(timeout_seconds: float) -> TimeoutError:
+ return TimeoutError(
+ "Code execution exceeded the configured timeout "
+ f"of {timeout_seconds} seconds"
+ )
+
+
+@contextmanager
+def _execution_timeout(timeout_seconds: float | None):
+ if timeout_seconds is None or not _supports_signal_timeout():
+ yield
+ return
+
+ previous_handler = signal.getsignal(signal.SIGALRM)
+
+ def _handle_timeout(signum, frame):
+ raise _timeout_error(timeout_seconds)
+
+ signal.signal(signal.SIGALRM, _handle_timeout)
+ signal.setitimer(signal.ITIMER_REAL, timeout_seconds)
+ try:
+ yield
+ finally:
+ signal.setitimer(signal.ITIMER_REAL, 0)
+ signal.signal(signal.SIGALRM, previous_handler)
+
+
+@contextmanager
+def _chdir(dir_path: str):
+ """Change to a temporary directory for file operations."""
+ current_dir = os.getcwd()
+ os.chdir(dir_path)
+ try:
+ yield
+ finally:
+ os.chdir(current_dir)
+
+
+def _safe_open(file: str, *args, **kwargs):
+ """Safe version of open() that only allows access to the current directory."""
+ real_file = os.path.realpath(file)
+ working_dir = os.path.realpath(os.getcwd())
+ if os.path.commonpath([real_file, working_dir]) != working_dir:
+ raise PermissionError(
+ "Access beyond the temporary working directory is blocked"
+ )
+ return open(file, *args, **kwargs)
+
+
+def _safe_import(name: str, *args, **kwargs):
+ """Safe version of import that blocks risky modules."""
+ risky_modules = {
+ "os",
+ "shutil", # erase filesystem
+ "sys",
+ "signal", # exit the current program
+ "socket", # network communication
+ "subprocess",
+ "threading",
+ "multiprocessing", # spawn threads or processes
+ "builtins",
+ "importlib", # bypass current blockers
+ }
+ if name in risky_modules:
+ raise PermissionError("Importing system and network modules is blocked")
+ return builtins.__import__(name, *args, **kwargs)
+
+
+def _create_sandbox() -> dict[str, Any]:
+ builtin_dict = {k: getattr(builtins, k) for k in dir(builtins)}
+ builtin_dict["open"] = _safe_open
+ builtin_dict["__import__"] = _safe_import
+ return {"__builtins__": builtin_dict}
+
+
+def _apply_memory_limit(memory_limit_bytes: int | None) -> None:
+ if memory_limit_bytes is None:
+ return
+ if not _supports_memory_limit():
+ raise RuntimeError(
+ "memory_limit_bytes is not supported on this platform because "
+ "resource.RLIMIT_AS is unavailable"
+ )
+
+ assert resource is not None # for type checkers
+ resource.setrlimit(
+ resource.RLIMIT_AS,
+ (memory_limit_bytes, memory_limit_bytes),
+ )
+
+
+def _execute_code_request(request: CodeExecutionRequest) -> CodeExecutionResponse:
+ sandbox = _create_sandbox()
+ result: Any = None
+ terminated = False
+
+ try:
+ with _execution_timeout(request.timeout_seconds):
+ tree = ast.parse(request.code)
+ if tree.body and isinstance(tree.body[-1], ast.Expr):
+ exec_code = ast.unparse(tree.body[:-1])
+ eval_code = ast.unparse(tree.body[-1])
+ else:
+ exec_code = request.code
+ eval_code = None
+
+ with _chdir(request.working_dir):
+ _apply_memory_limit(request.memory_limit_bytes)
+ exec(exec_code, sandbox, request.context)
+ if eval_code:
+ result = eval(eval_code, sandbox, request.context)
+ terminated = True
+ except Exception as err:
+ result = err
+
+ return CodeExecutionResponse(
+ formatted_result=format_result(result, request.code, request.lookahead),
+ terminated=terminated,
+ context=request.context,
+ )
+
+
+def _serialize_request(
+ request: CodeExecutionRequest,
+) -> tuple[bytes, CodeExecutionRequest]:
+ try:
+ return cloudpickle.dumps(request), request
+ except Exception:
+ transport_request = replace(request, context=sanitize_object(request.context))
+ return cloudpickle.dumps(transport_request), transport_request
+
+
+def _serialize_response(response: CodeExecutionResponse) -> bytes:
+ try:
+ return cloudpickle.dumps(response)
+ except Exception:
+ sanitized_response = replace(
+ response, context=sanitize_object(response.context)
+ )
+ return cloudpickle.dumps(sanitized_response)
+
+
+def _subprocess_main(
+ send_conn: Connection, request_bytes: bytes
+) -> None: # pragma: no cover
+ try:
+ request = cloudpickle.loads(request_bytes)
+ response = _execute_code_request(request)
+ send_conn.send_bytes(_serialize_response(response))
+ finally:
+ send_conn.close()
+
+
+def _wait_for_subprocess_response(
+ recv_conn: Connection,
+ process: mp.Process,
+ timeout_seconds: float | None,
+) -> tuple[CodeExecutionResponse | None, bool]:
+ wait_timeout_seconds = None
+ if timeout_seconds is not None:
+ wait_timeout_seconds = timeout_seconds + SUBPROCESS_STARTUP_GRACE_SECONDS
+ deadline = (
+ None
+ if wait_timeout_seconds is None
+ else time.monotonic() + wait_timeout_seconds
+ )
+
+ while True:
+ poll_timeout = 0.1
+ if deadline is not None:
+ remaining = deadline - time.monotonic()
+ if remaining <= 0:
+ return None, True
+ poll_timeout = min(poll_timeout, remaining)
+
+ if recv_conn.poll(poll_timeout):
+ try:
+ response = cloudpickle.loads(recv_conn.recv_bytes())
+ except EOFError:
+ return None, False
+ return response, False
+
+ if not process.is_alive():
+ return None, False
+
+
+def _execute_code_in_subprocess(
+ request: CodeExecutionRequest,
+) -> CodeExecutionResponse:
+ ctx = mp.get_context("spawn")
+ recv_conn, send_conn = ctx.Pipe(duplex=False)
+ request_bytes, transport_request = _serialize_request(request)
+ process = ctx.Process(
+ target=_subprocess_main,
+ args=(send_conn, request_bytes),
+ daemon=True,
+ )
+
+ try:
+ process.start()
+ send_conn.close()
+ response, timed_out = _wait_for_subprocess_response(
+ recv_conn,
+ process,
+ transport_request.timeout_seconds,
+ )
+
+ if response is not None:
+ process.join()
+ return replace(response, context=sanitize_object(response.context))
+
+ if timed_out and process.is_alive():
+ process.kill()
+ process.join()
+ return CodeExecutionResponse(
+ formatted_result=format_result(
+ _timeout_error(transport_request.timeout_seconds)
+ ),
+ terminated=False,
+ context=sanitize_object(transport_request.context),
+ )
+
+ process.join()
+
+ if (
+ transport_request.memory_limit_bytes is not None
+ and process.exitcode is not None
+ and process.exitcode < 0
+ ):
+ error: Exception = MemoryError(
+ "Code execution exceeded the configured memory limit "
+ f"of {transport_request.memory_limit_bytes} bytes"
+ )
+ else:
+ error = RuntimeError(
+ "Code execution subprocess exited unexpectedly "
+ f"with exit code {process.exitcode}"
+ )
+ return CodeExecutionResponse(
+ formatted_result=format_result(error),
+ terminated=False,
+ context=sanitize_object(transport_request.context),
+ )
+ finally:
+ recv_conn.close()
+ if process.is_alive():
+ process.kill()
+ process.join()
@ray.remote # pragma: no cover
class CodeExecutionWorker:
"""Helper class to process individual code execution steps."""
- def __init__(self):
- # Create sandbox with safe builtins
- builtin_dict = {k: getattr(builtins, k) for k in dir(builtins)}
- builtin_dict["open"] = self.safe_open
- builtin_dict["__import__"] = self.safe_import
- self.sandbox = {"__builtins__": builtin_dict}
-
- def sanitize(self, obj: Any) -> Any:
- # TODO: better handling of unpicklable objects: custom __getstate__ and __setstate__
- # recursively remove all file objects as they are not picklable by ray
- if isinstance(obj, (IOBase, ModuleType)):
- # replace unpickable objects with a string representation
- return repr(obj)
- if isinstance(obj, Mapping):
- return obj.__class__(
- {self.sanitize(k): self.sanitize(v) for k, v in obj.items()}
- )
- if isinstance(obj, Sequence) and not isinstance(obj, str):
- return obj.__class__(self.sanitize(v) for v in obj)
- if hasattr(obj, "__dict__"):
- new_obj = copy(obj)
- new_obj.__dict__ = {
- self.sanitize(k): self.sanitize(v) for k, v in obj.__dict__.items()
- }
- return new_obj
- return obj
-
- def format_result(
- self, result: Any, code: Optional[str] = None, lookahead: Optional[str] = None
- ) -> str:
- if result is None:
- # no return value
- return ""
- result = pformat(result)
- multiline = (code and "\n" in code) or "\n" in result
- if multiline:
- # multi-line format
- result = f"\n\n\n{result}\n"
- else:
- # inline format
- result = f"{result}"
- if lookahead:
- if result.startswith(lookahead):
- # The generation may look like "\n" if ">\n" is a single token.
- # We trim \n from the result if the model has already generated it.
- result = result[len(lookahead) :]
- return result
+ def __init__(
+ self,
+ *,
+ default_timeout_seconds: float | None,
+ default_memory_limit_bytes: int | None,
+ ):
+ self.default_timeout_seconds = default_timeout_seconds
+ self.default_memory_limit_bytes = default_memory_limit_bytes
def execute(
- self, message_batch: str, metadata_batch: List[CodeEnvMetadata]
+ self, message_batch: list[str], metadata_batch: List[CodeEnvMetadata]
) -> Tuple[List[Dict[str, str]], List[bool], List[Any]]:
"""Execute code in a sandboxed environment."""
results = []
terminateds = []
+ updated_metadata_batch: list[CodeEnvMetadata] = []
for message, metadata in zip(message_batch, metadata_batch):
match = re.search(r"(.*)(.*)", message, re.DOTALL)
if not match:
results.append("")
terminateds.append(False)
+ updated_metadata_batch.append(metadata)
continue
code, lookahead = match.groups()
- tree = ast.parse(code)
-
- if tree.body and isinstance(tree.body[-1], ast.Expr):
- # Interactive mode
- exec_code = ast.unparse(tree.body[:-1])
- eval_code = ast.unparse(tree.body[-1])
- else:
- # Silent mode
- exec_code = code
- eval_code = None
+ execution_limits = _resolve_execution_limits(
+ metadata,
+ default_timeout_seconds=self.default_timeout_seconds,
+ default_memory_limit_bytes=self.default_memory_limit_bytes,
+ )
+ response = _execute_code_in_subprocess(
+ CodeExecutionRequest(
+ code=code,
+ context=metadata["context"],
+ working_dir=metadata["working_dir"],
+ lookahead=lookahead,
+ timeout_seconds=execution_limits.timeout_seconds,
+ memory_limit_bytes=execution_limits.memory_limit_bytes,
+ )
+ )
+ updated_metadata = dict(metadata)
+ updated_metadata["context"] = response.context
- result = None
- terminated = False
- with self.chdir(metadata["working_dir"]):
- try:
- # isolate the code in a sandbox
- # capture local variables in metadata["context"]
- exec(exec_code, self.sandbox, metadata["context"])
- if eval_code:
- result = eval(eval_code, self.sandbox, metadata["context"])
- terminated = True
- except Exception as err:
- result = err
-
- result = self.format_result(result, code, lookahead)
- results.append(result)
- terminateds.append(terminated)
+ results.append(response.formatted_result)
+ terminateds.append(response.terminated)
+ updated_metadata_batch.append(updated_metadata)
observations = [
{"role": "environment", "content": result} for result in results
]
- metadata_batch = self.sanitize(metadata_batch)
-
- return observations, terminateds, metadata_batch
-
- @contextmanager
- def chdir(self, dir: str):
- """Change to temporary directory for file operations."""
- current_dir = os.getcwd()
- os.chdir(dir)
- try:
- yield
- finally:
- os.chdir(current_dir)
-
- def safe_open(self, file: str, *args, **kwargs):
- """Safe version of open() that only allows access to temporary directory."""
- real_file = os.path.realpath(file)
- working_dir = os.path.realpath(os.getcwd())
- if os.path.commonpath([real_file, working_dir]) != working_dir:
- raise PermissionError(
- "Access beyond the temporary working directory is blocked"
- )
- return open(file, *args, **kwargs)
-
- def safe_import(self, name: str, *args, **kwargs):
- """Safe version of import that blocks risky modules."""
- risky_modules = {
- "os",
- "shutil", # erase filesystem
- "sys",
- "signal", # exit the current program
- "socket", # network communication
- "subprocess",
- "threading",
- "multiprocessing", # spawn threads or processes
- "builtins",
- "importlib", # bypass current blockers
- }
- if name in risky_modules:
- raise PermissionError("Importing system and network modules is blocked")
- return builtins.__import__(name, *args, **kwargs)
+ updated_metadata_batch = sanitize_object(updated_metadata_batch)
+
+ return observations, terminateds, updated_metadata_batch
@ray.remote # pragma: no cover
@@ -195,10 +507,19 @@ def __init__(self, cfg: CodeEnvConfig):
self.cfg = cfg
self.num_workers = cfg["num_workers"]
self.terminate_on_evaluation = cfg["terminate_on_evaluation"]
+ self.default_timeout_seconds = _validate_timeout_seconds(
+ cfg.get("default_timeout_seconds")
+ )
+ self.default_memory_limit_bytes = _validate_memory_limit_bytes(
+ cfg.get("default_memory_limit_bytes")
+ )
self.workers = [
CodeExecutionWorker.options(
runtime_env={"py_executable": PY_EXECUTABLES.SYSTEM}
- ).remote()
+ ).remote(
+ default_timeout_seconds=self.default_timeout_seconds,
+ default_memory_limit_bytes=self.default_memory_limit_bytes,
+ )
for _ in range(self.num_workers)
]
@@ -213,7 +534,6 @@ def step(
chunked_message_batch = chunk_list_to_workers(message_batch, self.num_workers)
chunked_metadata_batch = chunk_list_to_workers(metadata_batch, self.num_workers)
- # Process each chunk in parallel
futures = [
self.workers[i].execute.remote(message_chunk, metadata_chunk)
for i, (message_chunk, metadata_chunk) in enumerate(
@@ -223,7 +543,6 @@ def step(
results = ray.get(futures)
- # Unpack results
observations = []
terminateds = []
new_metadata_batch = []
@@ -242,7 +561,8 @@ def step(
next_stop_strings = [[""]] * len(message_log_batch)
assert return_extracted_answer == False, (
- "return_extracted_answer is not supported in CodeEnvironment. Please set it to False."
+ "return_extracted_answer is not supported in CodeEnvironment. "
+ "Please set it to False."
)
extracted_answers = None
@@ -256,7 +576,6 @@ def step(
)
def shutdown(self):
- # shutdown all workers
for worker in self.workers:
ray.kill(worker)
@@ -264,5 +583,4 @@ def global_post_process_and_metrics(
self, batch: BatchedDataDict
) -> Tuple[BatchedDataDict, dict]:
"""Compute metrics for the batch."""
- # No specific metrics for code execution
return batch, {}
diff --git a/tests/unit/environments/test_code_environment.py b/tests/unit/environments/test_code_environment.py
index d32550aba1..bc6cdaa2eb 100644
--- a/tests/unit/environments/test_code_environment.py
+++ b/tests/unit/environments/test_code_environment.py
@@ -12,22 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from contextlib import ExitStack
from tempfile import TemporaryDirectory
import pytest
import ray
from transformers import AutoTokenizer
-from nemo_rl.distributed.batched_data_dict import BatchedDataDict
-from nemo_rl.distributed.virtual_cluster import RayVirtualCluster
from nemo_rl.environments.code_environment import (
CodeEnvConfig,
CodeEnvironment,
CodeEnvMetadata,
+ _resolve_execution_limits,
+ _supports_memory_limit,
)
-from nemo_rl.experience.rollouts import run_multi_turn_rollout
-from nemo_rl.models.generation import configure_generation_config
-from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration
MODEL_NAME = "meta-llama/Llama-3.2-1B"
@@ -36,43 +34,41 @@
"terminate_on_evaluation": True,
}
-# Define basic vLLM test config
-basic_vllm_test_config: VllmConfig = {
- "backend": "vllm",
- "model_name": MODEL_NAME,
- "tokenizer_name": None,
- "dtype": "bfloat16",
- "max_new_tokens": 100,
- "temperature": 1.0,
- "top_p": 1.0,
- "top_k": None,
- "stop_token_ids": None,
- "stop_strings": None,
- "vllm_cfg": {
- "async_engine": False,
- "precision": "bfloat16",
- "tensor_parallel_size": 1,
- "pipeline_parallel_size": 1,
- "expert_parallel_size": 1,
- "max_model_len": 1024,
- "disable_log_stats": True,
- "disable_log_requests": True,
- "gpu_memory_utilization": 0.6,
- "enforce_eager": "False",
- },
- "colocated": {
- "enabled": True,
- "resources": {
- "gpus_per_node": None,
- "num_nodes": None,
- },
- },
-}
+
+def _make_metadata(
+ working_dir: str,
+ *,
+ timeout_seconds: float | None = None,
+ memory_limit_bytes: int | None = None,
+) -> CodeEnvMetadata:
+ metadata: CodeEnvMetadata = {
+ "context": {},
+ "working_dir": working_dir,
+ }
+ if timeout_seconds is not None:
+ metadata["timeout_seconds"] = timeout_seconds
+ if memory_limit_bytes is not None:
+ metadata["memory_limit_bytes"] = memory_limit_bytes
+ return metadata
+
+
+def _step_code(
+ env_actor,
+ code: str,
+ metadata: CodeEnvMetadata,
+):
+ return ray.get(
+ env_actor.step.remote(
+ [[{"role": "user", "content": f"{code}"}]],
+ [metadata],
+ )
+ )
@pytest.fixture(scope="function")
def code_env():
"""Create a code environment for testing."""
+ env_actor = None
try:
env_actor = CodeEnvironment.remote(cfg)
yield env_actor
@@ -89,7 +85,9 @@ def tokenizer():
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(
- f"Tokenizer loaded. Pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id}), EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})"
+ "Tokenizer loaded. "
+ f"Pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id}), "
+ f"EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})"
)
return tokenizer
@@ -97,6 +95,8 @@ def tokenizer():
@pytest.fixture(scope="function")
def cluster():
"""Create a virtual cluster for testing."""
+ from nemo_rl.distributed.virtual_cluster import RayVirtualCluster
+
cluster_instance = None
cluster_name = f"test-code-cluster-{id(cluster_instance)}"
print(f"\nCreating virtual cluster '{cluster_name}'...")
@@ -115,6 +115,40 @@ def cluster():
cluster_instance.shutdown()
+def test_resolve_execution_limits_uses_metadata_overrides():
+ metadata = CodeEnvMetadata(
+ context={},
+ working_dir="/tmp/code-env",
+ timeout_seconds=2.0,
+ memory_limit_bytes=2048,
+ )
+
+ limits = _resolve_execution_limits(
+ metadata,
+ default_timeout_seconds=1.0,
+ default_memory_limit_bytes=1024,
+ )
+
+ assert limits.timeout_seconds == 2.0
+ assert limits.memory_limit_bytes == 2048
+
+
+def test_resolve_execution_limits_uses_defaults_when_metadata_missing():
+ metadata = CodeEnvMetadata(
+ context={},
+ working_dir="/tmp/code-env",
+ )
+
+ limits = _resolve_execution_limits(
+ metadata,
+ default_timeout_seconds=1.5,
+ default_memory_limit_bytes=4096,
+ )
+
+ assert limits.timeout_seconds == 1.5
+ assert limits.memory_limit_bytes == 4096
+
+
def test_untrusted_code(code_env):
"""Test whether the code environment can block untrusted code."""
codes = [
@@ -129,46 +163,170 @@ def test_untrusted_code(code_env):
]
results = [
"\n\n\n'some content'\n",
- "\n\n\nPermissionError('Access beyond the temporary working directory is blocked')\n",
+ (
+ "\n\n\n"
+ "PermissionError("
+ "'Access beyond the temporary working directory is blocked'"
+ ")\n"
+ ""
+ ),
"\n\n\n3\n",
- "PermissionError('Importing system and network modules is blocked')",
+ (
+ ""
+ "PermissionError('Importing system and network modules is blocked')"
+ ""
+ ),
]
- message_log_batch = [
- [{"role": "user", "content": f"{code}"}] for code in codes
- ]
- temp_dirs = [TemporaryDirectory() for _ in codes]
- metadata_batch = [
- CodeEnvMetadata(
- context={},
- working_dir=temp_dir.name,
- )
- for temp_dir in temp_dirs
- ]
+ with ExitStack() as stack:
+ temp_dirs = [stack.enter_context(TemporaryDirectory()) for _ in codes]
+ metadata_batch = [
+ _make_metadata(temp_dir)
+ for temp_dir in temp_dirs
+ ]
- # Execute the code
- output = ray.get(code_env.step.remote(message_log_batch, metadata_batch))
- responses = [obs["content"] for obs in output.observations]
+ output = ray.get(
+ code_env.step.remote(
+ [
+ [{"role": "user", "content": f"{code}"}]
+ for code in codes
+ ],
+ metadata_batch,
+ )
+ )
+ responses = [obs["content"] for obs in output.observations]
assert responses == results, f"Got wrong output {responses}"
+def test_syntax_error_returns_observation_and_actor_stays_healthy(code_env):
+ with TemporaryDirectory() as temp_dir:
+ metadata = _make_metadata(temp_dir)
+ output = _step_code(code_env, "def broken(:\n pass", metadata)
+
+ assert "SyntaxError" in output.observations[0]["content"]
+ follow_up = _step_code(code_env, "1 + 1", output.metadata[0])
+ assert follow_up.observations[0]["content"] == "2"
+
+
+def test_timeout_returns_observation_and_actor_stays_healthy(code_env):
+ with TemporaryDirectory() as temp_dir:
+ metadata = _make_metadata(temp_dir, timeout_seconds=0.5)
+ output = _step_code(code_env, "while True:\n pass", metadata)
+
+ assert "TimeoutError" in output.observations[0]["content"]
+ follow_up = _step_code(code_env, "40 + 2", output.metadata[0])
+ assert follow_up.observations[0]["content"] == "42"
+
+
+@pytest.mark.skipif(
+ not _supports_memory_limit(),
+ reason="Memory limits require resource.RLIMIT_AS support",
+)
+def test_memory_limit_returns_observation_and_actor_stays_healthy(code_env):
+ with TemporaryDirectory() as temp_dir:
+ metadata = _make_metadata(
+ temp_dir,
+ memory_limit_bytes=256 * 1024 * 1024,
+ )
+ output = _step_code(
+ code_env,
+ "blob = bytearray(1024 * 1024 * 1024)\nlen(blob)",
+ metadata,
+ )
+
+ assert "MemoryError" in output.observations[0]["content"]
+ follow_up = _step_code(code_env, "6 * 7", output.metadata[0])
+ assert follow_up.observations[0]["content"] == "42"
+
+
+def test_default_timeout_applies_when_metadata_omits_limit():
+ env_actor = None
+ try:
+ env_actor = CodeEnvironment.remote(
+ CodeEnvConfig(
+ num_workers=1,
+ terminate_on_evaluation=True,
+ default_timeout_seconds=0.5,
+ )
+ )
+ with TemporaryDirectory() as temp_dir:
+ metadata = _make_metadata(temp_dir)
+ output = _step_code(env_actor, "while True:\n pass", metadata)
+
+ assert "TimeoutError" in output.observations[0]["content"]
+ finally:
+ if env_actor:
+ ray.kill(env_actor)
+
+
+def test_multiturn_context_survives_subprocess_boundary(code_env):
+ with TemporaryDirectory() as temp_dir:
+ metadata = _make_metadata(temp_dir)
+ first_output = _step_code(
+ code_env,
+ "def square(x):\n return x * x\nvalue = 7",
+ metadata,
+ )
+ second_output = _step_code(
+ code_env,
+ "square(value)",
+ first_output.metadata[0],
+ )
+
+ assert second_output.observations[0]["content"] == "49"
+
+
@pytest.mark.hf_gated
def test_vllm_execute_code(cluster, tokenizer, code_env):
"""Test that vLLM can call the code executor."""
- # Prepare test data
+ from nemo_rl.distributed.batched_data_dict import BatchedDataDict
+ from nemo_rl.experience.rollouts import run_multi_turn_rollout
+ from nemo_rl.models.generation import configure_generation_config
+ from nemo_rl.models.generation.vllm import VllmGeneration
+
+ basic_vllm_test_config = {
+ "backend": "vllm",
+ "model_name": MODEL_NAME,
+ "tokenizer_name": None,
+ "dtype": "bfloat16",
+ "max_new_tokens": 100,
+ "temperature": 1.0,
+ "top_p": 1.0,
+ "top_k": None,
+ "stop_token_ids": None,
+ "stop_strings": None,
+ "vllm_cfg": {
+ "async_engine": False,
+ "precision": "bfloat16",
+ "tensor_parallel_size": 1,
+ "pipeline_parallel_size": 1,
+ "expert_parallel_size": 1,
+ "max_model_len": 1024,
+ "disable_log_stats": True,
+ "disable_log_requests": True,
+ "gpu_memory_utilization": 0.6,
+ "enforce_eager": "False",
+ },
+ "colocated": {
+ "enabled": True,
+ "resources": {
+ "gpus_per_node": None,
+ "num_nodes": None,
+ },
+ },
+ }
+
codes = [
"x = 3; y = 4\nThis is some regular text.\nx + y\n",
"\ndef f(x):\n return x * x\n\nf(2)\n\n",
]
results = ["7", "\n\n4\n"]
- # Create message logs
message_logs = []
metadata_batch = []
temp_dirs = []
for code in codes:
- # Tokenize the message content
prompt = code * 4
token_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)[
"input_ids"
@@ -177,10 +335,9 @@ def test_vllm_execute_code(cluster, tokenizer, code_env):
message_logs.append(
[{"role": "user", "content": prompt, "token_ids": token_ids}]
)
- metadata_batch.append(CodeEnvMetadata(context={}, working_dir=temp_dir.name))
+ metadata_batch.append(_make_metadata(temp_dir.name))
temp_dirs.append(temp_dir)
- # Create initial batch
initial_batch = BatchedDataDict(
{
"message_log": message_logs,
@@ -190,30 +347,29 @@ def test_vllm_execute_code(cluster, tokenizer, code_env):
}
)
- # Create vLLM generation
vllm_config = basic_vllm_test_config.copy()
vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True)
vllm_generation = VllmGeneration(cluster, vllm_config)
- # Create code environment
task_to_env = {"code_execution": code_env}
- # Run rollout
- vllm_generation.prepare_for_generation()
- final_batch, _ = run_multi_turn_rollout(
- policy_generation=vllm_generation,
- input_batch=initial_batch,
- tokenizer=tokenizer,
- task_to_env=task_to_env,
- max_seq_len=256,
- max_rollout_turns=2,
- greedy=True,
- )
- vllm_generation.finish_generation()
+ try:
+ vllm_generation.prepare_for_generation()
+ final_batch, _ = run_multi_turn_rollout(
+ policy_generation=vllm_generation,
+ input_batch=initial_batch,
+ tokenizer=tokenizer,
+ task_to_env=task_to_env,
+ max_seq_len=256,
+ max_rollout_turns=2,
+ greedy=True,
+ )
+ finally:
+ vllm_generation.finish_generation()
+ for temp_dir in temp_dirs:
+ temp_dir.cleanup()
- # Check results
for i, msg_log in enumerate(final_batch["message_log"]):
- # Get the last message which should contain the result
last_msg = msg_log[-1]
assert last_msg["role"] == "environment"
assert last_msg["content"] == results[i], (