Skip to content

Add FSDP v2 multi-GPU parameter sharding for LoRA training#774

Open
luke9705 wants to merge 1 commit intoostris:mainfrom
luke9705:feature/fsdp-multi-gpu
Open

Add FSDP v2 multi-GPU parameter sharding for LoRA training#774
luke9705 wants to merge 1 commit intoostris:mainfrom
luke9705:feature/fsdp-multi-gpu

Conversation

@luke9705
Copy link
Copy Markdown

@luke9705 luke9705 commented Apr 2, 2026

Summary

Enable distributed LoRA fine-tuning using PyTorch FSDP v2 (Fully Sharded Data Parallel). The frozen transformer is sharded across GPUs while LoRA parameters are trained with synchronized gradients, reducing per-GPU VRAM and enabling training of large models across multiple GPUs.

Motivation

This implementation was born out of a practical need: we needed to train FLUX.2 and had two RTX 6000 Pro GPUs, but the model couldn't fit entirely on each GPU without quantization. Dynamic quantization significantly slows down training due to the constant quant/dequant overhead, which pushed us to implement FSDP v2 support for ai-toolkit. We now use it daily to accelerate our training sessions.

Changes

Core:

  • FSDP v2 auto-sharding via accelerate with per-transformer-block wrapping
  • Collective save/load for LoRA checkpoints and optimizer state across ranks
  • FSDP-aware sampling: all ranks run the forward pass, rank 0 saves images
  • Graceful distributed stop coordination via dist.broadcast
  • Text encoder forced offload under FSDP (unload_text_encoder now implies cache_text_embeddings)
  • Step metadata broadcast on resume so all ranks agree on training position

UI:

  • Multi-GPU toggle buttons replacing single-GPU dropdown
  • FSDP cascade: selecting multiple GPUs auto-enables unload_text_encoder and cache_text_embeddings
  • GPU overlap detection in job queue for multi-GPU resource management
  • Multi-GPU job grouping in the jobs table

Design Decisions

  • LoRA params are replicated, not sharded. LoRA parameters are small relative to the frozen transformer. They are kept as regular tensors on each rank with gradients synchronized via all_reduce after each optimizer step. This avoids the complexity of sharding hook-injected modules.
  • Text encoder is always offloaded under FSDP. The TE and sharded transformer cannot coexist on GPU during training. Text embeddings are pre-cached before the text encoder is unloaded. This is convenient because it frees up as much VRAM as possible for FSDP. That is especially useful in this training setup, since it can support models that would not otherwise fit on a single GPU.
  • accelerate launch for multi-GPU. The UI spawns distributed jobs via accelerate launch --num_processes=N --gpu_ids=... rather than torchrun, keeping compatibility with the existing accelerate-based training pipeline.
  • Rank-0 I/O pattern. Only rank 0 performs file writes (checkpoints, samples, logs). Other ranks participate in collective gather operations but do not write to disk.

DISCLAIMER: It was largely tested on Flux.2 DEV so, for other models, some bugs may still occur.

Enable distributed training for LoRA fine-tuning using PyTorch FSDP v2.
The frozen transformer is sharded across GPUs while LoRA parameters
are trained with synchronized gradients, reducing per-GPU memory and
enabling training of large models (e.g. 14B params) across multiple GPUs.

Core changes:
- FSDP v2 auto-sharding via accelerate with transformer block wrapping
- Collective save/load for checkpoints and optimizer state across ranks
- FSDP-aware sampling: all ranks run forward, rank 0 saves images
- Graceful stop coordination via dist.broadcast
- Text encoder forced offload under FSDP (TE and sharded transformer
  cannot coexist on GPU); unload_text_encoder now implies
  cache_text_embeddings
- Step metadata broadcast on resume so all ranks agree on step number

UI changes:
- Multi-GPU toggle buttons replacing single-GPU dropdown
- FSDP cascade: multi-GPU auto-enables unload_text_encoder and
  cache_text_embeddings
- GPU overlap detection in queue for multi-GPU resource management
- Multi-GPU job grouping in jobs table
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant