Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 177 additions & 1 deletion slime/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,181 @@
from slime.utils.types import RolloutBatch

from ...utils import logging_utils
from .cp_utils import get_sum_of_sample_mean, slice_with_cp
from .cp_utils import all_gather_with_cp, get_sum_of_sample_mean, slice_log_prob_with_cp, slice_with_cp

logger = logging.getLogger(__name__)


def _num_image_tokens_from_grid(grid_thw: torch.Tensor, merge_h: int = 2, merge_w: int = 2) -> int:
_, h, w = grid_thw.tolist()
# tpool_patch_merger averages over the temporal dimension T, so the
# actual number of tokens per image depends only on the spatial grid.
return (h // merge_h) * (w // merge_w)


def _expand_image_tokens_for_sample(
tokens: torch.Tensor,
loss_mask: torch.Tensor,
grid_thws: torch.Tensor,
media_token_id: int = 163605,
) -> tuple[torch.Tensor, torch.Tensor]:
if grid_thws is None or len(grid_thws) == 0:
return tokens, loss_mask

placeholder_positions = (tokens == media_token_id).nonzero(as_tuple=True)[0]
if len(placeholder_positions) == 0:
return tokens, loss_mask

num_placeholders = len(placeholder_positions)
num_grids = len(grid_thws)
expected_total_image_tokens = sum(_num_image_tokens_from_grid(grid_thw) for grid_thw in grid_thws)
if num_placeholders == expected_total_image_tokens:
# Already pre-expanded. Keep this helper idempotent because the same
# rollout batch may pass through multiple normalization paths.
return tokens, loss_mask
if num_placeholders != num_grids:
logger.warning(
"K25 multimodal token mismatch before training: placeholders=%s, grids=%s",
num_placeholders,
num_grids,
)

merge_h, merge_w = 2, 2
prompt_len = len(tokens) - len(loss_mask)

expanded_tokens = tokens.clone()
expanded_mask = loss_mask.clone()

for i, pos in enumerate(reversed(placeholder_positions)):
pos = pos.item()
grid_idx = num_placeholders - 1 - i
if grid_idx >= num_grids:
continue

_, h, w = grid_thws[grid_idx].tolist()
num_image_tokens = (h // merge_h) * (w // merge_w)

expanded_placeholder = torch.full(
(num_image_tokens,), media_token_id, dtype=expanded_tokens.dtype, device=expanded_tokens.device
)
expanded_tokens = torch.cat([expanded_tokens[:pos], expanded_placeholder, expanded_tokens[pos + 1 :]])

if pos >= prompt_len:
mask_pos = pos - prompt_len
expanded_mask_tokens = torch.zeros(
num_image_tokens, dtype=expanded_mask.dtype, device=expanded_mask.device
)
expanded_mask = torch.cat([expanded_mask[:mask_pos], expanded_mask_tokens, expanded_mask[mask_pos + 1 :]])

return expanded_tokens, expanded_mask


def _collect_multimodal_grid_inputs(
multimodal_train_inputs: Sequence[dict[str, torch.Tensor] | None] | None,
) -> list[dict[str, torch.Tensor] | None]:
if multimodal_train_inputs is None:
return []

mm_inputs_list = []
for mm_dict in multimodal_train_inputs:
if mm_dict is not None and "grid_thws" in mm_dict:
mm_inputs_list.append(mm_dict)
else:
mm_inputs_list.append(None)
return mm_inputs_list


def _batch_has_media_placeholders(
tokens: Sequence[torch.Tensor],
media_token_id: int = 163605,
) -> bool:
return any((token_tensor == media_token_id).any().item() for token_tensor in tokens)


def expand_multimodal_rollout_data_in_place(
rollout_data: RolloutBatch,
media_token_id: int = 163605,
qkv_format: str = "thd",
) -> None:
multimodal_train_inputs = rollout_data.get("multimodal_train_inputs", None)
mm_inputs_list = _collect_multimodal_grid_inputs(multimodal_train_inputs)
if not mm_inputs_list or not any(mm is not None for mm in mm_inputs_list):
return

tokens = rollout_data["tokens"]
if not _batch_has_media_placeholders(tokens, media_token_id=media_token_id):
return

loss_masks = rollout_data["loss_masks"]
old_total_lengths = list(rollout_data["total_lengths"])
old_response_lengths = list(rollout_data["response_lengths"])

token_or_mask_changed = False
expanded_tokens = []
expanded_loss_masks = []
expanded_total_lengths = []
expanded_response_lengths = []

for i, (token_tensor, loss_mask_tensor) in enumerate(zip(tokens, loss_masks, strict=False)):
if mm_inputs_list[i] is not None:
new_tokens, new_loss_mask = _expand_image_tokens_for_sample(
token_tensor,
loss_mask_tensor,
mm_inputs_list[i]["grid_thws"],
media_token_id=media_token_id,
)
token_or_mask_changed = token_or_mask_changed or (
(new_tokens.size(0) != token_tensor.size(0)) or (new_loss_mask.size(0) != loss_mask_tensor.size(0))
)
expanded_tokens.append(new_tokens)
expanded_loss_masks.append(new_loss_mask)
expanded_total_lengths.append(new_tokens.size(0))
expanded_response_lengths.append(new_loss_mask.size(0))
else:
expanded_tokens.append(token_tensor)
expanded_loss_masks.append(loss_mask_tensor)
expanded_total_lengths.append(old_total_lengths[i])
expanded_response_lengths.append(old_response_lengths[i])

rollout_data["tokens"] = expanded_tokens
rollout_data["loss_masks"] = expanded_loss_masks
rollout_data["total_lengths"] = expanded_total_lengths
rollout_data["response_lengths"] = expanded_response_lengths

metadata_changed = (expanded_total_lengths != old_total_lengths) or (
expanded_response_lengths != old_response_lengths
)
if metadata_changed:
cp_size = mpu.get_context_parallel_world_size()
if cp_size > 1 and qkv_format == "thd":
for key in ("rollout_log_probs", "teacher_log_probs"):
values = rollout_data.get(key)
if not values:
continue
rollout_data[key] = [
slice_log_prob_with_cp(
all_gather_with_cp(value, old_total_length, old_response_length),
new_total_length,
new_response_length,
qkv_format=qkv_format,
)
for value, old_total_length, old_response_length, new_total_length, new_response_length in zip(
values,
old_total_lengths,
old_response_lengths,
expanded_total_lengths,
expanded_response_lengths,
strict=False,
)
]
logger.info(
"Adjusted multimodal rollout metadata for Kimi VL: "
f"token_or_mask_changed={token_or_mask_changed}, "
f"total_lengths_changed={expanded_total_lengths != old_total_lengths}, "
f"response_lengths_changed={expanded_response_lengths != old_response_lengths}"
)


def get_batch(
data_iterator: "DataIterator",
keys: Sequence[str],
Expand Down Expand Up @@ -56,6 +226,10 @@ def get_batch(
if "dynamic_global_batch_size" in data_iterator.rollout_data:
batch["dynamic_global_batch_size"] = data_iterator.rollout_data["dynamic_global_batch_size"]

# Keep a local normalization path here as a no-op safety net in case
# batches reach get_batch without the rollout-level preprocessing step.
expand_multimodal_rollout_data_in_place(batch, qkv_format=qkv_format)

tokens = batch["tokens"]
# use 0 as the pad token id should be fine?
pad_token_id = 0
Expand Down Expand Up @@ -310,6 +484,8 @@ def get_data_iterator(
- `data_iterators`: list of `DataIterator`, one per VPP stage (size 1 if VPP disabled)
- `num_microbatches`: list[int], one per local step in the rollout (length = steps)
"""
expand_multimodal_rollout_data_in_place(rollout_data, qkv_format=args.qkv_format)

dp_size = mpu.get_data_parallel_world_size(with_context_parallel=False)
dp_group = mpu.get_data_parallel_group()
vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
Expand Down
9 changes: 6 additions & 3 deletions slime/rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from slime.utils.http_utils import get, post
from slime.utils.misc import SingletonMeta, load_function
from slime.utils.processing_utils import (
build_processor_kwargs,
call_processor,
encode_image_for_rollout_engine,
load_processor,
load_tokenizer,
Expand Down Expand Up @@ -120,9 +120,12 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A
), f"Sample status is {sample.status}"

if state.processor and sample.multimodal_inputs and any(v is not None for v in sample.multimodal_inputs.values()):
processor_kwargs = build_processor_kwargs(sample.multimodal_inputs)
processor_output = state.processor(text=sample.prompt, **processor_kwargs)
processor_output = call_processor(state.processor, sample.prompt, sample.multimodal_inputs)
prompt_ids = processor_output["input_ids"][0]
if hasattr(prompt_ids, "tolist"):
prompt_ids = prompt_ids.tolist()
else:
prompt_ids = list(prompt_ids)
sample.multimodal_train_inputs = {
k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"]
} or None
Expand Down
3 changes: 2 additions & 1 deletion slime/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
except ImportError:
pq = None

from slime.utils.processing_utils import call_processor
from slime.utils.types import MultimodalTypes, Sample

from .timer import Timer
Expand Down Expand Up @@ -109,7 +110,7 @@ def filter_long_prompt(origin_samples: list[Sample], tokenizer, processor, max_l

for sample in multimodal:
multimodal_inputs = process_vision_info(sample.prompt, processor)
processor_output = processor(text=sample.prompt, **multimodal_inputs)
processor_output = call_processor(processor, sample.prompt, multimodal_inputs)
input_ids = processor_output["input_ids"][0]
if len(input_ids) <= max_length:
filtered_samples.append(sample)
Expand Down
25 changes: 25 additions & 0 deletions slime/utils/processing_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import inspect
import io
import json
import logging
Expand Down Expand Up @@ -36,6 +37,30 @@ def build_processor_kwargs(multimodal_inputs: dict | None = None) -> dict:
return result


def processor_requires_medias(processor) -> bool:
try:
params = inspect.signature(processor.__call__).parameters
return "medias" in params and "text" in params
except (TypeError, ValueError):
return hasattr(processor, "media_processor")


def call_processor(processor, text, multimodal_inputs: dict | None = None):
multimodal_inputs = multimodal_inputs or {}

# for kimi-vl & kimi-2.5
if processor_requires_medias(processor):
medias = []
if images := multimodal_inputs.get("images"):
medias.extend({"type": "image", "image": image} for image in images)
if videos := multimodal_inputs.get("videos"):
medias.extend({"type": "video", "video": video} for video in videos)
return processor(text=text, medias=medias)

kwargs = build_processor_kwargs(multimodal_inputs)
return processor(text=text, **kwargs)


def _try_load_glm4v_processor(name_or_path: str, **kwargs):
"""Fallback: manually construct a Glm4vProcessor for GLM-4.6V / GLM-4.5V models.

Expand Down