Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713
Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713cspades wants to merge 14 commits intoNVIDIA:mainfrom
Conversation
50da1dc to
925d022
Compare
Greptile SummaryThis PR adds Torch DCP (Distributed Checkpoint) compatibility for FSDP2 × TP strided sharding across all Key changes and observations:
Confidence Score: 4/5
|
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
Outdated
Show resolved
Hide resolved
4ec2947 to
dbb9d14
Compare
fcdd5bd to
c912f5b
Compare
bc82f02 to
267f1df
Compare
|
/te-ci L1 pytorch |
f0b3cae to
af7362a
Compare
9435382 to
15df86f
Compare
|
/te-ci L1 pytorch |
|
For some reason after 2.3k training steps, I start to get NaNs: https://wandb.ai/nvidia/bionemo-recipes/runs/nmzugu0a?nw=nwusercye_nv Restarting from this checkpoint and around 500 steps later same thing. |
Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
…ess. Signed-off-by: Cory Ye <cye@nvidia.com>
… are still model parity tested. Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
82780a1 to
8ed5cc8
Compare
|
/te-ci L1 pytorch |
| module.reset_parameters() | ||
|
|
||
| # Run a training step to initialize FSDP2 lazy state and update quantization | ||
| # scales before testing the allgather. Block-scaling formats (Float8BlockScaling, |
There was a problem hiding this comment.
I believe Float8Blockscaling allgather should work now right?
There was a problem hiding this comment.
I think so, this comment might be out of date - I believe I only xFail the NVFP4 test instance.
| input_data = torch.randn(inp_shape, device=device) | ||
| target = torch.randn(inp_shape, device=device) | ||
| nvfp4_ctx = ( | ||
| torch.autocast(device_type="cuda", dtype=torch.bfloat16) |
There was a problem hiding this comment.
Why seperate nvfp4 context? In general, adding multiple context manager adds CPU overheads in the training loop.
There was a problem hiding this comment.
Copypasta from @pstjohn 's FusedAdam PR - Peter can you remind me what was the need for this double-context? Thanks!
|
|
||
|
|
||
| @dataclass | ||
| class AppState(Stateful): |
There was a problem hiding this comment.
This seems like a useful class. With some things like extra state specific to TE. Might make sense to move it TE distributed module. Thoughts @cspades ?
There was a problem hiding this comment.
Yup, also add a TODO for both of those two hacks:
- The empty param hack is mainly for biases, since TE uses an empty Tensor for some layers. In general, I believe it would be much cleaner to just... set the
self.biasor multiple biases all toNonebut I think there must have been a reason for empty Tensors in the first place. (And.... it's in a lot of people's checkpoints by now...) _extra_stateis immensely confusing due to its serialization, I've seen it work and not work, or change shape if you load a brand new model vs. load into a pre-trained model (i.e. re-load after running a few training steps). I hope this is an acceptable throwaway!
| # TransformerEngine uses empty Tensors for dummy Parameters. | ||
| optimizer_state_dict["state"][fqn] = {} | ||
| if fqn.endswith("_extra_state"): | ||
| # Evict `_extra_state` quantization data from model checkpoint. |
There was a problem hiding this comment.
If this is evicted, how do we make sure it is updated correctly after load from checkpoint?
There was a problem hiding this comment.
We don't, there's a pretty nasty 10% loss difference for specific test cases and it scales poorly with the model size, bigger model, the bigger the disparity.
I'm assuming it is caused by not checkpointing extra state, but as mentioned in the above comment, I need some help understanding how to support it. 🙏🏻
| ( | ||
| # FSDP | ||
| [NUM_PROCS], | ||
| # HSDP |
There was a problem hiding this comment.
Does this work ok if NUM_PROCS < 4? i.e lets say NUM_PROCS = 2. TP dimension will be 0. Curious what happens in that case?
There was a problem hiding this comment.
Right, I have this filter below:
parallel_size = math.prod(x for x in sharding_dims if x != 0)
if NUM_PROCS < parallel_size:
pytest.skip(
f"Insufficient devices ({NUM_PROCS}) to test sharding configuration: {sharding_dims}"
)
If you want this test to run with 2 GPUs, maybe we can add a test case that tests pure TP, no FSDP? Maybe just have a separate test module for this! (It'll be easy to deal with without FSDP2 annoyances and order-of-operations.)
|
|
||
| if tp_mesh is not None or weight_mesh is not None: | ||
| # Apply DeviceMesh and DTensor-related modifications. | ||
| self.set_device_mesh(tp_mesh=tp_mesh, weight_mesh=weight_mesh) |
There was a problem hiding this comment.
in set_device_mesh function, it says weight_mesh is not necessary, but we call it only if both tp_mesh and weight_mesh is not None. So it should not include the condition weight_mesh is not none right?
There was a problem hiding this comment.
Actually OR logic:
tp_mesh is not None or weight_mesh is not None
Your intuition is correct - set_device_mesh handles both arguments completely separately, and if one or the other is None, it won't do anything for that. It should be a switch / case function for DTensors.
| else device_mesh.get_group() | ||
| ) | ||
| quantizer.amax_reduction_group = amax_reduction_group | ||
| quantizer.amax_reduction_group = device_mesh.get_group() |
There was a problem hiding this comment.
Which group will it return in case of multiple dimensions? For instance if weight is both FSDP-TP sharded then will this give the FSDP dim or TP dim?
There was a problem hiding this comment.
This will actually error out if your device mesh is not pure FSDP (or just 1 dimension).
This is the backwards-incompatibility I mentioned a long while back, we can maybe hack in a WAR to support the old API (like maybe, pick the first mesh dimension, but obviously it's a very bad hack) but the right way is to expose this publicly so the user can tell us what the weight sharding mesh is!
| if isinstance(weight, DTensor): | ||
| weight = weight.to_local() |
There was a problem hiding this comment.
Shouldnt we use _extract_trainable_tensor_from_dtensor here?
There was a problem hiding this comment.
Also applicable to couple in other places in ops folder
There was a problem hiding this comment.
Oh, this one just seems to work fine, but I think you are right. This should definitely use that utility, there is no reason not to.
| instance._quantizer = quantizer.copy() if quantizer is not None else None | ||
| instance._fp8_dtype = fp8_dtype | ||
| instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales | ||
| instance._default_storage = torch.UntypedStorage(1, device=torch.cuda.current_device()) |
There was a problem hiding this comment.
I am leaning towards creating this default_storage on CPU instead due to coupe of reasons
- Since Idea of default_storage here is to show unique identity, keeping it on CPU/GPU shouldnt matter
- Calling torch cuda current device on every single Tensor creation has python overheads)
There was a problem hiding this comment.
Agreed. I also want us to confirm with @timmoon10 if maybe nullptr or empty storage "works". Or otherwise, if a 1-byte default TE Tensor storage size is acceptable or breaks our user's logic.
I believe Torch DCP checkpoint load/save do use untyped_storage() but technically if our row-wise and col-wise data is None, then maybe the data_ptr() > 0 check correctly does not checkpoint the weight? Or do we want to checkpoint the weight anyway?
There was a problem hiding this comment.
Wait, one reason I put it on the same device (GPU) is because I am not 100% sure if we want DCP to write QuantizedTensors' data to different devices depending on the state of QuantizedTensor.
Maybe that works, since it looks like TE quantizes on-the-fly, on the right devices, so we can just assume that this storage always belongs on CPU.
There was a problem hiding this comment.
I guess we can't get around the storage logic being ad hoc and delicate, but it seems putting the storage on CPU is too much. To reduce CPU overhead during initialization, how about we construct the storage lazily?
def untyped_storage(self):
...
# Return dummy storage if there is no local data
if self._default_storage is None:
self._default_storage = torch.UntypedStorage(1, device=self.device)
return self._default_storage| integration with Torch DCP checkpointing. This method should only be invoked when | ||
| using DTensor parameters, e.g. when using FSDP2 or DCP. | ||
|
|
||
| When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically |
There was a problem hiding this comment.
Why is it that we havent added tp_mesh and weight_mesh to the constrictors of rmsnorm and layer_norm? But for every other layer we have?
There was a problem hiding this comment.
Like the ops these modules call? Or that we don't TP shard norms so it's just Replicate() and there's no need to set the weight-sharding mesh.
Now that I think about it, maybe a loophole where you make DTensors but the amax reduction group logic will take the Replicate() mesh, but I think it won't matter if you reduce on a larger group if this norm weight is replicated across GPUs. (It will affect non-max / non-mean reductions though, like sum. Should be an obvious way to do this in TEBaseModule.)
| param = getattr(self, bias) | ||
| placements = (Replicate(),) | ||
| if self.parallel_mode == "column": | ||
| placements = (Shard(dim=0),) |
There was a problem hiding this comment.
I am wondering if we can make all the set_device_mesh function share a helper set_tp_mesh
defined in base.py that takes in a dictionary of {parameter name: parallel_mode} and tp_mesh that converts the parameters to Dtensors and use that in set_device_mesh of each module?
Something like this
def set_tp_mesh(self, param_mode_dict: dict, tp_mesh: Optional[DeviceMesh]):
There was a problem hiding this comment.
And put the big docstring that we have over there in base.py
. The docstring seems to be repeated in all places.
There was a problem hiding this comment.
Wait, every module's TP mesh logic is different. They all have different Shard() and Replicate() patterns, the placement implementations are not common to every TEBaseModule!
There was a problem hiding this comment.
Isnt the logic of selecting Shard vs Replicate and dimension decidable based on parallel_mode and the type of parameter? It seemed like that to me.
|
Generally LGTM @cspades. Lets get it merged after comments are addressed. |
cspades
left a comment
There was a problem hiding this comment.
Answered all comments, they were all good questions.
| module.reset_parameters() | ||
|
|
||
| # Run a training step to initialize FSDP2 lazy state and update quantization | ||
| # scales before testing the allgather. Block-scaling formats (Float8BlockScaling, |
There was a problem hiding this comment.
I think so, this comment might be out of date - I believe I only xFail the NVFP4 test instance.
| input_data = torch.randn(inp_shape, device=device) | ||
| target = torch.randn(inp_shape, device=device) | ||
| nvfp4_ctx = ( | ||
| torch.autocast(device_type="cuda", dtype=torch.bfloat16) |
There was a problem hiding this comment.
Copypasta from @pstjohn 's FusedAdam PR - Peter can you remind me what was the need for this double-context? Thanks!
|
|
||
|
|
||
| @dataclass | ||
| class AppState(Stateful): |
There was a problem hiding this comment.
Yup, also add a TODO for both of those two hacks:
- The empty param hack is mainly for biases, since TE uses an empty Tensor for some layers. In general, I believe it would be much cleaner to just... set the
self.biasor multiple biases all toNonebut I think there must have been a reason for empty Tensors in the first place. (And.... it's in a lot of people's checkpoints by now...) _extra_stateis immensely confusing due to its serialization, I've seen it work and not work, or change shape if you load a brand new model vs. load into a pre-trained model (i.e. re-load after running a few training steps). I hope this is an acceptable throwaway!
| # TransformerEngine uses empty Tensors for dummy Parameters. | ||
| optimizer_state_dict["state"][fqn] = {} | ||
| if fqn.endswith("_extra_state"): | ||
| # Evict `_extra_state` quantization data from model checkpoint. |
There was a problem hiding this comment.
We don't, there's a pretty nasty 10% loss difference for specific test cases and it scales poorly with the model size, bigger model, the bigger the disparity.
I'm assuming it is caused by not checkpointing extra state, but as mentioned in the above comment, I need some help understanding how to support it. 🙏🏻
| ( | ||
| # FSDP | ||
| [NUM_PROCS], | ||
| # HSDP |
There was a problem hiding this comment.
Right, I have this filter below:
parallel_size = math.prod(x for x in sharding_dims if x != 0)
if NUM_PROCS < parallel_size:
pytest.skip(
f"Insufficient devices ({NUM_PROCS}) to test sharding configuration: {sharding_dims}"
)
If you want this test to run with 2 GPUs, maybe we can add a test case that tests pure TP, no FSDP? Maybe just have a separate test module for this! (It'll be easy to deal with without FSDP2 annoyances and order-of-operations.)
|
|
||
| if tp_mesh is not None or weight_mesh is not None: | ||
| # Apply DeviceMesh and DTensor-related modifications. | ||
| self.set_device_mesh(tp_mesh=tp_mesh, weight_mesh=weight_mesh) |
There was a problem hiding this comment.
Actually OR logic:
tp_mesh is not None or weight_mesh is not None
Your intuition is correct - set_device_mesh handles both arguments completely separately, and if one or the other is None, it won't do anything for that. It should be a switch / case function for DTensors.
| else device_mesh.get_group() | ||
| ) | ||
| quantizer.amax_reduction_group = amax_reduction_group | ||
| quantizer.amax_reduction_group = device_mesh.get_group() |
There was a problem hiding this comment.
This will actually error out if your device mesh is not pure FSDP (or just 1 dimension).
This is the backwards-incompatibility I mentioned a long while back, we can maybe hack in a WAR to support the old API (like maybe, pick the first mesh dimension, but obviously it's a very bad hack) but the right way is to expose this publicly so the user can tell us what the weight sharding mesh is!
| integration with Torch DCP checkpointing. This method should only be invoked when | ||
| using DTensor parameters, e.g. when using FSDP2 or DCP. | ||
|
|
||
| When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically |
There was a problem hiding this comment.
Like the ops these modules call? Or that we don't TP shard norms so it's just Replicate() and there's no need to set the weight-sharding mesh.
Now that I think about it, maybe a loophole where you make DTensors but the amax reduction group logic will take the Replicate() mesh, but I think it won't matter if you reduce on a larger group if this norm weight is replicated across GPUs. (It will affect non-max / non-mean reductions though, like sum. Should be an obvious way to do this in TEBaseModule.)
| instance._quantizer = quantizer.copy() if quantizer is not None else None | ||
| instance._fp8_dtype = fp8_dtype | ||
| instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales | ||
| instance._default_storage = torch.UntypedStorage(1, device=torch.cuda.current_device()) |
There was a problem hiding this comment.
Agreed. I also want us to confirm with @timmoon10 if maybe nullptr or empty storage "works". Or otherwise, if a 1-byte default TE Tensor storage size is acceptable or breaks our user's logic.
I believe Torch DCP checkpoint load/save do use untyped_storage() but technically if our row-wise and col-wise data is None, then maybe the data_ptr() > 0 check correctly does not checkpoint the weight? Or do we want to checkpoint the weight anyway?
Summary
(H/F)SDP2 x TPstrided sharding, andDTensorFP8 parameters for Torch DCP checkpointing, across allTransformerEngineBaseModule(s).GroupedLinear, pending FSDP2 standalone pipe-cleaning. All other modules undertransformer_engine.pytorch.modulesare supported.FusibleOperationsupport is also a WIP, except forLayerNormorRMSNormwhich are TE modules.DTensor-based TP when unified by Torch DCP! In the Llama3 recipe, we useDTensor-based TP on thetorch.nn.Embedding, TransformerEngine-based TP on the LM head, and weight-tie the LM head to thetorch.nn.Embedding, which is why we do not need to callset_device_meshfor the LM head!Usage / Documentation
(
tp_meshandweight_meshcan also be passed inTEModule.__init__.)Details
DTensor Lifecycle in TransformerEngine
__init__metadevice with the appropriatetp_sizeand TP sharding strategy, e.g.parallel_modeandsequence_parallel.TransformerEngineModule.set_device_mesh(tp_mesh, weight_mesh)DTensorwith appropriate TPplacement(s) based on the TP sharding strategy specified in__init__, usingtransformer_engine.pytorch.distributed._convert_param_to_dtensor_param.tp_meshis a 1-DDeviceMeshcontaining the TPProcessGroupthat will be registered with the TransformerEngine module.weight_meshis the 1-DDeviceMeshcontaining theProcessGroupthat shards TransformerEngine module weights, the flattened combination of groups such as FSDP and TP. Specifically, it excludes non-weight groups such as DP-Replicate when using HSDP or HSDP-TP and is mainly required for per-Tensor scaling recipes likeFloat8CurrentScaling.fully_shard(which responds to the TP placements) and prior toreset_parameters(defer_init=False), which quantizes parameters.__init__(tp_mesh, weight_mesh)for supported TransformerEngine modules.fully_shardshards the TransformerEngine model with FSDP2.fully_shardencounters TP sharding ondim=0, it will use a_StridedShardfor DP. Put simply, this "pre-shards" the data prior to sharding on the current placement, followed by concatenating the pre-shards to get strided shards that will be re-sharded by the next placement. This effectively reverses the sharding order when processing the placements from left-to-right, and distributes shards as if we sharded on TP first, then FSDP, as required, even though DP appears before TP in theDeviceMeshandDTensor.placements. (SeeAppendixfor visualization of this sharding strategy.)reset_parametersis called if using meta device initialization.fully_shard. (Note that this essentially shares the same properties as the compute weight besides shape, and supporting tools such asFusedAdammust be used to properly handle high-precision main weights.)Tensoris actually a TP-shardedDTensor, which deviates from the original FSDP2 paradigm where the all-gatheredTensoris fully-unsharded and theDTensorwrapping is discarded. To support theseDTensorcompute weights in TransformerEngine modules, we utilizetransformer_engine.pytorch.distributed._extract_trainable_tensor_from_dtensorto localize theDTensorand also inheritrequires_gradattribute from theDTensorparameter as the localTensorhas this un-set duringDTensor.from_local(Tensor)for FP8 parameters specifically!Tensorgradient is converted toDTensorand attached to theDTensor.gradattribute. Handled by DTensor <> Tensor Autograd conversion functions, and in the case ofFusibleOperation, casted during the backward implementation.QuantizedTensorStorageNone, we senduntyped_storage()to a default 1-byte storage that unblocks DCP checkpoint loading assertions using this as a definition for "emptiness". This is because a storage of 0 bytes is adata_ptr() = nullptrand breaks DCP.untyped_storageis not used anywhere in TransformerEngine, it may break code that usesStorageto figure out if a Tensor is empty or not, as nowQuantizedTensorstorage will always be a 1-byte storage even when both row and column data are not set. Those checks would instead need to compare the storage size against 1 byte instead of 0 bytes.Bugs
"shard"was the presumed weight sharding sub-mesh in theDTensor.device_mesh. Now, users can precisely specify their own custom weight-shardingDeviceMeshfor per-tensoramax_reduction_groupvia theset_device_mesh(weight_mesh)API.TransformerEngineBaseModule:self.quantizers = {"scaling_fwd": [], "scaling_bwd": []}Testing
mainvs.cspades:cye/fsdp2-tp-dcpwith Megatron-LMmainon PyTorch25.11DelayedScalinghas DCP save/load disparity issues, i.e. on the scale of+/-1to theuint8parameter checkpoint!Appendix
_StridedShard- Using FSDP2 x TP Strided-ShardingWhen
redistribute'ing a global DTensor to(_StridedShard(dim=0, sf=2), Shard(dim=0)),DTensorwill perform the following steps:Shardplacements to the right of_StridedShard. (In the above example, since TP=2, the factor is 2.)[0 1 2 3 4 5 6 7] -> [0 1 2 3] and [4 5 6 7].fully_shard, this has already been done via initializing the TransformerEngine module with TP and calling_convert_param_to_dtensor_param!_StridedShard.[0] [1] [2] [3]and[4] [5] [6] [7][0 4] [1 5] [2 6] [3 7], which are assigned to the_StridedShardranks.[0 1] [2 3] [4 5] [6 7]!Shardplacement.[0] [4]/[1] [5]/[2] [6]/[3] [7], which are assigned to theShardranks.[0] [1]/[2] [3]/[4] [5]/[6] [7]!PyTorch also supports the inverse / un-sharding of this
redistribute, which is literally the inverse of these simple operations! (Though things get a bit more complicated with un-even shards from odd-numbered dimension sizes.)Type of change
Changes
Please list the changes introduced in this PR:
Checklist: