Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
81d2595
fix typing
hleehlee-amazon Feb 5, 2026
dedc550
Add automatic upstream sync for main and upstream_push branches
hleehlee-amazon Feb 5, 2026
58fa410
Add fork sync runbook
hleehlee-amazon Feb 5, 2026
a986f82
Merge branch 'THUDM:main' into main
nanjiangwill Feb 10, 2026
2cc03fb
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Feb 13, 2026
d38d574
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Feb 14, 2026
865ad35
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Feb 15, 2026
3425a66
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Feb 20, 2026
3f2cdc9
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Feb 21, 2026
197288a
add workflow write permission
hleehlee-amazon Feb 24, 2026
6ed5945
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Feb 24, 2026
9098406
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Feb 26, 2026
08ae012
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Feb 27, 2026
bb1c117
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Feb 28, 2026
d1aa446
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 1, 2026
56c0da3
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 3, 2026
06330cd
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 4, 2026
3048207
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 5, 2026
d984817
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 8, 2026
93a3351
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 9, 2026
2fe25f9
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 10, 2026
a79bfb8
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 12, 2026
4a07071
Merge branch 'THUDM:main' into main
nanjiangwill Mar 12, 2026
1462a26
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 14, 2026
1fa22b4
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 18, 2026
d8c2047
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 20, 2026
8292c76
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 21, 2026
4e264a3
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 22, 2026
9e593d5
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 23, 2026
f2d7479
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 24, 2026
e1d56d9
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 26, 2026
2084c09
Merge remote-tracking branch 'upstream/main'
github-actions[bot] Mar 27, 2026
5eff986
[slime] feat: add lightweight hook system for observability
andrija-s Mar 27, 2026
54e9214
chore: remove fork-specific files from hook proposal
andrija-s Mar 27, 2026
5e3511d
docs: add HOOKS.md explaining hook system and OTel example
andrija-s Mar 27, 2026
d538529
[slime] feat: add NODE_INIT hook for per-node metrics collection
andrija-s Mar 27, 2026
6973270
Revert "[slime] feat: add NODE_INIT hook for per-node metrics collect…
andrija-s Mar 27, 2026
94e8bd4
Merge branch 'main' into hook-proposal
andrija-s Mar 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions HOOKS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Training Loop Hooks

Lightweight hook system for observability and extensibility of the SLIME training loop.

**Issue:** https://github.com/THUDM/slime/issues/1728

## Problem

Downstream consumers of SLIME (custom telemetry, logging, profiling) need to wrap
key training loop operations without modifying internal function signatures or
monkey-patching. Changes to `train.py` and `train_async.py` should be minimal and
conflict-resistant when syncing with upstream.

## Design

`slime/hooks.py` provides:

- **`Op` enum** -- names every hookable point in the training loop
- **`hook(op, rollout_id)`** -- context manager that fires pre/post callbacks
- **`on_pre(op, fn)` / `on_post(op, fn)`** -- callback registration

When no callbacks are registered, `hook()` is a near-zero-cost no-op.

### Hooked operations

```
for rollout_id in range(...):
ITERATION
|-- EVAL # pre-train eval (rollout_id == 0)
|-- GENERATE # ray.get(rollout_manager.generate.remote())
|-- OFFLOAD_ROLLOUT # ray.get(rollout_manager.offload.remote())
|-- TRAIN # ray.get(actor_model.async_train())
|-- SAVE_MODEL # actor_model.save_model()
|-- OFFLOAD_TRAIN # actor_model.offload() / clear_memory()
|-- ONLOAD_ROLLOUT_WEIGHTS
|-- UPDATE_WEIGHTS # actor_model.update_weights()
|-- ONLOAD_ROLLOUT_KV
|-- EVAL # periodic eval
+-- ASYNC_ROLLOUT_SYNC # train_async.py only
```

All hooks receive `rollout_id` as their single positional argument. Additional
keyword arguments can be passed for specialized hooks.

### Call-site example (train.py)

```python
from slime.hooks import Op, hook

with hook(Op.GENERATE, rollout_id):
rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id))
```

## Example: OpenTelemetry tracing

This is how we use hooks in our fork to add OTel span tracing to training runs.
No SLIME source code is modified beyond what's in this PR -- all OTel logic lives
in our downstream `autonomy/` package.

