diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e31af9847811..b71a1ea959ed 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -832,11 +832,6 @@ def _get_dtype( return config, main_dtype -class PipelineParallel(Enum): - inputs = 0 - outputs = 1 - - class ModuleUtilsMixin: """ A few utilities for `torch.nn.Modules`, to be used as a mixin. @@ -1155,7 +1150,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH # A pipeline parallel plan specifying the layers which may not be present on all ranks when PP is enabled. For top-level # models, this attribute is currently defined in respective model code. For base models, it comes from # `config.base_model_pp_plan` during `post_init`. - _pp_plan: dict[str, PipelineParallel] | None = None + _pp_plan: dict[str, tuple[str, str]] = None # Advanced functionalities support supports_gradient_checkpointing: bool = False @@ -1380,7 +1375,13 @@ def tp_plan(self, plan: dict[str, str] | None): self._tp_plan = plan @pp_plan.setter - def pp_plan(self, plan: dict[str, tuple[str, str]]): + def pp_plan(self, plan: dict[str, tuple[str, str]] | None): + if plan is None: + self._pp_plan = {} + return + if not isinstance(plan, dict): + raise ValueError("Can only set a dictionary as `pp_plan`") + self._pp_plan = plan def dequantize(self, dtype=None): @@ -4385,12 +4386,14 @@ def supports_tp_plan(self): """ Returns whether the model has a tensor parallelism plan. """ - if self._tp_plan is not None: + # Check if model has a TP plan + if self._tp_plan: return True # Check if base model has a TP plan - if getattr(self.base_model, "_tp_plan", None) is not None: + if self.base_model._tp_plan: return True - if self.config.base_model_tp_plan is not None: + # Check if config has TP plan + if self.config.base_model_tp_plan: return True return False @@ -4404,10 +4407,14 @@ def tp_size(self): @property def supports_pp_plan(self): - if self._pp_plan is not None: + # Check if model has a PP plan + if self._pp_plan: return True # Check if base model has PP plan - if getattr(self.base_model, "_pp_plan", None) is not None: + if self.base_model._pp_plan: + return True + # Check if config has PP plan + if self.config.base_model_pp_plan: return True return False