Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,11 +832,6 @@ def _get_dtype(
return config, main_dtype


class PipelineParallel(Enum):
inputs = 0
outputs = 1


class ModuleUtilsMixin:
"""
A few utilities for `torch.nn.Modules`, to be used as a mixin.
Expand Down Expand Up @@ -1155,7 +1150,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# A pipeline parallel plan specifying the layers which may not be present on all ranks when PP is enabled. For top-level
# models, this attribute is currently defined in respective model code. For base models, it comes from
# `config.base_model_pp_plan` during `post_init`.
_pp_plan: dict[str, PipelineParallel] | None = None
_pp_plan: dict[str, tuple[str, str]] = None

# Advanced functionalities support
supports_gradient_checkpointing: bool = False
Expand Down Expand Up @@ -1380,7 +1375,13 @@ def tp_plan(self, plan: dict[str, str] | None):
self._tp_plan = plan

@pp_plan.setter
def pp_plan(self, plan: dict[str, tuple[str, str]]):
def pp_plan(self, plan: dict[str, tuple[str, str]] | None):
if plan is None:
self._pp_plan = {}
return
if not isinstance(plan, dict):
raise ValueError("Can only set a dictionary as `pp_plan`")

self._pp_plan = plan

def dequantize(self, dtype=None):
Expand Down Expand Up @@ -4385,12 +4386,14 @@ def supports_tp_plan(self):
"""
Returns whether the model has a tensor parallelism plan.
"""
if self._tp_plan is not None:
# Check if model has a TP plan
if self._tp_plan:
Comment on lines +4389 to +4390
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shiuldn't we check if it's empty?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was checking for truthiness because the default value is None if post_init is not called

return True
# Check if base model has a TP plan
if getattr(self.base_model, "_tp_plan", None) is not None:
if self.base_model._tp_plan:
return True
if self.config.base_model_tp_plan is not None:
# Check if config has TP plan
if self.config.base_model_tp_plan:
return True
return False

Expand All @@ -4404,10 +4407,14 @@ def tp_size(self):

@property
def supports_pp_plan(self):
if self._pp_plan is not None:
# Check if model has a PP plan
if self._pp_plan:
Comment on lines +4410 to +4411
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, let's check if it's empty

return True
# Check if base model has PP plan
if getattr(self.base_model, "_pp_plan", None) is not None:
if self.base_model._pp_plan:
return True
# Check if config has PP plan
if self.config.base_model_pp_plan:
return True
return False

Expand Down
Loading