### Registering callbacks

```python
# autonomy/telemetry/instrumentation.py

from opentelemetry import context as context_api
from opentelemetry import trace

from slime.hooks import Op, on_pre, on_post


def register_otel_hooks():
"""Register OTel span callbacks for all hook operations."""
for op in Op:
span_name = op.value

def on_start(**kwargs):
tracer = trace.get_tracer(__name__)
attrs = {}
rollout_id = kwargs.get("rollout_id")
if rollout_id is not None:
attrs["rollout.id"] = rollout_id

span = tracer.start_span(span_name, attributes=attrs)
token = context_api.attach(trace.set_span_in_context(span))
return {"_otel_span": span, "_otel_token": token}

def on_end(error=None, **kwargs):
span = kwargs.get("_otel_span")
token = kwargs.get("_otel_token")
if span is not None:
if error is not None:
span.set_status(trace.StatusCode.ERROR, str(error))
span.record_exception(error)
span.end()
if token is not None:
context_api.detach(token)

on_pre(op, on_start)
on_post(op, on_end)
```

### Activating in the entry point

```python
# Called once before training starts (e.g. in your launcher entry point)
register_otel_hooks()
```

### What this produces

With `SLIME_OTEL_ENABLED=1`, each training iteration produces a trace like:

```
iteration (rollout_id=0) 102.90s
|-- eval 24.82s
|-- generate 15.64s
|-- train 53.49s
|-- update_weights 2.20s
|-- offload_train 0.01s
+-- onload_rollout_weights 0.10s
```

Spans nest automatically -- `generate` and `train` are children of `iteration`
because OTel context propagates within the process.

### Pre-callback state injection

Pre callbacks can return a dict to inject state into post callbacks. This is how
the OTel example passes the span handle from `on_start` to `on_end`:

```python
def on_start(**kwargs):
span = tracer.start_span(...)
return {"_otel_span": span} # injected into on_end's **kwargs

def on_end(**kwargs):
span = kwargs["_otel_span"] # received from on_start
span.end()
```

### Error handling

Post callbacks always fire, even if the wrapped operation raises. The exception
is passed as `error`:

```python
def on_end(error=None, **kwargs):
if error is not None:
# operation failed
span.set_status(trace.StatusCode.ERROR, str(error))
span.end()
```

Callback exceptions are logged but never propagate -- hooks cannot break training.
184 changes: 184 additions & 0 deletions slime/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""Lightweight hook system for observability and extensibility.

Provides a context-manager ``hook(op, rollout_id)`` that fires registered
pre/post callbacks around any wrapped operation. The ``Op`` enum names
every hookable point; callbacks are registered with ``on_pre`` / ``on_post``.

Pre callbacks may return a ``dict`` whose entries are merged into
subsequent pre and post callbacks as extra ``**kwargs``. The context
manager yields the merged injected dict so that the call-site body can
read values set by pre callbacks.

When no callbacks are registered for an ``Op``, the context manager is
a near-zero-cost no-op.

See: https://github.com/THUDM/slime/issues/1728

Hooked operations in the training loop (train.py / train_async.py)::

for rollout_id in range(...):
ITERATION
├── EVAL # pre-train eval (rollout_id == 0)
├── GENERATE # ray.get(rollout_manager.generate.remote())
├── OFFLOAD_ROLLOUT # ray.get(rollout_manager.offload.remote())
├── TRAIN # ray.get(actor_model.async_train())
├── SAVE_MODEL # actor_model.save_model() + critic_model.save_model()
├── OFFLOAD_TRAIN # actor_model.offload() / clear_memory()
├── ONLOAD_ROLLOUT_WEIGHTS
├── UPDATE_WEIGHTS # actor_model.update_weights()
├── ONLOAD_ROLLOUT_KV
├── EVAL # periodic eval
└── ASYNC_ROLLOUT_SYNC # train_async.py only: sync before weight update

Example — call-site in train.py::

from slime.hooks import Op, hook

with hook(Op.GENERATE, rollout_id):
rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id))
"""

from __future__ import annotations

import logging
from collections.abc import Callable, Generator
from contextlib import contextmanager
from enum import Enum
from typing import Any

logger = logging.getLogger(__name__)


class Op(Enum):
"""Hookable operations.

