Add FSDP v2 multi-GPU parameter sharding for LoRA training#774
Open
luke9705 wants to merge 1 commit intoostris:mainfrom
Open
Add FSDP v2 multi-GPU parameter sharding for LoRA training#774luke9705 wants to merge 1 commit intoostris:mainfrom
luke9705 wants to merge 1 commit intoostris:mainfrom
Conversation
Enable distributed training for LoRA fine-tuning using PyTorch FSDP v2. The frozen transformer is sharded across GPUs while LoRA parameters are trained with synchronized gradients, reducing per-GPU memory and enabling training of large models (e.g. 14B params) across multiple GPUs. Core changes: - FSDP v2 auto-sharding via accelerate with transformer block wrapping - Collective save/load for checkpoints and optimizer state across ranks - FSDP-aware sampling: all ranks run forward, rank 0 saves images - Graceful stop coordination via dist.broadcast - Text encoder forced offload under FSDP (TE and sharded transformer cannot coexist on GPU); unload_text_encoder now implies cache_text_embeddings - Step metadata broadcast on resume so all ranks agree on step number UI changes: - Multi-GPU toggle buttons replacing single-GPU dropdown - FSDP cascade: multi-GPU auto-enables unload_text_encoder and cache_text_embeddings - GPU overlap detection in queue for multi-GPU resource management - Multi-GPU job grouping in jobs table
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Enable distributed LoRA fine-tuning using PyTorch FSDP v2 (Fully Sharded Data Parallel). The frozen transformer is sharded across GPUs while LoRA parameters are trained with synchronized gradients, reducing per-GPU VRAM and enabling training of large models across multiple GPUs.
Motivation
This implementation was born out of a practical need: we needed to train FLUX.2 and had two RTX 6000 Pro GPUs, but the model couldn't fit entirely on each GPU without quantization. Dynamic quantization significantly slows down training due to the constant quant/dequant overhead, which pushed us to implement FSDP v2 support for ai-toolkit. We now use it daily to accelerate our training sessions.
Changes
Core:
acceleratewith per-transformer-block wrappingdist.broadcastunload_text_encodernow impliescache_text_embeddings)UI:
unload_text_encoderandcache_text_embeddingsDesign Decisions
all_reduceafter each optimizer step. This avoids the complexity of sharding hook-injected modules.accelerate launchfor multi-GPU. The UI spawns distributed jobs viaaccelerate launch --num_processes=N --gpu_ids=...rather thantorchrun, keeping compatibility with the existing accelerate-based training pipeline.DISCLAIMER: It was largely tested on Flux.2 DEV so, for other models, some bugs may still occur.