Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion nemo_rl/models/megatron/community_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import os
from typing import Any, Optional

Expand Down Expand Up @@ -107,7 +108,20 @@ def import_model_from_hf_name(
config.num_layers_in_last_pipeline_stage = orig_num_layers_in_last_pipeline_stage
config.pipeline_dtype = orig_pipeline_dtype

bridge.save_megatron_model(megatron_model, output_path)
# Disable save optimizations that deadlock when the distributed world includes
# non-training ranks (e.g., non-colocated vLLM workers):
# - fully_parallel_save=False: bypasses FullyParallelSaveStrategyWrapper which
# calls all_gather_object on DP sub-groups containing non-participating ranks
# - validate_access_integrity=False: skips determine_global_metadata which calls
# all_gather_object on the default PG where some ranks may not participate
# Conditional for backward compat with older Megatron-Bridge versions.
save_kwargs = {}
sig = inspect.signature(bridge.save_megatron_model)
if "fully_parallel_save" in sig.parameters:
save_kwargs["fully_parallel_save"] = False
if "validate_access_integrity" in sig.parameters:
save_kwargs["validate_access_integrity"] = False
bridge.save_megatron_model(megatron_model, output_path, **save_kwargs)

# resetting mcore state
import megatron.core.rerun_state_machine
Expand Down
Loading