diff --git a/config/examples/train_flex_redux.yaml b/config/examples/train_flex_redux.yaml index 918de8427..2e44c5b67 100644 --- a/config/examples/train_flex_redux.yaml +++ b/config/examples/train_flex_redux.yaml @@ -56,7 +56,8 @@ config: batch_size: 3 gradient_accumulation: 2 - # captions are not needed for this training, we cache a blank proompt and rely on the vision encoder + # captions are not needed for this training, we cache a blank prompt and rely on the vision encoder + # unload_text_encoder now automatically caches all text embeddings before unloading unload_text_encoder: true loss_type: "mse" diff --git a/config/examples/train_lora_wan21_14b_24gb.yaml b/config/examples/train_lora_wan21_14b_24gb.yaml index 32babd14c..f38556099 100644 --- a/config/examples/train_lora_wan21_14b_24gb.yaml +++ b/config/examples/train_lora_wan21_14b_24gb.yaml @@ -67,7 +67,7 @@ config: ema_decay: 0.99 dtype: bf16 # required for 24GB cards - # this will encode your trigger word and use those embeddings for every image in the dataset + # automatically caches all text embeddings before unloading the text encoder unload_text_encoder: true model: # huggingface model name or path diff --git a/config/examples/train_lora_wan22_14b_24gb.yaml b/config/examples/train_lora_wan22_14b_24gb.yaml index 966f184f4..f603fffd0 100644 --- a/config/examples/train_lora_wan22_14b_24gb.yaml +++ b/config/examples/train_lora_wan22_14b_24gb.yaml @@ -55,12 +55,10 @@ config: # IMPORTANT: this is for Wan 2.2 MOE. It will switch training one stage or the other every this many steps switch_boundary_every: 10 - # required for 24GB cards. You must do either unload_text_encoder or cache_text_embeddings but not both - - # this will encode your trigger word and use those embeddings for every image in the dataset, captions will be ignored + # required for 24GB cards + # unload_text_encoder automatically caches all text embeddings before unloading + # either option works — unload_text_encoder implies cache_text_embeddings # unload_text_encoder: true - - # this will cache all captions in your dataset. cache_text_embeddings: true model: diff --git a/extensions_built_in/sd_trainer/DiffusionTrainer.py b/extensions_built_in/sd_trainer/DiffusionTrainer.py index 252da0031..f8d889e96 100644 --- a/extensions_built_in/sd_trainer/DiffusionTrainer.py +++ b/extensions_built_in/sd_trainer/DiffusionTrainer.py @@ -8,6 +8,7 @@ import threading import time import signal +import torch AITK_Status = Literal["running", "stopped", "error", "completed"] @@ -30,22 +31,29 @@ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): if self.is_ui_trainer: self.is_stopping = False + self.is_returning_to_queue = False # Create a thread pool for database operations self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) # Track all async tasks self._async_tasks = [] # Initialize the status self._run_async_operation(self._update_status("running", "Starting")) + self._last_speed_update = 0.0 self._stop_watcher_started = False - # self.start_stop_watcher(interval_sec=2.0) + self.start_stop_watcher(interval_sec=2.0) def start_stop_watcher(self, interval_sec: float = 5.0): """ Start a daemon thread that periodically checks should_stop() and terminates the process immediately when triggered. + In distributed mode, only rank 0 polls the DB; it kills the + entire process group so all ranks terminate. """ if not self.is_ui_trainer: return + # Only rank 0 should poll the DB to avoid SQLite locking issues + if self.accelerator.num_processes > 1 and not self.accelerator.is_main_process: + return if getattr(self, "_stop_watcher_started", False): return self._stop_watcher_started = True @@ -58,26 +66,34 @@ def _stop_watcher_thread(self, interval_sec: float): while True: try: if self.should_stop(): - # Mark and update status (non-blocking; uses existing infra) self.is_stopping = True - self._run_async_operation( - self._update_status("stopped", "Job stopped (remote)") - ) - # Best-effort flush pending async ops - try: - asyncio.run(self.wait_for_all_async()) - except RuntimeError: - pass - # Try to stop DB thread pool quickly - try: - self.thread_pool.shutdown(wait=False, cancel_futures=True) - except TypeError: - self.thread_pool.shutdown(wait=False) print("") print("****************************************************") print(" Stop signal received; terminating process. ") print("****************************************************") - os.kill(os.getpid(), signal.SIGINT) + if self.accelerator.num_processes > 1: + # FSDP: don't kill — let end_step_hook broadcast the flag + # so all ranks save a temp checkpoint together. + # Don't shut down thread_pool here; on_error() still needs it. + # Force kill after 5 min as a last-resort fallback. + for _ in range(150): + time.sleep(2) + os.killpg(os.getpgid(os.getpid()), signal.SIGTERM) + else: + # Single GPU: update status, flush, then SIGINT. + self._run_async_operation( + self._update_status("stopped", "Job stopped") + ) + try: + asyncio.run(self.wait_for_all_async()) + except RuntimeError: + pass + try: + self.thread_pool.shutdown(wait=False, cancel_futures=True) + except TypeError: + self.thread_pool.shutdown(wait=False) + os.kill(os.getpid(), signal.SIGINT) + break # don't loop — avoid repeated signals interrupting save time.sleep(interval_sec) except Exception: time.sleep(interval_sec) @@ -91,6 +107,10 @@ def _run_async_operation(self, coro): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) + # Prune completed tasks periodically to avoid unbounded growth + if len(self._async_tasks) > 100: + self._async_tasks = [t for t in self._async_tasks if not t.done()] + # Create a task and track it if loop.is_running(): task = asyncio.run_coroutine_threadsafe(coro, loop) @@ -101,40 +121,22 @@ def _run_async_operation(self, coro): loop.run_until_complete(task) async def _execute_db_operation(self, operation_func): - """Execute a database operation in a separate thread with retry on lock.""" + """Execute a database operation in a separate thread to avoid blocking.""" loop = asyncio.get_event_loop() - return await loop.run_in_executor( - self.thread_pool, lambda: self._retry_db_operation(operation_func) - ) + return await loop.run_in_executor(self.thread_pool, operation_func) def _db_connect(self): """Create a new connection for each operation to avoid locking.""" - conn = sqlite3.connect(self.sqlite_db_path, timeout=30.0) + conn = sqlite3.connect(self.sqlite_db_path, timeout=10.0) conn.isolation_level = None # Enable autocommit mode return conn - def _retry_db_operation(self, operation_func, max_retries=3, base_delay=2.0): - """Retry a database operation with exponential backoff on lock errors.""" - last_error = None - for attempt in range(max_retries + 1): - try: - return operation_func() - except sqlite3.OperationalError as e: - if "database is locked" in str(e): - last_error = e - if attempt < max_retries: - delay = base_delay * (2 ** attempt) # 2s, 4s, 8s - print(f"[AITK] Database locked (attempt {attempt + 1}/{max_retries + 1}), retrying in {delay:.1f}s...") - time.sleep(delay) - else: - print(f"[AITK] Database locked after {max_retries + 1} attempts, giving up.") - else: - raise - raise last_error - def should_stop(self): if not self.is_ui_trainer: return False + # In distributed mode, only rank 0 polls the DB to avoid SQLite locking + if self.accelerator.num_processes > 1 and not self.accelerator.is_main_process: + return False def _check_stop(): with self._db_connect() as conn: cursor = conn.cursor() @@ -143,11 +145,14 @@ def _check_stop(): stop = cursor.fetchone() return False if stop is None else stop[0] == 1 - return self._retry_db_operation(_check_stop) + return _check_stop() def should_return_to_queue(self): if not self.is_ui_trainer: return False + # In distributed mode, only rank 0 polls the DB to avoid SQLite locking + if self.accelerator.num_processes > 1 and not self.accelerator.is_main_process: + return False def _check_return_to_queue(): with self._db_connect() as conn: cursor = conn.cursor() @@ -156,11 +161,18 @@ def _check_return_to_queue(): return_to_queue = cursor.fetchone() return False if return_to_queue is None else return_to_queue[0] == 1 - return self._retry_db_operation(_check_return_to_queue) + return _check_return_to_queue() def maybe_stop(self): if not self.is_ui_trainer: return + # In distributed mode, only rank 0 checks stop signals. + # Non-rank-0 processes are terminated via the stop watcher's os.killpg(). + # We CANNOT use dist.broadcast() here because maybe_stop() is called from + # rank-0-only code paths (inside sample(), save(), sample_step_hook()) where + # other ranks have already diverged. A broadcast would deadlock. + if self.accelerator.num_processes > 1 and not self.accelerator.is_main_process: + return if self.should_stop(): self._run_async_operation( self._update_status("stopped", "Job stopped")) @@ -170,6 +182,7 @@ def maybe_stop(self): self._run_async_operation( self._update_status("queued", "Job queued")) self.is_stopping = True + self.is_returning_to_queue = True raise Exception("Job returning to queue") async def _update_key(self, key, value): @@ -251,28 +264,16 @@ async def wait_for_all_async(self): def on_error(self, e: Exception): super(DiffusionTrainer, self).on_error(e) if self.is_ui_trainer: - try: - if self.accelerator.is_main_process and not self.is_stopping: + if self.accelerator.is_main_process: + if self.is_returning_to_queue: + self.update_status("queued", "Job queued") + elif self.is_stopping: + self.update_status("stopped", "Job stopped") + else: self.update_status("error", str(e)) - self.update_db_key("step", self.last_save_step) - asyncio.run(self.wait_for_all_async()) - except Exception as db_err: - print(f"[AITK] Warning: failed to update DB during error handling: {db_err}") - finally: - self.thread_pool.shutdown(wait=True) - - def handle_timing_print_hook(self, timing_dict): - if "train_loop" not in timing_dict: - print("train_loop not found in timing_dict", timing_dict) - return - seconds_per_iter = timing_dict["train_loop"] - # determine iter/sec or sec/iter - if seconds_per_iter < 1: - iters_per_sec = 1 / seconds_per_iter - self.update_db_key("speed_string", f"{iters_per_sec:.2f} iter/sec") - else: - self.update_db_key( - "speed_string", f"{seconds_per_iter:.2f} sec/iter") + self.update_db_key("step", self.last_save_step) + asyncio.run(self.wait_for_all_async()) + self.thread_pool.shutdown(wait=True) def done_hook(self): super(DiffusionTrainer, self).done_hook() @@ -286,7 +287,29 @@ def end_step_hook(self): super(DiffusionTrainer, self).end_step_hook() if self.is_ui_trainer: self.update_step() - self.maybe_stop() + # Update speed_string every ~10 seconds using timer's rolling average + train_loop_timings = self.timer.timers.get('train_loop') + if train_loop_timings and (time.time() - self._last_speed_update) >= 10.0: + seconds_per_iter = sum(train_loop_timings) / len(train_loop_timings) + if seconds_per_iter <= 0: + pass + elif seconds_per_iter < 1: + self.update_db_key("speed_string", f"{1 / seconds_per_iter:.2f} iter/sec") + else: + self.update_db_key("speed_string", f"{seconds_per_iter:.2f} sec/iter") + self._last_speed_update = time.time() + # FSDP: broadcast stop flag so all ranks exit together for collective save. + # Only uses the watcher's is_stopping flag (no DB query in hot path). + if self.accelerator.num_processes > 1: + import torch.distributed as dist + flag = 1 if self.is_stopping else 0 + stop_tensor = torch.tensor([flag], device=self.accelerator.device) + dist.broadcast(stop_tensor, src=0) + if stop_tensor.item() == 1: + self.is_stopping = True + raise Exception("Job stopped") + else: + self.maybe_stop() def hook_before_model_load(self): super().hook_before_model_load() @@ -306,8 +329,6 @@ def hook_before_train_loop(self): self.maybe_stop() self.update_step() self.update_status("running", "Training") - self.timer.add_after_print_hook(self.handle_timing_print_hook) - def status_update_hook_func(self, string): self.update_status("running", string) @@ -332,9 +353,16 @@ def sample(self, step=None, is_first=False): self.maybe_stop() self.update_status("running", "Training") - def save(self, step=None): - self.maybe_stop() - self.update_status("running", "Saving model") - super().save(step) - self.maybe_stop() - self.update_status("running", "Training") + def save(self, step=None, is_temp=False): + if not is_temp: + # Under FSDP, skip pre-save maybe_stop() — it could throw on rank 0 + # before the collective save, deadlocking other ranks in full_tensor(). + # The post-save check and stop watcher handle stop signals safely. + if not self.use_fsdp: + self.maybe_stop() + self.update_status("running", "Saving model") + super().save(step, is_temp=is_temp) + if not is_temp: + if not self.use_fsdp: + self.maybe_stop() + self.update_status("running", "Training") diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 066a5eea0..c0f706be9 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -239,6 +239,12 @@ def before_dataset_load(self): self.taesd.requires_grad_(False) def hook_before_train_loop(self): + # Under FSDP, force TE unloading: pre-compute embeddings with TE on GPU, + # then unload TE before training. This prevents the TE and sharded + # transformer from coexisting on GPU during training steps. + if getattr(self, '_will_use_fsdp', False): + self.train_config.unload_text_encoder = True + super().hook_before_train_loop() if self.is_caching_text_embeddings: # make sure model is on cpu for this part so we don't oom. @@ -325,23 +331,39 @@ def hook_before_train_loop(self): self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class) self.cache_sample_prompts() - + + # cache per-image text embeddings with TE on GPU + # only main process caches (files on shared disk), others wait + if self.is_caching_text_embeddings: + from toolkit.data_loader import get_dataloader_datasets + all_datasets = get_dataloader_datasets(self.data_loader) + if self.data_loader_reg is not None: + all_datasets += get_dataloader_datasets(self.data_loader_reg) + if self.accelerator.is_main_process: + for dataset in all_datasets: + if hasattr(dataset, 'cache_text_embeddings'): + dataset.cache_text_embeddings() + # sync so non-main ranks wait for caching to finish + if self.accelerator.num_processes > 1: + self.accelerator.wait_for_everyone() + # mark all file items as cached on non-main ranks + if not self.accelerator.is_main_process: + for dataset in all_datasets: + if hasattr(dataset, 'file_list'): + for file_item in dataset.file_list: + file_item.is_text_embedding_cached = True + print_acc("\n***** UNLOADING TEXT ENCODER *****") if self.is_caching_text_embeddings: - print_acc("Embeddings cached to disk. We dont need the text encoder anymore") + print_acc("Text embeddings cached to disk. Unloading text encoder.") else: - print_acc("This will train only with a blank prompt or trigger word, if set") - print_acc("If this is not what you want, remove the unload_text_encoder flag") + print_acc("WARNING: Text embedding caching is not enabled.") + print_acc("Training will use only blank prompt or trigger word.") print_acc("***********************************") print_acc("") # unload the text encoder - if self.is_caching_text_embeddings: - unload_text_encoder(self.sd) - else: - # todo once every model is tested to work, unload properly. Though, this will all be merged into one thing. - # keep legacy usage for now. - self.sd.text_encoder_to("cpu") + unload_text_encoder(self.sd) flush() if self.train_config.blank_prompt_preservation and self.cached_blank_embeds is None: @@ -2090,6 +2112,10 @@ def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchD if not self.is_grad_accumulation_step: + # Sync LoRA gradients across ranks before optimizer step. + # Must be outside accumulate() context and only on optimizer-step boundaries. + self.sync_network_gradients() + # fix this for multi params if self.train_config.optimizer != 'adafactor': if isinstance(self.params[0], dict): diff --git a/extensions_built_in/sd_trainer/UITrainer.py b/extensions_built_in/sd_trainer/UITrainer.py index 8b5fa796c..2b15e4368 100644 --- a/extensions_built_in/sd_trainer/UITrainer.py +++ b/extensions_built_in/sd_trainer/UITrainer.py @@ -8,6 +8,7 @@ import threading import time import signal +import torch AITK_Status = Literal["running", "stopped", "error", "completed"] @@ -26,20 +27,27 @@ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): if self.job_id is None: raise Exception("AITK_JOB_ID not set") self.is_stopping = False + self.is_returning_to_queue = False # Create a thread pool for database operations self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) # Track all async tasks self._async_tasks = [] # Initialize the status self._run_async_operation(self._update_status("running", "Starting")) + self._last_speed_update = 0.0 self._stop_watcher_started = False - # self.start_stop_watcher(interval_sec=2.0) + self.start_stop_watcher(interval_sec=2.0) def start_stop_watcher(self, interval_sec: float = 5.0): """ Start a daemon thread that periodically checks should_stop() and terminates the process immediately when triggered. + In distributed mode, only rank 0 polls the DB; it kills the + entire process group so all ranks terminate. """ + # Only rank 0 should poll the DB to avoid SQLite locking issues + if self.accelerator.num_processes > 1 and not self.accelerator.is_main_process: + return if getattr(self, "_stop_watcher_started", False): return self._stop_watcher_started = True @@ -52,26 +60,34 @@ def _stop_watcher_thread(self, interval_sec: float): while True: try: if self.should_stop(): - # Mark and update status (non-blocking; uses existing infra) self.is_stopping = True - self._run_async_operation( - self._update_status("stopped", "Job stopped (remote)") - ) - # Best-effort flush pending async ops - try: - asyncio.run(self.wait_for_all_async()) - except RuntimeError: - pass - # Try to stop DB thread pool quickly - try: - self.thread_pool.shutdown(wait=False, cancel_futures=True) - except TypeError: - self.thread_pool.shutdown(wait=False) print("") print("****************************************************") print(" Stop signal received; terminating process. ") print("****************************************************") - os.kill(os.getpid(), signal.SIGINT) + if self.accelerator.num_processes > 1: + # FSDP: don't kill — let end_step_hook broadcast the flag + # so all ranks save a temp checkpoint together. + # Don't shut down thread_pool here; on_error() still needs it. + # Force kill after 5 min as a last-resort fallback. + for _ in range(150): + time.sleep(2) + os.killpg(os.getpgid(os.getpid()), signal.SIGTERM) + else: + # Single GPU: update status, flush, then SIGINT. + self._run_async_operation( + self._update_status("stopped", "Job stopped") + ) + try: + asyncio.run(self.wait_for_all_async()) + except RuntimeError: + pass + try: + self.thread_pool.shutdown(wait=False, cancel_futures=True) + except TypeError: + self.thread_pool.shutdown(wait=False) + os.kill(os.getpid(), signal.SIGINT) + break # don't loop — avoid repeated signals interrupting save time.sleep(interval_sec) except Exception: time.sleep(interval_sec) @@ -85,6 +101,10 @@ def _run_async_operation(self, coro): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) + # Prune completed tasks periodically to avoid unbounded growth + if len(self._async_tasks) > 100: + self._async_tasks = [t for t in self._async_tasks if not t.done()] + # Create a task and track it if loop.is_running(): task = asyncio.run_coroutine_threadsafe(coro, loop) @@ -106,6 +126,9 @@ def _db_connect(self): return conn def should_stop(self): + # In distributed mode, only rank 0 polls the DB to avoid SQLite locking + if self.accelerator.num_processes > 1 and not self.accelerator.is_main_process: + return False def _check_stop(): with self._db_connect() as conn: cursor = conn.cursor() @@ -115,8 +138,11 @@ def _check_stop(): return False if stop is None else stop[0] == 1 return _check_stop() - + def should_return_to_queue(self): + # In distributed mode, only rank 0 polls the DB to avoid SQLite locking + if self.accelerator.num_processes > 1 and not self.accelerator.is_main_process: + return False def _check_return_to_queue(): with self._db_connect() as conn: cursor = conn.cursor() @@ -128,6 +154,13 @@ def _check_return_to_queue(): return _check_return_to_queue() def maybe_stop(self): + # In distributed mode, only rank 0 checks stop signals. + # Non-rank-0 processes are terminated via the stop watcher's os.killpg(). + # We CANNOT use dist.broadcast() here because maybe_stop() is called from + # rank-0-only code paths (inside sample(), save(), sample_step_hook()) where + # other ranks have already diverged. A broadcast would deadlock. + if self.accelerator.num_processes > 1 and not self.accelerator.is_main_process: + return if self.should_stop(): self._run_async_operation( self._update_status("stopped", "Job stopped")) @@ -137,6 +170,7 @@ def maybe_stop(self): self._run_async_operation( self._update_status("queued", "Job queued")) self.is_stopping = True + self.is_returning_to_queue = True raise Exception("Job returning to queue") async def _update_key(self, key, value): @@ -217,25 +251,20 @@ async def wait_for_all_async(self): def on_error(self, e: Exception): super(UITrainer, self).on_error(e) - if self.accelerator.is_main_process and not self.is_stopping: - self.update_status("error", str(e)) + if self.accelerator.is_main_process: + if self.is_returning_to_queue: + self.update_status("queued", "Job queued") + elif self.is_stopping: + self.update_status("stopped", "Job stopped") + else: + self.update_status("error", str(e)) self.update_db_key("step", self.last_save_step) - asyncio.run(self.wait_for_all_async()) + try: + asyncio.run(self.wait_for_all_async()) + except RuntimeError: + pass self.thread_pool.shutdown(wait=True) - def handle_timing_print_hook(self, timing_dict): - if "train_loop" not in timing_dict: - print("train_loop not found in timing_dict", timing_dict) - return - seconds_per_iter = timing_dict["train_loop"] - # determine iter/sec or sec/iter - if seconds_per_iter < 1: - iters_per_sec = 1 / seconds_per_iter - self.update_db_key("speed_string", f"{iters_per_sec:.2f} iter/sec") - else: - self.update_db_key( - "speed_string", f"{seconds_per_iter:.2f} sec/iter") - def done_hook(self): super(UITrainer, self).done_hook() self.update_status("completed", "Training completed") @@ -246,7 +275,29 @@ def done_hook(self): def end_step_hook(self): super(UITrainer, self).end_step_hook() self.update_step() - self.maybe_stop() + # Update speed_string every ~10 seconds using timer's rolling average + train_loop_timings = self.timer.timers.get('train_loop') + if train_loop_timings and (time.time() - self._last_speed_update) >= 10.0: + seconds_per_iter = sum(train_loop_timings) / len(train_loop_timings) + if seconds_per_iter <= 0: + pass + elif seconds_per_iter < 1: + self.update_db_key("speed_string", f"{1 / seconds_per_iter:.2f} iter/sec") + else: + self.update_db_key("speed_string", f"{seconds_per_iter:.2f} sec/iter") + self._last_speed_update = time.time() + # FSDP: broadcast stop flag so all ranks exit together for collective save. + # Only uses the watcher's is_stopping flag (no DB query in hot path). + if self.accelerator.num_processes > 1: + import torch.distributed as dist + flag = 1 if self.is_stopping else 0 + stop_tensor = torch.tensor([flag], device=self.accelerator.device) + dist.broadcast(stop_tensor, src=0) + if stop_tensor.item() == 1: + self.is_stopping = True + raise Exception("Job stopped") + else: + self.maybe_stop() def hook_before_model_load(self): super().hook_before_model_load() @@ -263,8 +314,6 @@ def hook_before_train_loop(self): self.maybe_stop() self.update_step() self.update_status("running", "Training") - self.timer.add_after_print_hook(self.handle_timing_print_hook) - def status_update_hook_func(self, string): self.update_status("running", string) @@ -287,9 +336,15 @@ def sample(self, step=None, is_first=False): self.maybe_stop() self.update_status("running", "Training") - def save(self, step=None): - self.maybe_stop() - self.update_status("running", "Saving model") - super().save(step) - self.maybe_stop() - self.update_status("running", "Training") + def save(self, step=None, is_temp=False): + if not is_temp: + # Under FSDP, skip pre-save maybe_stop() — it could throw on rank 0 + # before the collective save, deadlocking other ranks in full_tensor(). + if not self.use_fsdp: + self.maybe_stop() + self.update_status("running", "Saving model") + super().save(step, is_temp=is_temp) + if not is_temp: + if not self.use_fsdp: + self.maybe_stop() + self.update_status("running", "Training") diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 0ef7077c2..d8eedf27e 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -106,6 +106,8 @@ def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=No self.network_config = NetworkConfig(**network_config) else: self.network_config = None + # Detect FSDP early — needed for device state presets before model loading + self._will_use_fsdp = (self.accelerator.num_processes > 1 and self.network_config is not None) self.train_config = TrainConfig(**self.get_conf('train', {})) model_config = self.get_conf('model', {}) self.modules_being_trained: List[torch.nn.Module] = [] @@ -146,8 +148,10 @@ def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=No self.dataset_configs: List[DatasetConfig] = [] self.params = [] - # add dataset text embedding cache to their config - if self.train_config.cache_text_embeddings: + # unload_text_encoder always implies cache_text_embeddings + if self.train_config.unload_text_encoder: + self.train_config.cache_text_embeddings = True + if self.train_config.cache_text_embeddings and raw_datasets is not None: for raw_dataset in raw_datasets: raw_dataset['cache_text_embeddings'] = True @@ -209,6 +213,9 @@ def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=No # 'ratio', 0.25) # get the device state preset based on what we are training + # Under FSDP, unload TE during training to save GPU memory for + # the transformer forward/backward pass. TE is loaded on-demand for encoding. + fsdp_unload_te = getattr(self, '_will_use_fsdp', False) self.train_device_state_preset = get_train_sd_device_state_preset( device=self.device_torch, train_unet=self.train_config.train_unet, @@ -219,10 +226,10 @@ def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=No train_embedding=self.embed_config is not None, train_decorator=self.decorator_config is not None, train_refiner=self.train_config.train_refiner, - unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings, + unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings or fsdp_unload_te, require_grads=False # we ensure them later ) - + self.get_params_device_state_preset = get_train_sd_device_state_preset( device=self.device_torch, train_unet=self.train_config.train_unet, @@ -233,7 +240,7 @@ def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=No train_embedding=self.embed_config is not None, train_decorator=self.decorator_config is not None, train_refiner=self.train_config.train_refiner, - unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings, + unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings or fsdp_unload_te, require_grads=True # We check for grads when getting params ) @@ -267,7 +274,10 @@ def post_process_generate_image_config_list(self, generate_image_config_list: Li return generate_image_config_list def sample(self, step=None, is_first=False): - if not self.accelerator.is_main_process: + # Under FSDP, all ranks must participate in the forward pass because + # parameters are sharded. All ranks run the sampling pipeline together + # with identical inputs; only rank 0 saves images. + if not self.use_fsdp and not self.accelerator.is_main_process: return flush() sample_folder = os.path.join(self.save_root, 'samples') @@ -361,8 +371,13 @@ def sample(self, step=None, is_first=False): if self.adapter is not None and isinstance(self.adapter, CustomAdapter): self.adapter.is_sampling = True - # send to be generated - self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler) + # send to be generated — under FSDP all ranks run forward, only rank 0 saves + self.sd.generate_images( + gen_img_config_list, + sampler=sample_config.sampler, + use_fsdp=self.use_fsdp, + is_main_process=self.accelerator.is_main_process, + ) if self.adapter is not None and isinstance(self.adapter, CustomAdapter): @@ -488,7 +503,116 @@ def done_hook(self): def end_step_hook(self): pass - def save(self, step=None): + def _save_fsdp(self, step=None, is_temp=False): + """FSDP-aware save: all ranks gather state dict, only rank 0 writes files. + + get_state_dict() calls full_tensor() on DTensors, which is a collective + op requiring all ranks. After gathering, only rank 0 does file I/O. + All ranks must wait for file I/O to complete before the next training + step — FSDP's forward pass (AllGather) is collective. + """ + # All ranks: gather the LoRA state dict (collective operation) + save_dict = None + if self.network is not None: + prev_multiplier = self.network.multiplier + self.network.multiplier = 1.0 + try: + embedding_dict = self.embedding.state_dict() if self.embedding else None + save_dict = self.network.get_state_dict( + extra_state_dict=embedding_dict, + dtype=get_torch_dtype(self.save_config.dtype), + ) + finally: + self.network.multiplier = prev_multiplier + + self.accelerator.wait_for_everyone() + + # All ranks: gather full optimizer state (collective op on DTensors). + # Must happen outside the rank 0 block so all ranks participate. + optim_state = None + if self.optimizer is not None: + if self.sd.network is not None: + # LoRA params live in the separate network module, not the + # FSDP-wrapped unet, so plain state_dict works. + optim_state = self.optimizer.state_dict() + else: + # get_optimizer_state_dict is a collective op — all ranks must call it. + # Do NOT wrap in try/except: asymmetric failure would deadlock. + from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict + optim_state = get_optimizer_state_dict(self.sd.unet, self.optimizer) + + # Only rank 0: write files (other ranks wait at the barrier below). + # try/finally ensures the final barrier is always reached even if + # file I/O fails — otherwise non-rank-0 would deadlock at the barrier. + try: + if self.accelerator.is_main_process: + flush() + if not os.path.exists(self.save_root): + os.makedirs(self.save_root, exist_ok=True) + + step_num = '' + if step is not None: + step_num = f"_{str(step).zfill(9)}" + + self.update_training_metadata() + save_meta = copy.deepcopy(self.meta) + save_meta = get_meta_for_safetensors(save_meta, self.job.name) + + file_path = None + if self.network is not None and save_dict is not None: + lora_name = self.job.name + if self.named_lora: + lora_name += '_LoRA' + filename = f'{lora_name}{step_num}.safetensors' + file_path = os.path.join(self.save_root, filename) + + if os.path.splitext(file_path)[1] == ".safetensors": + from safetensors.torch import save_file as sf_save + from toolkit.metadata import add_model_hash_to_meta + metadata = OrderedDict() + metadata = add_model_hash_to_meta(save_dict, metadata) + metadata.update(save_meta) + sf_save(save_dict, file_path, metadata) + else: + torch.save(save_dict, file_path) + + print_acc(f"Saved checkpoint to {file_path}") + + # Save standalone embedding file (even if also included in LoRA dict) + if self.embedding is not None: + emb_filename = f'{self.embed_config.trigger}{step_num}.safetensors' + emb_file_path = os.path.join(self.save_root, emb_filename) + self.embedding.step = self.step_num + if self.embed_config.save_format == "pt": + emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt" + self.embedding.save(emb_file_path) + + # Save gathered optimizer state + if optim_state is not None: + try: + opt_file_path = os.path.join(self.save_root, 'optimizer.pt') + torch.save(optim_state, opt_file_path) + print_acc(f"Saved optimizer to {opt_file_path}") + except Exception as e: + print_acc(e) + print_acc("Could not save optimizer") + + self.clean_up_saves() + self.post_save_hook(file_path or self.save_root) + finally: + # All ranks: update step and wait for rank 0 to finish file I/O + if step is not None: + self.last_save_step = step + self.accelerator.wait_for_everyone() + + def save(self, step=None, is_temp=False): + # Under FSDP, get_state_dict() contains full_tensor() calls which are + # collective ops requiring all ranks. All ranks must enter save(), but + # only rank 0 does file I/O. Non-rank-0 processes participate in the + # collective gather inside network.save_weights() then return. + if self.use_fsdp: + self._save_fsdp(step, is_temp=is_temp) + return if not self.accelerator.is_main_process: return flush() @@ -698,6 +822,55 @@ def save(self, step=None): self.ema.train() flush() + def save_on_interrupt(self): + """Save a temporary checkpoint when training is interrupted so no steps are lost.""" + if not self.use_fsdp and not self.accelerator.is_main_process: + return + if self.use_fsdp: + # All ranks must agree on whether to save — collective ops require + # all-or-nothing participation. Rank 0 decides, broadcasts to all. + # Use a timeout: if ranks can't synchronize (e.g., one rank crashed + # or is stuck in a different collective), skip the save rather than + # deadlocking forever. + import torch.distributed as dist + try: + # Use a barrier with timeout rather than wait_for_everyone() to avoid + # indefinite hangs when only a subset of ranks enter this code path. + if dist.is_initialized(): + dist.barrier(device_ids=[torch.cuda.current_device()] if torch.cuda.is_available() else None) + should_save = 0 + if self.accelerator.is_main_process: + if self.step_num > self.last_save_step and self.step_num > self.start_step: + should_save = 1 + save_tensor = torch.tensor([should_save], device=self.accelerator.device) + dist.broadcast(save_tensor, src=0) + if save_tensor.item() == 0: + return + except Exception as e: + if self.accelerator.is_main_process: + print_acc(f"Warning: FSDP interrupt save coordination failed: {e}") + print_acc("Skipping interrupt save — ranks could not synchronize.") + return + else: + if self.step_num <= self.last_save_step: + return + if self.step_num <= self.start_step: + return + if self.accelerator.is_main_process: + print_acc(f"\nSaving interrupt checkpoint at step {self.step_num}") + try: + self.save(self.step_num) + except Exception as e: + if self.accelerator.is_main_process: + print_acc(f"Warning: Failed to save interrupt checkpoint: {e}") + + def on_error(self, e: Exception): + try: + self.save_on_interrupt() + except Exception as save_err: + print_acc(f"Warning: Failed to save on error: {save_err}") + super().on_error(e) + # Called before the model is loaded def hook_before_model_load(self): # override in subclass @@ -711,23 +884,116 @@ def hook_add_extra_train_params(self, params): # override in subclass return params + @property + def use_fsdp(self): + """Whether FSDP v2 parameter sharding is active for this training run. + Auto-enabled for multi-GPU + LoRA training when block classes are detected.""" + return getattr(self, '_fsdp_active', False) + def hook_before_train_loop(self): if self.accelerator.is_main_process: self.logger.start() + + # For multi-GPU LoRA, recreate accelerator with FSDP v2 plugin + # before prepare_accelerator() wraps models. + if (self.accelerator.num_processes > 1 and self.network_config is not None): + self._setup_fsdp_accelerator() + self.prepare_accelerator() - + + def _setup_fsdp_accelerator(self): + """Recreate the accelerator with FSDP v2 for parameter sharding.""" + from toolkit.fsdp_utils import create_fsdp_plugin, get_block_class_names + from toolkit.accelerator import reset_accelerator + + transformer = self.sd.unet + block_class_names = get_block_class_names(transformer, model=self.sd) + + if not block_class_names: + print_acc("WARNING: Could not detect transformer block classes for FSDP wrapping. " + "Falling back to standard DDP.") + # Clear the intent flag so device state presets and model loading + # don't behave as if FSDP is active. + self._will_use_fsdp = False + # Model was loaded to CPU for FSDP — move transformer to GPU for DDP. + if self.sd.unet is not None: + self.sd.unet.to(self.device_torch) + return + + print_acc(f"FSDP v2: sharding transformer across {self.accelerator.num_processes} GPUs") + print_acc(f" Block classes to wrap: {block_class_names}") + + plugin = create_fsdp_plugin(block_class_names) + self.accelerator = reset_accelerator(fsdp_plugin=plugin) + self.device = str(self.accelerator.device) + self.device_torch = self.accelerator.device + # Update stale accelerator references captured at init time + if hasattr(self, 'sd') and self.sd is not None: + self.sd.accelerator = self.accelerator + self._fsdp_active = True + def sample_step_hook(self, img_num, total_imgs): pass def prepare_accelerator(self): - # set some config - self.accelerator.even_batches=False - - # # prepare all the models stuff for accelerator (hopefully we dont miss any) + # Validate incompatible features with distributed training + if self.accelerator.num_processes > 1: + if self.model_config.split_model_over_gpus: + raise ValueError( + "split_model_over_gpus (model parallelism) cannot be combined with " + "multi-GPU distributed training. Use one or the other." + ) + if self.train_config.do_paramiter_swapping: + raise ValueError( + "Parameter swapping is not compatible with distributed training." + ) + if self.use_fsdp: + # Quantized transformer can't be FSDP-sharded (QTensors don't survive DTensor) + quant_flags = ['quantize', 'load_in_4bit', 'load_in_8bit'] + has_quant = any(getattr(self.model_config, f, False) for f in quant_flags) + if has_quant: + raise ValueError( + "Quantization of the transformer cannot be combined with FSDP v2. " + "FSDP already reduces per-GPU memory by sharding parameters." + ) + # quantize_te is OK — quantized TEs skip FSDP wrapping and go to GPU directly + if self.use_fsdp and self.train_config.train_text_encoder: + raise ValueError( + "Training text encoders is not yet supported with FSDP v2 multi-GPU. " + "Only LoRA on the transformer is supported." + ) + if self.use_fsdp: + has_offload = ( + getattr(self.model_config, 'layer_offloading', False) or + (self.network_config is not None and getattr(self.network_config, 'layer_offloading', False)) + ) + if has_offload: + raise ValueError( + "Layer offloading cannot be combined with FSDP v2. " + "FSDP shards parameters across GPUs — offloading them to CPU breaks the sharding contract." + ) + # Warn about bucket batching in distributed mode for non-LoRA training. + if self.network_config is None and self.datasets is not None: + has_buckets = any( + ds.get('buckets', False) if isinstance(ds, dict) else getattr(ds, 'buckets', False) + for ds in self.datasets + ) + if has_buckets: + print_acc("WARNING: Bucket batching with distributed training may cause shape " + "mismatches across ranks for non-LoRA training. This is safe for LoRA.") + + self.accelerator.even_batches = False + + if self.use_fsdp: + self._prepare_accelerator_fsdp() + else: + self._prepare_accelerator_standard() + + def _prepare_accelerator_standard(self): + """Standard DDP preparation — wraps all models with accelerator.prepare().""" self.sd.vae = self.accelerator.prepare(self.sd.vae) if self.sd.unet is not None: self.sd.unet = self.accelerator.prepare(self.sd.unet) - # todo always tdo it? self.modules_being_trained.append(self.sd.unet) if self.sd.text_encoder is not None and self.train_config.train_text_encoder: if isinstance(self.sd.text_encoder, list): @@ -739,22 +1005,91 @@ def prepare_accelerator(self): if self.sd.refiner_unet is not None and self.train_config.train_refiner: self.sd.refiner_unet = self.accelerator.prepare(self.sd.refiner_unet) self.modules_being_trained.append(self.sd.refiner_unet) - # todo, do we need to do the network or will "unet" get it? if self.sd.network is not None: self.sd.network = self.accelerator.prepare(self.sd.network) self.modules_being_trained.append(self.sd.network) if self.adapter is not None and self.adapter_config.train: - # todo adapters may not be a module. need to check self.adapter = self.accelerator.prepare(self.adapter) self.modules_being_trained.append(self.adapter) - - # prepare other things + self.optimizer = self.accelerator.prepare(self.optimizer) if self.lr_scheduler is not None: self.lr_scheduler = self.accelerator.prepare(self.lr_scheduler) - # self.data_loader = self.accelerator.prepare(self.data_loader) - # if self.data_loader_reg is not None: - # self.data_loader_reg = self.accelerator.prepare(self.data_loader_reg) + + def _prepare_accelerator_fsdp(self): + """FSDP v2 preparation — transformer and text encoders are FSDP-wrapped. + + Both are sharded across GPUs via FSDP, reducing per-GPU memory. + VAE is excluded (small, no gradients). + """ + # FSDP2 requires model and optimizer in the SAME prepare() call so + # Accelerate can rebind optimizer param_groups to sharded DTensors. + # Models are already on CPU (loaded there when _will_use_fsdp=True), + # so Accelerate's fsdp2_prepare_model() can: save state_dict on CPU, + # move to meta device, fully_shard(), then distribute sharded weights. + if self.sd.unet is not None: + self.sd.unet, self.optimizer = self.accelerator.prepare( + self.sd.unet, self.optimizer + ) + self.modules_being_trained.append(self.sd.unet) + + # Text encoders stay on CPU under FSDP — unloaded during training, + # loaded to GPU on-demand for encoding via set_device_state(). + # This keeps GPU memory free for the transformer forward/backward pass. + + # Update model's device references — they were set to CPU during + # construction but now the accelerator device is GPU after FSDP reset. + self.sd.device_torch = self.device_torch + self.sd.device = str(self.device_torch) + self.sd.vae_device_torch = self.device_torch + self.sd.te_device_torch = self.device_torch + + # Move VAE to GPU (small enough to fit without sharding) + if self.sd.vae is not None: + self.sd.vae = self.sd.vae.to(self.device_torch) + + # Network (LoRA) is part of the transformer, no separate prepare needed. + if self.sd.network is not None: + self.modules_being_trained.append(self.sd.network) + + # Optimizer already prepared above with model. + if self.lr_scheduler is not None: + self.lr_scheduler = self.accelerator.prepare(self.lr_scheduler) + + + def sync_network_gradients(self): + """Manually sync LoRA network gradients across ranks in distributed training. + + Required because LoRA's forward-hook injection (monkey-patching org_module.forward) + bypasses DDP's bucket-based gradient reduction. The forward pass goes through the + parent model's layers, not through the DDP wrapper's forward(), so DDP's autograd + hooks don't fire correctly for LoRA parameters. + See: kohya-ss/sd-scripts PR #989 + + IMPORTANT: Only call on optimizer-step boundaries, NOT every accumulation step. + The accumulate() context uses no_sync() internally; calling all_reduce on every + step would defeat gradient accumulation. + + NOTE: If an OOM occurs on only some ranks, the ranks that succeeded will call + this while the failed ranks won't (they exit via exception). This would cause + a hang. In practice, all ranks process the same batch size on identical GPUs, + so OOM is almost always all-or-nothing. A full collective OOM broadcast would + be the robust fix but is deferred for now. + """ + if self.accelerator.num_processes <= 1: + return + if self.use_fsdp: + # FSDP handles gradient reduce-scatter automatically + return + if self.sd.network is None: + return + import torch.distributed as dist + network = unwrap_model(self.sd.network) + world_size = self.accelerator.num_processes + for param in network.parameters(): + if param.grad is not None: + dist.all_reduce(param.grad, op=dist.ReduceOp.SUM) + param.grad /= world_size def ensure_params_requires_grad(self, force=False): @@ -840,8 +1175,26 @@ def get_latest_save_path(self, name=None, post=''): return latest_path def load_training_state_from_metadata(self, path): + # Under FSDP, all ranks must agree on step_num/start_step because + # save and sample triggers are collective operations. Rank 0 reads + # metadata from disk and broadcasts to other ranks. + if self.accelerator.num_processes > 1: + import torch.distributed as dist + if self.accelerator.is_main_process: + self._load_training_state_from_metadata_impl(path) + # broadcast step_num and epoch_num from rank 0 + step_tensor = torch.tensor([self.step_num, self.epoch_num], device=self.accelerator.device) + dist.broadcast(step_tensor, src=0) + if not self.accelerator.is_main_process: + self.step_num = int(step_tensor[0].item()) + self.epoch_num = int(step_tensor[1].item()) + self.start_step = self.step_num + return if not self.accelerator.is_main_process: return + self._load_training_state_from_metadata_impl(path) + + def _load_training_state_from_metadata_impl(self, path): if path is not None and self.network_config is not None and path == self.network_config.pretrained_lora_path: # dont load metadata from pretrained lora return @@ -1601,10 +1954,15 @@ def run(self): model_config_to_load.refiner_name_or_path = previous_refiner_save self.load_training_state_from_metadata(previous_refiner_save) + # When FSDP is planned, load model to CPU so accelerator.prepare() + # can shard without a GPU memory spike. FSDP will distribute sharded + # weights to GPUs during prepare(). + model_device = torch.device("cpu") if self._will_use_fsdp else self.accelerator.device + if self._will_use_fsdp: + print_acc(f"FSDP: Loading model to CPU (device={model_device})") + self.sd = ModelClass( - # todo handle single gpu and multi gpu here - # device=self.device, - device=self.accelerator.device, + device=model_device, model_config=model_config_to_load, dtype=self.train_config.dtype, custom_pipeline=self.custom_pipeline, @@ -1706,7 +2064,8 @@ def run(self): else: text_encoder.requires_grad_(False) text_encoder.eval() - unet.to(self.device_torch, dtype=dtype) + if not self._will_use_fsdp: + unet.to(self.device_torch, dtype=dtype) unet.requires_grad_(False) unet.eval() vae = vae.to(torch.device('cpu'), dtype=dtype) @@ -1980,44 +2339,20 @@ def run(self): # only works for adafactor, but it should have thrown an error prior to this otherwise self.optimizer.enable_paramiter_swapping(self.train_config.paramiter_swapping_factor) - # check if it exists - optimizer_state_filename = f'optimizer.pt' - optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename) + # Resolve optimizer state file path. Actual loading is deferred until + # after FSDP wrapping so state buffers match sharded params. + optimizer_state_file_path = os.path.join(self.save_root, 'optimizer.pt') if os.path.exists(optimizer_state_file_path): - # try to load - # previous param groups - # previous_params = copy.deepcopy(optimizer.param_groups) - previous_lrs = [] - for group in optimizer.param_groups: - previous_lrs.append(group['lr']) - - load_optimizer = True - if self.network is not None: - if self.network.did_change_weights: - # do not load optimizer if the network changed, it will result in - # a double state that will oom. - load_optimizer = False - - if load_optimizer: - try: - print_acc(f"Loading optimizer state from {optimizer_state_file_path}") - optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True) - optimizer.load_state_dict(optimizer_state_dict) - del optimizer_state_dict - flush() - except Exception as e: - print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}") - print_acc(e) - - # update the optimizer LR from the params - print_acc(f"Updating optimizer LR from params") - if len(previous_lrs) > 0: - for i, group in enumerate(optimizer.param_groups): - group['lr'] = previous_lrs[i] - group['initial_lr'] = previous_lrs[i] + self._optimizer_state_file_path = optimizer_state_file_path + else: + self._optimizer_state_file_path = None - # Update the learning rates if they changed - # optimizer.param_groups = previous_params + if self._optimizer_state_file_path is not None: + self._load_optimizer_on_resume = True + if self.network is not None and self.network.did_change_weights: + # do not load optimizer if the network changed, it will result in + # a double state that will oom. + self._load_optimizer_on_resume = False lr_scheduler_params = self.train_config.lr_scheduler_params @@ -2036,16 +2371,47 @@ def run(self): self.before_dataset_load() # load datasets if passed in the root process if self.datasets is not None: - self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd) + self.data_loader = get_dataloader_from_datasets( + self.datasets, self.train_config.batch_size, self.sd, accelerator=self.accelerator) if self.datasets_reg is not None: - self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, - self.sd) + self.data_loader_reg = get_dataloader_from_datasets( + self.datasets_reg, self.train_config.batch_size, self.sd, accelerator=self.accelerator) flush() self.last_save_step = self.step_num ### HOOK ### self.hook_before_train_loop() + # Load optimizer state AFTER FSDP wrapping so state buffers match sharded params. + # For FSDP, use set_optimizer_state_dict (inverse of get_optimizer_state_dict used + # during save) to correctly scatter the gathered state to sharded DTensors. + # For non-FSDP, use plain load_state_dict as before. + # Re-bind local after prepare() which may return a new wrapped optimizer + optimizer = self.optimizer + + if getattr(self, '_optimizer_state_file_path', None) is not None and getattr(self, '_load_optimizer_on_resume', False): + previous_lrs = [group['lr'] for group in optimizer.param_groups] + try: + print_acc(f"Loading optimizer state from {self._optimizer_state_file_path}") + optimizer_state_dict = torch.load(self._optimizer_state_file_path, weights_only=True) + if self.use_fsdp and self.sd.network is None: + from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict + set_optimizer_state_dict(self.sd.unet, optimizer, optim_state_dict=optimizer_state_dict) + else: + optimizer.load_state_dict(optimizer_state_dict) + del optimizer_state_dict + flush() + except Exception as e: + print_acc(f"Failed to load optimizer state from {self._optimizer_state_file_path}") + print_acc(e) + + # Restore LR from config (user may have changed it between runs) + print_acc(f"Updating optimizer LR from params") + for i, group in enumerate(optimizer.param_groups): + if i < len(previous_lrs): + group['lr'] = previous_lrs[i] + group['initial_lr'] = previous_lrs[i] + # compile the model if needed (must be after LoRA/adapter injection AND accelerator.prepare) if self.model_config.compile: try: @@ -2057,6 +2423,21 @@ def run(self): print_acc(f"Failed to compile model: {e}") print_acc("Continuing without compilation") + # Log distributed training info (after hook_before_train_loop which sets up FSDP) + if self.accelerator.num_processes > 1 and self.accelerator.is_main_process: + grad_accum = max(self.train_config.gradient_accumulation_steps, self.train_config.gradient_accumulation) + effective_batch = self.train_config.batch_size * self.accelerator.num_processes * max(1, grad_accum) + print_acc(f"") + print_acc(f"========================================") + print_acc(f"Distributed training enabled") + print_acc(f" Mode: {'FSDP v2 (parameter sharding)' if self.use_fsdp else 'DDP (data parallel)'}") + print_acc(f" GPUs: {self.accelerator.num_processes}") + print_acc(f" Batch size per GPU: {self.train_config.batch_size}") + print_acc(f" Gradient accumulation: {max(1, grad_accum)}") + print_acc(f" Effective batch size: {effective_batch}") + print_acc(f"========================================") + print_acc(f"") + if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling: print_acc("Generating first sample from first sample config") self.sample(0, is_first=True) @@ -2160,7 +2541,7 @@ def run(self): if self.progress_bar is not None: self.progress_bar.pause() dataloader_iterator_reg = iter(dataloader_reg) - trigger_dataloader_setup_epoch(dataloader_reg) + trigger_dataloader_setup_epoch(dataloader_reg, self.epoch_num) with self.timer('get_batch:reg'): batch = next(dataloader_iterator_reg) @@ -2177,8 +2558,8 @@ def run(self): if self.progress_bar is not None: self.progress_bar.pause() dataloader_iterator = iter(dataloader) - trigger_dataloader_setup_epoch(dataloader) self.epoch_num += 1 + trigger_dataloader_setup_epoch(dataloader, self.epoch_num) if self.train_config.gradient_accumulation_steps == -1: # if we are accumulating for an entire epoch, trigger a step self.is_grad_accumulation_step = False @@ -2223,6 +2604,12 @@ def run(self): else: raise # not an OOM; surface real errors if did_oom: + # Under FSDP, OOM on one rank causes divergent execution paths — + # FSDP's internal reduce-scatter needs all ranks to participate. + # Re-raise immediately so all ranks fail together via os.killpg(). + if self.use_fsdp: + raise RuntimeError("OOM during FSDP training step. Cannot recover " + "from asymmetric OOM under FSDP — all ranks must fail together.") self.num_consecutive_oom += 1 if self.num_consecutive_oom > 3: raise RuntimeError("OOM during training step 3 times in a row, aborting training") @@ -2288,7 +2675,6 @@ def run(self): self.accelerator.wait_for_everyone() if is_save_step: - self.accelerator # print above the progress bar if self.progress_bar is not None: self.progress_bar.pause() @@ -2395,8 +2781,10 @@ def run(self): self.sample(self.step_num) self.logger.commit(step=self.step_num) print_acc("") + # save() must be called by all ranks under FSDP so all participate + # in the collective gather. The rank-0 guard is inside save() itself. + self.save() if self.accelerator.is_main_process: - self.save() self.logger.finish() self.accelerator.end_training() diff --git a/toolkit/accelerator.py b/toolkit/accelerator.py index 0736f0167..dad6d654a 100644 --- a/toolkit/accelerator.py +++ b/toolkit/accelerator.py @@ -10,6 +10,25 @@ def get_accelerator() -> Accelerator: global_accelerator = Accelerator() return global_accelerator + +def reset_accelerator(fsdp_plugin=None, **kwargs) -> Accelerator: + """Recreate the global accelerator with new configuration. + + Must be called before any accelerator.prepare() calls. Used to switch + from the default bare Accelerator to one configured with FSDP. + + The previous accelerator must not have been used for prepare() calls. + Only read-only operations (is_main_process, device) are safe before reset. + """ + global global_accelerator + if global_accelerator is not None: + # Release the old accelerator's resources. The process group is shared + # and will be reused by the new Accelerator instance. + del global_accelerator + global_accelerator = Accelerator(fsdp_plugin=fsdp_plugin, **kwargs) + return global_accelerator + + def unwrap_model(model): try: accelerator = get_accelerator() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 8e69d1712..559e5e65f 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -1366,7 +1366,9 @@ def validate_configs( # check if they are doing differential output preservation if train_config.diff_output_preservation: - raise ValueError("Cannot use differential output preservation with caching text embeddings. Please set diff_output_preservation to False.") + raise ValueError("Cannot use differential output preservation with caching text embeddings. " + "Note: unload_text_encoder automatically enables text embedding caching. " + "Please set diff_output_preservation to False.") # make sure they are all cached for dataset in dataset_configs: diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 2d96f1ee0..36bf00fa0 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -12,7 +12,8 @@ from PIL import Image from PIL.ImageOps import exif_transpose from torchvision import transforms -from torch.utils.data import Dataset, DataLoader, ConcatDataset +from torch.utils.data import Dataset, DataLoader, ConcatDataset, WeightedRandomSampler +from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm import albumentations as A @@ -592,8 +593,8 @@ def setup_epoch(self): self.cache_latents_all_latents() if self.is_caching_clip_vision_to_disk: self.cache_clip_vision_to_disk() - if self.is_caching_text_embeddings: - self.cache_text_embeddings() + # text embedding caching is deferred to hook_before_train_loop + # where the text encoder is guaranteed to be on GPU if self.is_generating_controls: # always do this last self.setup_controls() @@ -630,14 +631,89 @@ def __getitem__(self, item): return self._get_single_item(item) +def _build_weighted_sampler(datasets, concatenated_dataset): + """Build WeightedRandomSampler if any dataset has sampling_weight set.""" + has_weights = any(ds.dataset_config.sampling_weight is not None for ds in datasets) + if not has_weights: + return None + + weights = [ds.dataset_config.sampling_weight if ds.dataset_config.sampling_weight is not None else 1.0 for ds in datasets] + if any(w < 0 for w in weights): + raise ValueError("sampling_weight values must be non-negative") + total_weight = sum(weights) + if total_weight <= 0: + raise ValueError("sampling_weight values must sum to a positive number") + + total_items = len(concatenated_dataset) + + sample_weights = [] + for ds, w in zip(datasets, weights): + n = len(ds) + per_item_w = (w / total_weight) / (n / total_items) if n > 0 else 0.0 + sample_weights.extend([per_item_w] * n) + + for ds, w in zip(datasets, weights): + print_acc(f"Dataset '{ds.dataset_config.folder_path}': sampling_weight={w} ({w/total_weight*100:.1f}%)") + + return WeightedRandomSampler(sample_weights, num_samples=total_items, replacement=True) + + +def _build_oversampled_concat_dataset(datasets): + """Build a ConcatDataset with oversampling baked in to preserve weighted ratios. + + When using DistributedSampler (which shards uniformly), we can't use + WeightedRandomSampler. Instead, we repeat dataset entries according to + their sampling_weight so the desired ratio is baked into the dataset itself. + DistributedSampler then shards the pre-weighted data, preserving ratios. + """ + has_weights = any(ds.dataset_config.sampling_weight is not None for ds in datasets) + if not has_weights: + return ConcatDataset(datasets), False + + weights = [ds.dataset_config.sampling_weight if ds.dataset_config.sampling_weight is not None else 1.0 for ds in datasets] + if any(w < 0 for w in weights): + raise ValueError("sampling_weight values must be non-negative") + # Filter out datasets with zero weight + filtered = [(ds, w) for ds, w in zip(datasets, weights) if w > 0] + if not filtered: + raise ValueError("sampling_weight values must sum to a positive number") + datasets, weights = zip(*filtered) + datasets, weights = list(datasets), list(weights) + total_weight = sum(weights) + + # Compute repeat factors to match desired sampling ratios. + # For each dataset: desired_fraction / current_fraction tells us how much + # to oversample. Normalize so the minimum repeat is 1 (no undersampling). + normalized = [w / total_weight for w in weights] + lens = [len(ds) for ds in datasets] + total_items = sum(lens) + + # ratio_i = (desired_fraction_i) / (natural_fraction_i) + # = (w_i / total_weight) / (len_i / total_items) + ratios = [(n * total_items) / (l if l > 0 else 1) for n, l in zip(normalized, lens)] + min_ratio = min(r for r in ratios if r > 0) if any(r > 0 for r in ratios) else 1.0 + repeat_factors = [max(1, round(r / min_ratio)) for r in ratios] + + oversampled_datasets = [] + for ds, repeat, w in zip(datasets, repeat_factors, weights): + for _ in range(repeat): + oversampled_datasets.append(ds) + print_acc(f"Dataset '{ds.dataset_config.folder_path}': sampling_weight={w} ({w/total_weight*100:.1f}%), repeat={repeat}x for distributed") + + return ConcatDataset(oversampled_datasets), True + + def get_dataloader_from_datasets( dataset_options, batch_size=1, sd: 'StableDiffusion' = None, + accelerator=None, ) -> DataLoader: if dataset_options is None or len(dataset_options) == 0: return None + is_distributed = accelerator is not None and accelerator.num_processes > 1 + datasets = [] has_buckets = False is_caching_latents = False @@ -665,7 +741,21 @@ def get_dataloader_from_datasets( else: raise ValueError(f"invalid dataset type: {config.type}") - concatenated_dataset = ConcatDataset(datasets) + # In distributed mode, bake weighted sampling into dataset via oversampling + # (can't use WeightedRandomSampler with DistributedSampler) + if is_distributed: + concatenated_dataset, used_oversampling = _build_oversampled_concat_dataset(datasets) + sampler = DistributedSampler( + concatenated_dataset, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + ) + print_acc(f"Distributed training: rank {accelerator.process_index}/{accelerator.num_processes}, " + f"dataset size per rank: ~{len(concatenated_dataset) // accelerator.num_processes}") + else: + concatenated_dataset = ConcatDataset(datasets) + sampler = _build_weighted_sampler(datasets, concatenated_dataset) # todo build scheduler that can get buckets from all datasets that match # todo and evenly distribute reg images @@ -696,7 +786,8 @@ def dto_collation(batch: List['FileItemDTO']): concatenated_dataset, batch_size=None, # we batch in the datasets for now drop_last=False, - shuffle=True, + shuffle=(sampler is None), + sampler=sampler, collate_fn=dto_collation, # Use the custom collate function **dataloader_kwargs ) @@ -704,24 +795,35 @@ def dto_collation(batch: List['FileItemDTO']): data_loader = DataLoader( concatenated_dataset, batch_size=batch_size, - shuffle=True, + shuffle=(sampler is None), + sampler=sampler, collate_fn=dto_collation, **dataloader_kwargs ) return data_loader -def trigger_dataloader_setup_epoch(dataloader: DataLoader): +def trigger_dataloader_setup_epoch(dataloader: DataLoader, epoch_num: int = 0): # hacky but needed because of different types of datasets and dataloaders dataloader.len = None + + # Update DistributedSampler epoch for proper shuffling across ranks + if hasattr(dataloader, 'sampler') and isinstance(dataloader.sampler, DistributedSampler): + dataloader.sampler.set_epoch(epoch_num) + + # Use a seen set to avoid calling setup_epoch multiple times on the same + # dataset object (happens when oversampling repeats dataset references). + seen = set() if isinstance(dataloader.dataset, list): for dataset in dataloader.dataset: if hasattr(dataset, 'datasets'): for sub_dataset in dataset.datasets: - if hasattr(sub_dataset, 'setup_epoch'): + if id(sub_dataset) not in seen and hasattr(sub_dataset, 'setup_epoch'): + seen.add(id(sub_dataset)) sub_dataset.setup_epoch() sub_dataset.len = None - elif hasattr(dataset, 'setup_epoch'): + elif id(dataset) not in seen and hasattr(dataset, 'setup_epoch'): + seen.add(id(dataset)) dataset.setup_epoch() dataset.len = None elif hasattr(dataloader.dataset, 'setup_epoch'): @@ -730,7 +832,8 @@ def trigger_dataloader_setup_epoch(dataloader: DataLoader): elif hasattr(dataloader.dataset, 'datasets'): dataloader.dataset.len = None for sub_dataset in dataloader.dataset.datasets: - if hasattr(sub_dataset, 'setup_epoch'): + if id(sub_dataset) not in seen and hasattr(sub_dataset, 'setup_epoch'): + seen.add(id(sub_dataset)) sub_dataset.setup_epoch() sub_dataset.len = None diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 55bbff88f..a3c60e1eb 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -40,7 +40,9 @@ from toolkit.data_transfer_object.data_loader import FileItemDTO from toolkit.stable_diffusion_model import StableDiffusion -accelerator = get_accelerator() +def _get_acc(): + """Lazy accessor to avoid capturing a stale accelerator reference.""" + return get_accelerator() # def get_associated_caption_from_img_path(img_path): # https://demo.albumentations.ai/ @@ -1852,7 +1854,7 @@ def __init__(self: 'AiToolkitDataset', **kwargs): self.latent_cache = {} def cache_latents_all_latents(self: 'AiToolkitDataset'): - with accelerator.main_process_first(): + with _get_acc().main_process_first(): print_acc(f"Caching latents for {self.dataset_path}") # cache all latents to disk to_disk = self.is_caching_latents_to_disk @@ -2020,7 +2022,7 @@ def __init__(self: 'AiToolkitDataset', **kwargs): self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings def cache_text_embeddings(self: 'AiToolkitDataset'): - with accelerator.main_process_first(): + with _get_acc().main_process_first(): print_acc(f"Caching text_embeddings for {self.dataset_path}") print_acc(" - Saving text embeddings to disk") diff --git a/toolkit/fsdp_utils.py b/toolkit/fsdp_utils.py new file mode 100644 index 000000000..509d4a0ce --- /dev/null +++ b/toolkit/fsdp_utils.py @@ -0,0 +1,98 @@ +"""FSDP v2 utilities for multi-GPU parameter sharding. + +When multi-GPU + LoRA training is detected, FSDP v2 is used to shard the frozen +base model parameters across GPUs, reducing per-GPU memory. Only the transformer +is FSDP-wrapped; VAE and text encoders are excluded. +""" + +import torch.nn as nn +from typing import List + +# Common transformer block attribute names across diffusion model architectures. +# Used as a fallback when the model doesn't implement get_transformer_block_names(). +# Follows the same pattern as HuggingFace finetrainers. +KNOWN_BLOCK_ATTR_NAMES = [ + "transformer_blocks", + "single_transformer_blocks", + "double_stream_blocks", + "single_stream_blocks", + "double_blocks", + "single_blocks", + "temporal_transformer_blocks", + "blocks", + "layers", +] + + +def get_block_class_names( + transformer: nn.Module, + model=None, +) -> List[str]: + """Introspect the transformer to find its block class names for FSDP wrapping. + + First tries model.get_transformer_block_names() if available, then falls back + to scanning known attribute names on the transformer module. + + Args: + transformer: The transformer/unet module to introspect. + model: The parent model object (e.g. StableDiffusion) that may have + get_transformer_block_names(). + + Returns: + List of unique class name strings for FSDP transformer-based wrapping. + """ + block_attr_names = None + + # Try the model's declared block names first + if model is not None and hasattr(model, "get_transformer_block_names"): + block_attr_names = model.get_transformer_block_names() + + # Fallback: scan known attribute names on the transformer + if not block_attr_names: + block_attr_names = [] + for attr_name in KNOWN_BLOCK_ATTR_NAMES: + blocks = getattr(transformer, attr_name, None) + if blocks is not None and isinstance(blocks, nn.ModuleList) and len(blocks) > 0: + block_attr_names.append(attr_name) + + # Extract unique class names from the discovered block lists + class_names = set() + for attr_name in block_attr_names: + blocks = getattr(transformer, attr_name, None) + if blocks is None: + continue + if isinstance(blocks, nn.ModuleList) and len(blocks) > 0: + class_names.add(type(blocks[0]).__name__) + elif isinstance(blocks, nn.Module): + # Some models have sub-modules that contain blocks rather than + # being ModuleLists directly. Check for nested ModuleLists. + for child_name, child in blocks.named_children(): + if isinstance(child, nn.ModuleList) and len(child) > 0: + class_names.add(type(child[0]).__name__) + break + + return list(class_names) + + +def create_fsdp_plugin(transformer_block_class_names: List[str]): + """Create an Accelerate FSDP v2 plugin for parameter sharding. + + Args: + transformer_block_class_names: Class names of transformer blocks to wrap + as individual FSDP units (e.g. ["FluxTransformerBlock", "FluxSingleTransformerBlock"]). + + Returns: + FullyShardedDataParallelPlugin configured for FSDP v2 with FULL_SHARD. + """ + from accelerate import FullyShardedDataParallelPlugin + + plugin = FullyShardedDataParallelPlugin( + fsdp_version=2, + auto_wrap_policy="transformer_based_wrap", + transformer_cls_names_to_wrap=transformer_block_class_names, + reshard_after_forward=True, # FULL_SHARD: shard params after forward for max memory savings + activation_checkpointing=False, # toolkit handles this separately + cpu_ram_efficient_loading=True, + state_dict_type="FULL_STATE_DICT", # needed for LoRA weight extraction + ) + return plugin diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 9e4e8d1af..ccc41b560 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -367,6 +367,8 @@ def generate_images( sampler=None, pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None, + use_fsdp=False, + is_main_process=True, ): network = self.network merge_multiplier = 1.0 @@ -391,11 +393,12 @@ def generate_images( if network is not None: network = unwrap_model(self.network) network.eval() - # check if we have the same network weight for all samples. If we do, we can merge in th - # the network to drastically speed up inference + # check if we have the same network weight for all samples. If we do, we can merge in + # the network to drastically speed up inference. + # Under FSDP, skip merge_in — it directly mutates DTensor shards. unique_network_weights = set( [x.network_multiplier for x in image_configs]) - if len(unique_network_weights) == 1 and network.can_merge_in: + if len(unique_network_weights) == 1 and network.can_merge_in and not use_fsdp: can_merge_in = True merge_multiplier = unique_network_weights.pop() network.merge_in(merge_weight=merge_multiplier) @@ -403,7 +406,8 @@ def generate_images( network = BlankNetwork() self.save_device_state() - self.set_device_state_preset('generate') + if not use_fsdp: + self.set_device_state_preset('generate') # save current seed state for training rng_state = torch.get_rng_state() @@ -415,6 +419,18 @@ def generate_images( pipeline.set_progress_bar_config(disable=True) except: pass + if use_fsdp: + # Under FSDP: restore the FSDP-wrapped model so all-gather + # hooks fire during forward. Also null out TEs/tokenizers since + # we always use pre-cached prompt_embeds. + for attr in ['transformer', 'unet']: + if hasattr(pipeline, attr) and getattr(pipeline, attr) is not None: + setattr(pipeline, attr, self.unet) + break + for attr in ['text_encoder', 'text_encoder_2', 'text_encoder_3', + 'tokenizer', 'tokenizer_2', 'tokenizer_3']: + if hasattr(pipeline, attr): + setattr(pipeline, attr, None) start_multiplier = 1.0 if network is not None: @@ -427,7 +443,7 @@ def generate_images( if network is not None: assert network.is_active - for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): + for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False, disable=not is_main_process): gen_config = image_configs[i] extra = {} @@ -659,9 +675,10 @@ def generate_images( extra, ) - gen_config.save_image(img, i) - gen_config.log_image(img, i) - self._after_sample_image(i, len(image_configs)) + if is_main_process: + gen_config.save_image(img, i) + gen_config.log_image(img, i) + self._after_sample_image(i, len(image_configs)) flush() if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): @@ -676,12 +693,14 @@ def generate_images( if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) - self.restore_device_state() + if not use_fsdp: + self.restore_device_state() if network is not None: network.train() network.multiplier = start_multiplier - self.unet.to(self.device_torch, dtype=self.torch_dtype) + if not use_fsdp: + self.unet.to(self.device_torch, dtype=self.torch_dtype) if network.is_merged_in: network.merge_out(merge_multiplier) # self.tokenizer.to(original_device_dict['tokenizer']) @@ -1566,12 +1585,15 @@ def set_device_state_preset(self, device_state_preset: DeviceStatePreset): self.set_device_state(state) def text_encoder_to(self, *args, **kwargs): + if self.text_encoder is None: + return if isinstance(self.text_encoder, list): for encoder in self.text_encoder: encoder.to(*args, **kwargs) else: self.text_encoder.to(*args, **kwargs) - + + def convert_lora_weights_before_save(self, state_dict): # can be overridden in child classes to convert weights before saving return state_dict diff --git a/toolkit/models/wan21/wan21_i2v.py b/toolkit/models/wan21/wan21_i2v.py index bf5a88b89..c15167963 100644 --- a/toolkit/models/wan21/wan21_i2v.py +++ b/toolkit/models/wan21/wan21_i2v.py @@ -347,6 +347,8 @@ def generate_images( image_configs, sampler=None, pipeline=None, + use_fsdp=False, + is_main_process=True, ): # will oom on 24gb vram if we dont unload vision encoder first if self.model_config.low_vram: @@ -359,6 +361,8 @@ def generate_images( image_configs, sampler=sampler, pipeline=pipeline, + use_fsdp=use_fsdp, + is_main_process=is_main_process, ) def set_device_state_preset(self, *args, **kwargs): diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 4546e573e..931359ee6 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -536,6 +536,12 @@ def get_state_dict(self: Network, extra_state_dict=None, dtype=torch.float16): for key in list(state_dict.keys()): v = state_dict[key] + # Under FSDP v2, parameters are DTensors sharded across ranks. + # full_tensor() is a collective op (all ranks must call it) that + # gathers the full parameter. The caller must ensure all ranks + # enter get_state_dict() together. + if hasattr(v, 'full_tensor'): + v = v.full_tensor() v = v.detach().clone().to("cpu").to(dtype) save_key = save_keymap[key] if key in save_keymap else key save_dict[save_key] = v diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 2f1030c20..19ff23500 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1131,6 +1131,8 @@ def generate_images( image_configs: List[GenerateImageConfig], sampler=None, pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None, + use_fsdp=False, + is_main_process=True, ): network = unwrap_model(self.network) merge_multiplier = 1.0 @@ -1153,11 +1155,13 @@ def generate_images( if network is not None: network.eval() - # check if we have the same network weight for all samples. If we do, we can merge in th - # the network to drastically speed up inference + # check if we have the same network weight for all samples. If we do, we can merge in + # the network to drastically speed up inference. + # Under FSDP, skip merge_in — it directly mutates org_module weights which are + # DTensor shards; the LoRA hooks still apply correctly via forward hooks. unique_network_weights = set([x.network_multiplier for x in image_configs]) - if len(unique_network_weights) == 1 and network.can_merge_in: - # make sure it is on device before merging. + if len(unique_network_weights) == 1 and network.can_merge_in and not use_fsdp: + # make sure it is on device before merging. self.unet.to(self.device_torch) can_merge_in = True merge_multiplier = unique_network_weights.pop() @@ -1166,7 +1170,9 @@ def generate_images( network = BlankNetwork() self.save_device_state() - self.set_device_state_preset('generate') + if not use_fsdp: + # Under FSDP, skip — 'generate' preset moves TEs to GPU but we use cached embeddings + self.set_device_state_preset('generate') # save current seed state for training rng_state = torch.get_rng_state() @@ -1238,23 +1244,34 @@ def generate_images( pipeline = Pipe( vae=self.vae, unet=self.unet, - text_encoder=self.text_encoder[0], - text_encoder_2=self.text_encoder[1], - tokenizer=self.tokenizer[0], - tokenizer_2=self.tokenizer[1], + text_encoder=None if use_fsdp else self.text_encoder[0], + text_encoder_2=None if use_fsdp else self.text_encoder[1], + tokenizer=None if use_fsdp else self.tokenizer[0], + tokenizer_2=None if use_fsdp else self.tokenizer[1], scheduler=noise_scheduler, **extra_args - ).to(self.device_torch) + ) + if not use_fsdp: + pipeline = pipeline.to(self.device_torch) pipeline.watermark = None elif self.is_flux: + # Under FSDP, pass the wrapped transformer directly (don't unwrap — it + # would strip FSDP sharding). TEs are not on GPU under FSDP; pass None + # since we always use pre-cached prompt_embeds. + transformer = self.unet if use_fsdp else unwrap_model(self.unet) + te0 = None if use_fsdp else unwrap_model(self.text_encoder[0]) + te1 = None if use_fsdp else unwrap_model(self.text_encoder[1]) + tok0 = None if use_fsdp else self.tokenizer[0] + tok1 = None if use_fsdp else self.tokenizer[1] + if self.model_config.use_flux_cfg: pipeline = FluxWithCFGPipeline( vae=self.vae, - transformer=unwrap_model(self.unet), - text_encoder=unwrap_model(self.text_encoder[0]), - text_encoder_2=unwrap_model(self.text_encoder[1]), - tokenizer=self.tokenizer[0], - tokenizer_2=self.tokenizer[1], + transformer=transformer, + text_encoder=te0, + text_encoder_2=te1, + tokenizer=tok0, + tokenizer_2=tok1, scheduler=noise_scheduler, **extra_args ) @@ -1267,25 +1284,25 @@ def generate_images( Pipe = FluxAdvancedControlPipeline extra_args['do_inpainting'] = self.adapter.config.has_inpainting_input extra_args['num_controls'] = self.adapter.config.num_control_images - + pipeline = Pipe( vae=self.vae, - transformer=unwrap_model(self.unet), - text_encoder=unwrap_model(self.text_encoder[0]), - text_encoder_2=unwrap_model(self.text_encoder[1]), - tokenizer=self.tokenizer[0], - tokenizer_2=self.tokenizer[1], + transformer=transformer, + text_encoder=te0, + text_encoder_2=te1, + tokenizer=tok0, + tokenizer_2=tok1, scheduler=noise_scheduler, **extra_args ) - + pipeline.watermark = None elif self.is_lumina2: pipeline = Lumina2Pipeline( vae=self.vae, transformer=self.unet, - text_encoder=self.text_encoder, - tokenizer=self.tokenizer, + text_encoder=None if use_fsdp else self.text_encoder, + tokenizer=None if use_fsdp else self.tokenizer, scheduler=noise_scheduler, **extra_args ) @@ -1293,12 +1310,12 @@ def generate_images( pipeline = Pipe( vae=self.vae, transformer=self.unet, - text_encoder=self.text_encoder[0], - text_encoder_2=self.text_encoder[1], - text_encoder_3=self.text_encoder[2], - tokenizer=self.tokenizer[0], - tokenizer_2=self.tokenizer[1], - tokenizer_3=self.tokenizer[2], + text_encoder=None if use_fsdp else self.text_encoder[0], + text_encoder_2=None if use_fsdp else self.text_encoder[1], + text_encoder_3=None if use_fsdp else self.text_encoder[2], + tokenizer=None if use_fsdp else self.tokenizer[0], + tokenizer_2=None if use_fsdp else self.tokenizer[1], + tokenizer_3=None if use_fsdp else self.tokenizer[2], scheduler=noise_scheduler, **extra_args ) @@ -1306,8 +1323,8 @@ def generate_images( pipeline = PixArtSigmaPipeline( vae=self.vae, transformer=self.unet, - text_encoder=self.text_encoder, - tokenizer=self.tokenizer, + text_encoder=None if use_fsdp else self.text_encoder, + tokenizer=None if use_fsdp else self.tokenizer, scheduler=noise_scheduler, **extra_args ) @@ -1316,8 +1333,8 @@ def generate_images( pipeline = AuraFlowPipeline( vae=self.vae, transformer=self.unet, - text_encoder=self.text_encoder, - tokenizer=self.tokenizer, + text_encoder=None if use_fsdp else self.text_encoder, + tokenizer=None if use_fsdp else self.tokenizer, scheduler=noise_scheduler, **extra_args ) @@ -1326,8 +1343,8 @@ def generate_images( pipeline = Pipe( vae=self.vae, unet=self.unet, - text_encoder=self.text_encoder, - tokenizer=self.tokenizer, + text_encoder=None if use_fsdp else self.text_encoder, + tokenizer=None if use_fsdp else self.tokenizer, scheduler=noise_scheduler, safety_checker=None, feature_extractor=None, @@ -1342,8 +1359,8 @@ def generate_images( pipeline.set_scheduler(sampler) refiner_pipeline = None - if self.refiner_unet: - # build refiner pipeline + if self.refiner_unet and not use_fsdp: + # build refiner pipeline (refiner not supported under FSDP) refiner_pipeline = StableDiffusionXLImg2ImgPipeline( vae=pipeline.vae, unet=self.refiner_unet, @@ -1371,7 +1388,7 @@ def generate_images( if network is not None: assert network.is_active - for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): + for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False, disable=not is_main_process): gen_config = image_configs[i] extra = {} @@ -1454,7 +1471,13 @@ def generate_images( if self.sample_prompts_cache is not None: conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype) unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype) - else: + elif use_fsdp: + raise RuntimeError( + "FSDP sampling requires pre-cached sample prompt embeddings, " + "but sample_prompts_cache is None. Ensure sampling prompts are " + "configured and text embeddings were cached at startup." + ) + else: # encode the prompt ourselves so we can do fun stuff with embeddings if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False @@ -1708,9 +1731,10 @@ def callback_on_step_end(pipe, i, t, callback_kwargs): generator=generator, ).images[0] - gen_config.save_image(img, i) - gen_config.log_image(img, i) - self._after_sample_image(i, len(image_configs)) + if is_main_process: + gen_config.save_image(img, i) + gen_config.log_image(img, i) + self._after_sample_image(i, len(image_configs)) flush() if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): @@ -1727,12 +1751,15 @@ def callback_on_step_end(pipe, i, t, callback_kwargs): if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) - self.restore_device_state() + if not use_fsdp: + self.restore_device_state() if network is not None: network.train() network.multiplier = start_multiplier - self.unet.to(self.device_torch, dtype=self.torch_dtype) + if not use_fsdp: + # Under FSDP, unet is already on device and managed by FSDP — don't manually move + self.unet.to(self.device_torch, dtype=self.torch_dtype) if network.is_merged_in: network.merge_out(merge_multiplier) # self.tokenizer.to(original_device_dict['tokenizer']) @@ -3109,12 +3136,15 @@ def set_device_state_preset(self, device_state_preset: DeviceStatePreset): self.set_device_state(state) def text_encoder_to(self, *args, **kwargs): + if self.text_encoder is None: + return if isinstance(self.text_encoder, list): for encoder in self.text_encoder: encoder.to(*args, **kwargs) else: self.text_encoder.to(*args, **kwargs) - + + def convert_lora_weights_before_save(self, state_dict): # can be overridden in child classes to convert weights before saving return state_dict diff --git a/ui/cron/actions/processQueue.ts b/ui/cron/actions/processQueue.ts index 175f613ed..2e1796e13 100644 --- a/ui/cron/actions/processQueue.ts +++ b/ui/cron/actions/processQueue.ts @@ -3,6 +3,44 @@ import prisma from '../prisma'; import { Job, Queue } from '@prisma/client'; import startJob from './startJob'; +/** + * Parse gpu_ids string into a set of individual GPU IDs. + * e.g., "0,1" -> Set{"0", "1"}, "0" -> Set{"0"} + */ +function parseGpuIds(gpuIds: string): Set { + return new Set( + gpuIds + .split(',') + .map(s => s.trim()) + .filter(s => s.length > 0), + ); +} + +/** + * Check if two gpu_ids strings have any overlapping GPUs. + */ +function gpuIdsOverlap(a: string, b: string): boolean { + const setA = parseGpuIds(a); + for (const gpu of parseGpuIds(b)) { + if (setA.has(gpu)) return true; + } + return false; +} + +/** + * Check if two gpu_ids strings represent the same set of GPUs + * (order-independent equality). + */ +function gpuIdsEqual(a: string, b: string): boolean { + const setA = parseGpuIds(a); + const setB = parseGpuIds(b); + if (setA.size !== setB.size) return false; + for (const gpu of setA) { + if (!setB.has(gpu)) return false; + } + return true; +} + export default async function processQueue() { const queues: Queue[] = await prisma.queue.findMany({ orderBy: { @@ -10,17 +48,37 @@ export default async function processQueue() { }, }); + // Build a set of all occupied GPU IDs from currently running/stopping jobs. + // This prevents multi-GPU jobs from conflicting with single-GPU jobs and vice versa. + const allRunningJobs: Job[] = await prisma.job.findMany({ + where: { + status: { in: ['running', 'stopping'] }, + }, + }); + + const occupiedGpus = new Set(); + for (const rj of allRunningJobs) { + for (const gpu of parseGpuIds(rj.gpu_ids)) { + occupiedGpus.add(gpu); + } + } + for (const queue of queues) { if (!queue.is_running) { - // stop any running jobs first - const runningJobs: Job[] = await prisma.job.findMany({ - where: { - status: 'running', - gpu_ids: queue.gpu_ids, - }, - }); + // stop any running jobs whose GPUs overlap with this queue + const runningJobs: Job[] = allRunningJobs.filter( + j => gpuIdsOverlap(j.gpu_ids, queue.gpu_ids) && j.status === 'running', + ); for (const job of runningJobs) { + // Don't stop a job if it belongs to a different queue that IS running. + // e.g., a stopped queue for GPU "0" must not kill a multi-GPU job on "0,1" + // that is managed by an active queue for "0,1". + const belongsToActiveQueue = queues.some( + q => q.id !== queue.id && q.is_running && gpuIdsOverlap(q.gpu_ids, job.gpu_ids), + ); + if (belongsToActiveQueue) continue; + console.log(`Stopping job ${job.id} on GPU(s) ${job.gpu_ids}`); await prisma.job.update({ where: { id: job.id }, @@ -32,38 +90,53 @@ export default async function processQueue() { } } if (queue.is_running) { - // first see if one is already running, status of running or stopping - const runningJob: Job | null = await prisma.job.findFirst({ - where: { - status: { in: ['running', 'stopping'] }, - gpu_ids: queue.gpu_ids, - }, - }); + // Check if any running job uses GPUs that overlap with this queue + const hasOverlappingRunningJob = allRunningJobs.some( + rj => (rj.status === 'running' || rj.status === 'stopping') && gpuIdsOverlap(rj.gpu_ids, queue.gpu_ids), + ); - if (runningJob) { - // already running, nothing to do - continue; // skip to next queue + if (hasOverlappingRunningJob) { + // GPUs are busy, skip to next queue + continue; } else { - // find the next job in the queue - const nextJob: Job | null = await prisma.job.findFirst({ + // find the next job in the queue whose GPUs exactly match this queue + const queuedJobs: Job[] = await prisma.job.findMany({ where: { status: 'queued', - gpu_ids: queue.gpu_ids, }, orderBy: { queue_position: 'asc', }, }); + const nextJob: Job | null = queuedJobs.find(j => gpuIdsEqual(j.gpu_ids, queue.gpu_ids)) ?? null; if (nextJob) { - console.log(`Starting job ${nextJob.id} on GPU(s) ${nextJob.gpu_ids}`); - await startJob(nextJob.id); + // Verify all GPUs needed by this job are currently free + const jobGpus = parseGpuIds(nextJob.gpu_ids); + const allFree = [...jobGpus].every(gpu => !occupiedGpus.has(gpu)); + + if (allFree) { + console.log(`Starting job ${nextJob.id} on GPU(s) ${nextJob.gpu_ids}`); + await startJob(nextJob.id); + // Mark these GPUs as occupied for remaining queue iterations + for (const gpu of jobGpus) { + occupiedGpus.add(gpu); + } + } else { + console.log(`Job ${nextJob.id} needs GPU(s) ${nextJob.gpu_ids} but some are occupied, skipping`); + } } else { - // no more jobs, stop the queue - console.log(`No more jobs in queue for GPU(s) ${queue.gpu_ids}, stopping queue`); - await prisma.queue.update({ - where: { id: queue.id }, - data: { is_running: false }, - }); + // No queued jobs for this queue. Only auto-stop if there are also no + // running/stopping jobs that belong to THIS queue — otherwise the queue + // must stay active so that stopped single-GPU queues don't kill its job + // (the belongsToActiveQueue check relies on the queue being is_running). + const hasActiveJob = allRunningJobs.some(rj => gpuIdsEqual(rj.gpu_ids, queue.gpu_ids)); + if (!hasActiveJob) { + console.log(`No more jobs in queue for GPU(s) ${queue.gpu_ids}, stopping queue`); + await prisma.queue.update({ + where: { id: queue.id }, + data: { is_running: false }, + }); + } } } } diff --git a/ui/cron/actions/startJob.ts b/ui/cron/actions/startJob.ts index a6e3983bb..8778e53dc 100644 --- a/ui/cron/actions/startJob.ts +++ b/ui/cron/actions/startJob.ts @@ -1,11 +1,53 @@ import prisma from '../prisma'; import { Job } from '@prisma/client'; -import { spawn } from 'child_process'; +import { spawn, execSync } from 'child_process'; import path from 'path'; import fs from 'fs'; import { TOOLKIT_ROOT, getTrainingFolder, getHFToken } from '../paths'; const isWindows = process.platform === 'win32'; +/** + * Find a free port by probing with Python's socket module. + * Falls back to a hash-based port if the probe fails. + */ +function findFreePort(pythonPath: string, fallbackSeed: string): number { + try { + const port = execSync( + `"${pythonPath.replace(/"/g, '\\"')}" -c "import socket; s=socket.socket(); s.bind(('',0)); print(s.getsockname()[1]); s.close()"`, + { timeout: 5000, encoding: 'utf-8' }, + ).trim(); + const parsed = parseInt(port, 10); + if (isNaN(parsed)) throw new Error(`Invalid port: ${port}`); + return parsed; + } catch { + // Fallback: hash-based port in range 29500-39999 + let hash = 0; + for (let i = 0; i < fallbackSeed.length; i++) { + hash = ((hash << 5) - hash + fallbackSeed.charCodeAt(i)) | 0; + } + return 29500 + (Math.abs(hash) % 10500); + } +} + +/** + * Find the accelerate binary in the venv. + */ +function findAcceleratePath(): string | null { + const venvDirs = ['.venv', 'venv']; + for (const venv of venvDirs) { + const venvPath = path.join(TOOLKIT_ROOT, venv); + if (fs.existsSync(venvPath)) { + const accelPath = isWindows + ? path.join(venvPath, 'Scripts', 'accelerate.exe') + : path.join(venvPath, 'bin', 'accelerate'); + if (fs.existsSync(accelPath)) { + return accelPath; + } + } + } + return null; +} + const startAndWatchJob = (job: Job) => { // starts and watches the job asynchronously return new Promise(async (resolve, reject) => { @@ -78,31 +120,90 @@ const startAndWatchJob = (job: Job) => { info: `Error launching job: run.py not found`, }, }); + resolve(); return; } + // Determine if this is a multi-GPU distributed job + const gpuIdList = job.gpu_ids + .split(',') + .map(s => s.trim()) + .filter(s => s.length > 0); + const isMultiGPU = gpuIdList.length > 1; + const additionalEnv: any = { AITK_JOB_ID: jobID, CUDA_DEVICE_ORDER: 'PCI_BUS_ID', - CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`, IS_AI_TOOLKIT_UI: '1', + HF_HOME: process.env.HF_HOME || path.join(process.env.HOME || '/root', '.cache', 'huggingface'), + HF_HUB_ENABLE_HF_TRANSFER: '0', + HF_HUB_DISABLE_XET: '1', + HF_HUB_DOWNLOAD_TIMEOUT: '300', }; + // For multi-GPU on Linux, accelerate launch --gpu_ids handles device assignment. + // Setting CUDA_VISIBLE_DEVICES alongside --gpu_ids causes conflicts (Accelerate #1848). + // On Windows, accelerate launch is not supported, so always set CUDA_VISIBLE_DEVICES. + if (!isMultiGPU || isWindows) { + additionalEnv.CUDA_VISIBLE_DEVICES = `${job.gpu_ids}`; + } + // HF_TOKEN const hfToken = await getHFToken(); if (hfToken && hfToken.trim() !== '') { additionalEnv.HF_TOKEN = hfToken; } - // Add the --log argument to the command - const args = [runFilePath, configPath, '--log', logPath]; - try { - let subprocess; + let childProcess; - if (isWindows) { + if (isMultiGPU && !isWindows) { + // Multi-GPU distributed training via accelerate launch + const acceleratePath = findAcceleratePath(); + if (!acceleratePath) { + console.error('accelerate binary not found in venv'); + await prisma.job.update({ + where: { id: jobID }, + data: { + status: 'error', + info: 'Error launching distributed job: accelerate binary not found in venv', + }, + }); + resolve(); + return; + } + + const masterPort = findFreePort(pythonPath, jobID); + const numProcesses = gpuIdList.length; + + const launchArgs = [ + 'launch', + `--num_processes=${numProcesses}`, + `--gpu_ids=${gpuIdList.join(',')}`, + `--main_process_port=${masterPort}`, + '--mixed_precision=no', // precision handled by toolkit config + runFilePath, + configPath, + '--log', + logPath, + ]; + + console.log(`Distributed launch: ${acceleratePath} ${launchArgs.join(' ')}`); + console.log(` GPUs: ${gpuIdList.join(',')}, port: ${masterPort}`); + + childProcess = spawn(acceleratePath, launchArgs, { + detached: true, + stdio: 'ignore', + env: { + ...process.env, + ...additionalEnv, + }, + cwd: TOOLKIT_ROOT, + }); + } else if (isWindows) { // Spawn Python directly on Windows so the process can survive parent exit - subprocess = spawn(pythonPath, args, { + const args = [runFilePath, configPath, '--log', logPath]; + childProcess = spawn(pythonPath, args, { env: { ...process.env, ...additionalEnv, @@ -113,8 +214,9 @@ const startAndWatchJob = (job: Job) => { stdio: 'ignore', // don't tie stdio to parent }); } else { - // For non-Windows platforms, fully detach and ignore stdio so it survives daemon-like - subprocess = spawn(pythonPath, args, { + // Single-GPU: existing path, spawn python directly + const args = [runFilePath, configPath, '--log', logPath]; + childProcess = spawn(pythonPath, args, { detached: true, stdio: 'ignore', env: { @@ -125,13 +227,23 @@ const startAndWatchJob = (job: Job) => { }); } - // Save the PID to the database and a file for future management (stop/inspect) - const pid = subprocess.pid ?? null; + // Important: let the child run independently of this Node process. + if (childProcess.unref) { + childProcess.unref(); + } + + // Write pid to database and file for future management (stop/inspect). + // For distributed jobs, this is the launcher PID (process group leader). + const pid = childProcess.pid ?? null; if (pid != null) { - await prisma.job.update({ - where: { id: jobID }, - data: { pid }, - }); + try { + await prisma.job.update({ + where: { id: jobID }, + data: { pid }, + }); + } catch (e) { + console.error('Error updating pid in database:', e); + } } try { fs.writeFileSync(path.join(trainingFolder, 'pid.txt'), String(pid ?? ''), { flag: 'w' }); @@ -139,11 +251,6 @@ const startAndWatchJob = (job: Job) => { console.error('Error writing pid file:', e); } - // Important: let the child run independently of this Node process. - if (subprocess.unref) { - subprocess.unref(); - } - // (No stdout/stderr listeners — logging should go to --log handled by your Python) // (No monitoring loop — the whole point is to let it live past this worker) } catch (error: any) { @@ -157,6 +264,7 @@ const startAndWatchJob = (job: Job) => { info: `Error launching job: ${error?.message || 'Unknown error'}`, }, }); + resolve(); return; } // Resolve the promise immediately after starting the process diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index 4abd5f4ca..65617a8b3 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -19,7 +19,6 @@ import SampleControlImage from '@/components/SampleControlImage'; import { FlipHorizontal2, FlipVertical2 } from 'lucide-react'; import { handleModelArchChange } from './utils'; import { IoFlaskSharp } from 'react-icons/io5'; -import { isMac } from '@/helpers/basic'; type Props = { jobConfig: JobConfig; @@ -31,7 +30,6 @@ type Props = { setGpuIDs: (value: string | null) => void; gpuList: any; datasetOptions: any; - isLoading?: boolean; }; const isDev = process.env.NODE_ENV === 'development'; @@ -46,7 +44,6 @@ export default function SimpleJob({ setGpuIDs, gpuList, datasetOptions, - isLoading, }: Props) { const modelArch = useMemo(() => { return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch; @@ -147,22 +144,9 @@ export default function SimpleJob({ return newQuantizationOptions; }, [modelArch]); - const showGPUSelect = !isMac(); - return ( <> -
- {isLoading && ( -
-
-
- Loading... -
-
- )} +
- {showGPUSelect && ( - setGpuIDs(value)} - options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))} - /> - )} +
+ +
+ {gpuList.map((gpu: any) => { + const gpuId = `${gpu.index}`; + const selectedIds = (gpuIDs || '0').split(',').map((s: string) => s.trim()); + const isSelected = selectedIds.includes(gpuId); + return ( + + ); + })} +
+ {(gpuIDs || '').includes(',') && ( +

+ FSDP v2: model sharded across {(gpuIDs || '').split(',').length} GPUs +

+ )} +
{disableSections.includes('trigger_word') ? null : ( setJobConfig(value, 'config.process[0].model.model_kwargs.match_target_res')} /> )} - {modelArch?.additionalSections?.includes('model.layer_offloading') && !isMac() && ( + {modelArch?.additionalSections?.includes('model.layer_offloading') && ( <> - {!disableSections.includes('train.unload_text_encoder') && ( - { - setJobConfig(value, 'config.process[0].train.unload_text_encoder'); - if (value) { - setJobConfig(false, 'config.process[0].train.cache_text_embeddings'); - } - }} - /> - )} - { - setJobConfig(value, 'config.process[0].train.cache_text_embeddings'); - if (value) { - setJobConfig(false, 'config.process[0].train.unload_text_encoder'); - } - }} - /> + {(() => { + const isFSDP = (gpuIDs || '').includes(','); + const unloadTE = isFSDP || (jobConfig.config.process[0].train.unload_text_encoder || false); + const cacheEmbeds = unloadTE || (jobConfig.config.process[0].train.cache_text_embeddings || false); + return ( + <> + {!disableSections.includes('train.unload_text_encoder') && ( + { + setJobConfig(value, 'config.process[0].train.unload_text_encoder'); + if (value) { + setJobConfig(true, 'config.process[0].train.cache_text_embeddings'); + } + }} + /> + )} + { + setJobConfig(value, 'config.process[0].train.cache_text_embeddings'); + }} + /> + + ); + })()}
diff --git a/ui/src/app/jobs/new/page.tsx b/ui/src/app/jobs/new/page.tsx index c106a4b0b..a35358a9e 100644 --- a/ui/src/app/jobs/new/page.tsx +++ b/ui/src/app/jobs/new/page.tsx @@ -152,12 +152,22 @@ export default function TrainingForm() { if (status === 'saving') return; setStatus('saving'); + // Enforce FSDP → unload_text_encoder → cache_text_embeddings cascade + const configToSave = objectCopy(jobConfig); + const isFSDP = (gpuIDs || '').includes(','); + if (isFSDP) { + configToSave.config.process[0].train.unload_text_encoder = true; + } + if (configToSave.config.process[0].train.unload_text_encoder) { + configToSave.config.process[0].train.cache_text_embeddings = true; + } + apiClient .post('/api/jobs', { id: runId, - name: jobConfig.config.name, + name: configToSave.config.name, gpu_ids: gpuIDs, - job_config: jobConfig, + job_config: configToSave, }) .then(res => { setStatus('success'); @@ -201,12 +211,35 @@ export default function TrainingForm() {
{showAdvancedView && ( <> -
- setGpuIDs(value)} - options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))} - /> +
+ {gpuList.map((gpu: any) => { + const gpuId = `${gpu.index}`; + const selectedIds = (gpuIDs || '0').split(',').map((s: string) => s.trim()); + const isSelected = selectedIds.includes(gpuId); + return ( + + ); + })}
@@ -311,7 +344,6 @@ export default function TrainingForm() { setGpuIDs={setGpuIDs} gpuList={gpuList} datasetOptions={datasetOptions} - isLoading={!isSettingsLoaded || !isGPUInfoLoaded || datasetFetchStatus !== 'success'} /> diff --git a/ui/src/components/JobsTable.tsx b/ui/src/components/JobsTable.tsx index 7548ce8c8..907fdb196 100644 --- a/ui/src/components/JobsTable.tsx +++ b/ui/src/components/JobsTable.tsx @@ -71,7 +71,7 @@ export default function JobsTable({ onlyActive = false }: JobsTableProps) { render: row => { let statusClass = 'text-gray-400'; if (row.status === 'completed') statusClass = 'text-green-400'; - if (row.status === 'failed') statusClass = 'text-red-400'; + if (row.status === 'error') statusClass = 'text-red-400'; if (row.status === 'running') statusClass = 'text-blue-400'; return {row.status}; @@ -96,33 +96,45 @@ export default function JobsTable({ onlyActive = false }: JobsTableProps) { if (!isGPUInfoLoaded) return {}; if (jobs.length === 0) return {}; let jd: { [key: string]: { name: string; jobs: Job[] } } = {}; + // Create entries for each individual GPU gpuList.forEach(gpu => { jd[`${gpu.index}`] = { name: `${gpu.name}`, jobs: [] }; }); jd['Idle'] = { name: 'Idle', jobs: [] }; jobs.forEach(job => { - const gpu = gpuList.find(gpu => job.gpu_ids?.split(',').includes(gpu.index.toString())) as GpuInfo; - const key = `${gpu?.index || '0'}`; - if (['queued', 'running', 'stopping'].includes(job.status) && key in jd) { + if (!['queued', 'running', 'error'].includes(job.status)) { + jd['Idle'].jobs.push(job); + return; + } + const gpuIds = (job.gpu_ids || '0') + .split(',') + .map((s: string) => s.trim()) + .filter((s: string) => s.length > 0); + if (gpuIds.length > 1) { + // Multi-GPU job: group under the combined gpu_ids key (numerically sorted for consistency) + const key = gpuIds.sort((a: string, b: string) => parseInt(a) - parseInt(b)).join(','); + if (!(key in jd)) { + const gpuName = gpuList.find(g => g.index.toString() === gpuIds[0])?.name || 'GPU'; + jd[key] = { name: `${gpuName} (x${gpuIds.length})`, jobs: [] }; + } jd[key].jobs.push(job); } else { - jd['Idle'].jobs.push(job); + const key = gpuIds[0]; + if (key in jd) { + jd[key].jobs.push(job); + } else { + jd['Idle'].jobs.push(job); + } } }); // sort the queued/running jobs by queue position Object.keys(jd).forEach(key => { - if (key === 'Idle') { - jd[key].jobs.sort((a, b) => { - // sort by updated_at, newest first - return new Date(b.updated_at).getTime() - new Date(a.updated_at).getTime(); - }); - } else { - jd[key].jobs.sort((a, b) => { - if (a.queue_position === null) return 1; - if (b.queue_position === null) return -1; - return a.queue_position - b.queue_position; - }); - } + if (key === 'Idle') return; + jd[key].jobs.sort((a, b) => { + if (a.queue_position === null) return 1; + if (b.queue_position === null) return -1; + return a.queue_position - b.queue_position; + }); }); return jd; }, [jobs, queues, isGPUInfoLoaded]); @@ -138,7 +150,22 @@ export default function JobsTable({ onlyActive = false }: JobsTableProps) { .sort() .filter(key => key !== 'Idle') .map(gpuKey => { - const queue = queues.find(q => `${q.gpu_ids}` === gpuKey) as Queue; + // Use set-based comparison: gpu_ids may be stored in different order than the sorted key + const gpuKeySet = new Set( + gpuKey + .split(',') + .map(s => s.trim()) + .filter(s => s.length > 0), + ); + const queue = queues.find(q => { + const qSet = new Set( + `${q.gpu_ids}` + .split(',') + .map(s => s.trim()) + .filter(s => s.length > 0), + ); + return qSet.size === gpuKeySet.size && [...gpuKeySet].every(id => qSet.has(id)); + }) ?? null; return (

{jobsDict[gpuKey].name}

- # {queue?.gpu_ids} + + # {queue?.gpu_ids ?? gpuKey} +
{queue?.is_running ? ( diff --git a/ui/src/components/formInputs.tsx b/ui/src/components/formInputs.tsx index 36c2ae5f3..f02220eb3 100644 --- a/ui/src/components/formInputs.tsx +++ b/ui/src/components/formInputs.tsx @@ -26,12 +26,24 @@ export interface InputProps { export interface TextInputProps extends InputProps { value: string; onChange: (value: string) => void; + onBlur?: () => void; type?: 'text' | 'password'; disabled?: boolean; } export const TextInput = forwardRef((props: TextInputProps, ref) => { - const { label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null } = props; + const { + label, + value, + onChange, + onBlur, + placeholder, + required, + disabled, + type = 'text', + className, + docKey = null, + } = props; let { doc } = props; if (!doc && docKey) { doc = getDoc(docKey); @@ -55,6 +67,7 @@ export const TextInput = forwardRef((props: Te onChange={e => { if (!disabled) onChange(e.target.value); }} + onBlur={onBlur} className={`${inputClasses} ${disabled ? 'opacity-30 cursor-not-allowed' : ''}`} placeholder={placeholder} required={required} diff --git a/ui/src/docs.tsx b/ui/src/docs.tsx index 6a1c0b5ad..9f72286ca 100644 --- a/ui/src/docs.tsx +++ b/ui/src/docs.tsx @@ -17,8 +17,8 @@ const docs: { [key: string]: ConfigDoc } = { title: 'GPU ID', description: ( <> - This is the GPU that will be used for training. Only one GPU can be used per job at a time via the UI currently. - However, you can start multiple jobs in parallel, each using a different GPU. + Select one or more GPUs for training. Multiple GPUs use FSDP v2 to shard the model across devices. + You can also run multiple single-GPU jobs in parallel, each using a different GPU. ), }, @@ -157,8 +157,9 @@ const docs: { [key: string]: ConfigDoc } = { title: 'Unload Text Encoder', description: ( <> - Unloading text encoder will cache the trigger word and the sample prompts and unload the text encoder from the - GPU. Captions in for the dataset will be ignored + Caches the trigger word and sample prompts, then unloads the text encoder from the GPU. + Implies Cache Text Embeddings. Captions in the dataset will be ignored. + Automatically enabled when using FSDP. ), }, @@ -166,11 +167,10 @@ const docs: { [key: string]: ConfigDoc } = { title: 'Cache Text Embeddings', description: ( <> - (experimental) -
- Caching text embeddings will process and cache all the text embeddings from the text encoder to the disk. The - text encoder will be unloaded from the GPU. This does not work with things that dynamically change the prompt - such as trigger words, caption dropout, etc. + Caches all text embeddings from the text encoder to disk. On subsequent runs, the text encoder + is skipped entirely (never loaded or quantized), saving significant VRAM and startup time. + Does not work with features that dynamically change the prompt such as caption dropout. + Automatically enabled by FSDP and Unload Text Encoder. ), },