Each member names a point in the training pipeline where pre/post
callbacks can be registered. The string value is used as the span name
in OTel traces and in log messages.
"""

# head-node operations (train.py / train_async.py)
ITERATION = "iteration"
GENERATE = "generate"
TRAIN = "train"
UPDATE_WEIGHTS = "update_weights"
SAVE_MODEL = "save_model"
EVAL = "eval"
OFFLOAD_TRAIN = "offload_train"
OFFLOAD_ROLLOUT = "offload_rollout"
ONLOAD_ROLLOUT_WEIGHTS = "onload_rollout_weights"
ONLOAD_ROLLOUT_KV = "onload_rollout_kv"
ASYNC_ROLLOUT_SYNC = "async_rollout_sync"



_pre: dict[Op, list[Callable[..., Any]]] = {}
_post: dict[Op, list[Callable[..., Any]]] = {}


def on_pre(op: Op, fn: Callable[..., Any]) -> Callable[..., Any]:
"""Register a pre-operation callback.

The callback receives ``rollout_id`` plus any additional kwargs passed
to ``hook()``, plus state injected by earlier pre callbacks as
``**kwargs``. May return a ``dict`` to inject state into subsequent
pre and post callbacks.

Args:
op: The operation to attach to.
fn: Callback function.

Returns:
``fn`` unchanged (can be used as a decorator).
"""
_pre.setdefault(op, []).append(fn)
logger.debug("Registered pre %s -> %s", op.value, getattr(fn, "__name__", repr(fn)))
return fn


def on_post(op: Op, fn: Callable[..., Any]) -> Callable[..., Any]:
"""Register a post-operation callback.

The callback receives ``rollout_id`` plus any additional kwargs, plus
all injected state, plus ``error: BaseException | None`` (``None`` on
success). Called after the operation completes, even on exception.

Args:
op: The operation to attach to.
fn: Callback function.

Returns:
``fn`` unchanged (can be used as a decorator).
"""
_post.setdefault(op, []).append(fn)
logger.debug("Registered post %s -> %s", op.value, getattr(fn, "__name__", repr(fn)))
return fn


@contextmanager
def hook(op: Op, rollout_id: int, **kwargs: Any) -> Generator[dict[str, Any], None, None]:
"""Wrap an operation with registered pre/post callbacks.

Yields a dict of state injected by pre callbacks. Post callbacks fire
in the ``finally`` block, so they run even if the body raises.

When no callbacks are registered for ``op``, this is a near-zero-cost
no-op that yields an empty dict.

Args:
op: The operation being performed.
rollout_id: Current training iteration index.
**kwargs: Additional attributes forwarded to callbacks (e.g. rank, gpu).

Yields:
Dict of state injected by pre callbacks (empty if none registered).
"""
pre_fns = _pre.get(op, ())
post_fns = _post.get(op, ())
if not pre_fns and not post_fns:
yield {}
return

call_kwargs: dict[str, Any] = {"rollout_id": rollout_id, **kwargs}
injected: dict[str, Any] = {}
for fn in list(pre_fns):
try:
result = fn(**call_kwargs, **injected)
if isinstance(result, dict):
injected.update(result)
except Exception:
logger.warning("Pre %s callback %s raised", op.value, getattr(fn, "__name__", repr(fn)), exc_info=True)
try:
yield injected
except BaseException as exc:
injected["error"] = exc
raise
finally:
injected.setdefault("error", None)
for fn in list(post_fns):
try:
fn(**call_kwargs, **injected)
except Exception:
logger.warning(
"Post %s callback %s raised", op.value, getattr(fn, "__name__", repr(fn)), exc_info=True
)


def has_callbacks(op: Op) -> bool:
"""True if any pre or post callbacks are registered for ``op``."""
return bool(_pre.get(op)) or bool(_post.get(op))


def clear(op: Op | None = None) -> None:
"""Remove callbacks.

Args:
op: If given, clear only this operation's callbacks.
If ``None``, clear all callbacks for all operations.
"""
if op is None:
_pre.clear()
_post.clear()
else:
_pre.pop(op, None)
_post.pop(op, None)
Loading