Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions roll/datasets/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
# model_inputs for hf/deepspeed: input_id, attention_mask, pixel_values, image_grid_thw
padded_features = defaultdict(list)
un_padded_features = defaultdict(list)
mm_token_type_id_features = []
mm_feature_keys = set()
for feature in features:
# cannot process as batch directly though processor output as batch
Expand Down Expand Up @@ -165,6 +166,8 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
model_inputs.pop(key)
for key in filter(lambda k: k in model_inputs, self.padded_keys):
padded_features[key].append(model_inputs.pop(key)[0])
if "mm_token_type_ids" in model_inputs:
mm_token_type_id_features.append(torch.as_tensor(model_inputs.pop("mm_token_type_ids")[0]))
# mm feature fileds can be different because of mixed data
mm_feature_keys = mm_feature_keys.union(model_inputs.keys())
# to tensors except padded_keys which would be converted after padding
Expand Down Expand Up @@ -208,6 +211,22 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
return_tensors=self.return_tensors,
)
batch.update(un_padded_features)
if mm_token_type_id_features:
target_len = batch["input_ids"].shape[-1]
padded_mm_token_type_ids = []
for token_type_ids in mm_token_type_id_features:
pad_len = target_len - token_type_ids.shape[-1]
if pad_len < 0:
raise ValueError(
f"mm_token_type_ids length {token_type_ids.shape[-1]} exceeds padded input length {target_len}"
)
pad = torch.zeros(pad_len, dtype=token_type_ids.dtype, device=token_type_ids.device)
if self.tokenizer.padding_side == "left":
token_type_ids = torch.cat([pad, token_type_ids], dim=-1)
else:
token_type_ids = torch.cat([token_type_ids, pad], dim=-1)
padded_mm_token_type_ids.append(token_type_ids)
batch["mm_token_type_ids"] = torch.stack(padded_mm_token_type_ids, dim=0)

# other custom data fields: mainly for specific position_ids currently
# position_ids for qwen2-vl is optional and make sure it is a 3D tensor
Expand All @@ -226,6 +245,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
kwargs[key] = fun_params[key].default
extra_data = self.extra_data_provider(**kwargs)
batch.update(extra_data)
batch.pop("mm_token_type_ids", None)

# each field should be a tensor or np.array(val=list_data, dtype=object)
# to be stored in DataProto
Expand Down
33 changes: 20 additions & 13 deletions roll/models/model_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ def load_model(
freeze_model(model, model_args)
else:
model = setup_lora_training(config, model, model_args, is_trainable)
if not model_args.disable_gradient_checkpointing and hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()

if add_valuehead:
from trl import AutoModelForCausalLMWithValueHead
Expand Down Expand Up @@ -710,8 +712,6 @@ def get_extra_data_provider(model_name_or_path: str, processor=None):
if isinstance(model_type, str) and (("qwen2" in model_type) or (model_type in ("qwen3_vl", "qwen3_vl_moe"))):
import types

from transformers import BatchFeature # help define a object to accesss attr

def _call_get_rope_index(fn, input_ids: torch.LongTensor, **candidate_kwargs):
sig = inspect.signature(fn)
params = sig.parameters
Expand Down Expand Up @@ -745,17 +745,13 @@ def _call_get_rope_index(fn, input_ids: torch.LongTensor, **candidate_kwargs):
"<|vision_start|>"
)

dummy_self = BatchFeature(
{
"config": BatchFeature(
{
"vision_config": BatchFeature(vc),
"image_token_id": image_token_id,
"video_token_id": video_token_id,
"vision_start_token_id": vision_start_token_id,
}
)
}
dummy_self = types.SimpleNamespace(
config=types.SimpleNamespace(
vision_config=types.SimpleNamespace(**vc),
image_token_id=image_token_id,
video_token_id=video_token_id,
vision_start_token_id=vision_start_token_id,
)
)

is_tf_ge_4_52 = is_transformers_version_greater_than("4.52.0")
Expand All @@ -771,6 +767,9 @@ def _call_get_rope_index(fn, input_ids: torch.LongTensor, **candidate_kwargs):
elif model_type in ("qwen3_vl", "qwen3_vl_moe"):
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel

dummy_self.get_vision_position_ids = types.MethodType(
Qwen3VLModel.get_vision_position_ids, dummy_self
)
get_rope_index = types.MethodType(Qwen3VLModel.get_rope_index, dummy_self)
else:
if is_tf_ge_4_52:
Expand All @@ -787,8 +786,15 @@ def extra_data_provider(
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
mm_token_type_ids: Optional[torch.Tensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
):
if model_type in ("qwen3_vl", "qwen3_vl_moe") and mm_token_type_ids is None:
mm_token_type_ids = torch.zeros_like(input_ids)
if image_token_id is not None:
mm_token_type_ids = torch.where(input_ids == image_token_id, 1, mm_token_type_ids)
if video_token_id is not None:
mm_token_type_ids = torch.where(input_ids == video_token_id, 2, mm_token_type_ids)
# Keep kwargs to be resilient to HF signature changes between versions/models.
out = _call_get_rope_index(
get_rope_index,
Expand All @@ -797,6 +803,7 @@ def extra_data_provider(
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
attention_mask=attention_mask,
mm_token_type_ids=mm_token_type_ids,
)
rope_index = out[0]
# PumpkinComment:
Expand Down