Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion config/examples/train_flex_redux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ config:
batch_size: 3
gradient_accumulation: 2

# captions are not needed for this training, we cache a blank proompt and rely on the vision encoder
# captions are not needed for this training, we cache a blank prompt and rely on the vision encoder
# unload_text_encoder now automatically caches all text embeddings before unloading
unload_text_encoder: true

loss_type: "mse"
Expand Down
2 changes: 1 addition & 1 deletion config/examples/train_lora_wan21_14b_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ config:
ema_decay: 0.99
dtype: bf16
# required for 24GB cards
# this will encode your trigger word and use those embeddings for every image in the dataset
# automatically caches all text embeddings before unloading the text encoder
unload_text_encoder: true
model:
# huggingface model name or path
Expand Down
8 changes: 3 additions & 5 deletions config/examples/train_lora_wan22_14b_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,10 @@ config:
# IMPORTANT: this is for Wan 2.2 MOE. It will switch training one stage or the other every this many steps
switch_boundary_every: 10

# required for 24GB cards. You must do either unload_text_encoder or cache_text_embeddings but not both

# this will encode your trigger word and use those embeddings for every image in the dataset, captions will be ignored
# required for 24GB cards
# unload_text_encoder automatically caches all text embeddings before unloading
# either option works — unload_text_encoder implies cache_text_embeddings
# unload_text_encoder: true

# this will cache all captions in your dataset.
cache_text_embeddings: true

model:
Expand Down
172 changes: 100 additions & 72 deletions extensions_built_in/sd_trainer/DiffusionTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import threading
import time
import signal
import torch

AITK_Status = Literal["running", "stopped", "error", "completed"]

Expand All @@ -30,22 +31,29 @@ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):

if self.is_ui_trainer:
self.is_stopping = False
self.is_returning_to_queue = False
# Create a thread pool for database operations
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
# Track all async tasks
self._async_tasks = []
# Initialize the status
self._run_async_operation(self._update_status("running", "Starting"))
self._last_speed_update = 0.0
self._stop_watcher_started = False
# self.start_stop_watcher(interval_sec=2.0)
self.start_stop_watcher(interval_sec=2.0)

def start_stop_watcher(self, interval_sec: float = 5.0):
"""
Start a daemon thread that periodically checks should_stop()
and terminates the process immediately when triggered.
In distributed mode, only rank 0 polls the DB; it kills the
entire process group so all ranks terminate.
"""
if not self.is_ui_trainer:
return
# Only rank 0 should poll the DB to avoid SQLite locking issues
if self.accelerator.num_processes > 1 and not self.accelerator.is_main_process:
return
if getattr(self, "_stop_watcher_started", False):
return
self._stop_watcher_started = True
Expand All @@ -58,26 +66,34 @@ def _stop_watcher_thread(self, interval_sec: float):
while True:
try:
if self.should_stop():
# Mark and update status (non-blocking; uses existing infra)
self.is_stopping = True
self._run_async_operation(
self._update_status("stopped", "Job stopped (remote)")
)
# Best-effort flush pending async ops
try:
asyncio.run(self.wait_for_all_async())
except RuntimeError:
pass
# Try to stop DB thread pool quickly
try:
self.thread_pool.shutdown(wait=False, cancel_futures=True)
except TypeError:
self.thread_pool.shutdown(wait=False)
print("")
print("****************************************************")
print(" Stop signal received; terminating process. ")
print("****************************************************")
os.kill(os.getpid(), signal.SIGINT)
if self.accelerator.num_processes > 1:
# FSDP: don't kill — let end_step_hook broadcast the flag
# so all ranks save a temp checkpoint together.
# Don't shut down thread_pool here; on_error() still needs it.
# Force kill after 5 min as a last-resort fallback.
for _ in range(150):
time.sleep(2)
os.killpg(os.getpgid(os.getpid()), signal.SIGTERM)
else:
# Single GPU: update status, flush, then SIGINT.
self._run_async_operation(
self._update_status("stopped", "Job stopped")
)
try:
asyncio.run(self.wait_for_all_async())
except RuntimeError:
pass
try:
self.thread_pool.shutdown(wait=False, cancel_futures=True)
except TypeError:
self.thread_pool.shutdown(wait=False)
os.kill(os.getpid(), signal.SIGINT)
break # don't loop — avoid repeated signals interrupting save
time.sleep(interval_sec)
except Exception:
time.sleep(interval_sec)
Expand All @@ -91,6 +107,10 @@ def _run_async_operation(self, coro):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

# Prune completed tasks periodically to avoid unbounded growth
if len(self._async_tasks) > 100:
self._async_tasks = [t for t in self._async_tasks if not t.done()]

# Create a task and track it
if loop.is_running():
task = asyncio.run_coroutine_threadsafe(coro, loop)
Expand All @@ -101,40 +121,22 @@ def _run_async_operation(self, coro):
loop.run_until_complete(task)

async def _execute_db_operation(self, operation_func):
"""Execute a database operation in a separate thread with retry on lock."""
"""Execute a database operation in a separate thread to avoid blocking."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.thread_pool, lambda: self._retry_db_operation(operation_func)
)
return await loop.run_in_executor(self.thread_pool, operation_func)

def _db_connect(self):
"""Create a new connection for each operation to avoid locking."""
conn = sqlite3.connect(self.sqlite_db_path, timeout=30.0)
conn = sqlite3.connect(self.sqlite_db_path, timeout=10.0)
conn.isolation_level = None # Enable autocommit mode
return conn

def _retry_db_operation(self, operation_func, max_retries=3, base_delay=2.0):
"""Retry a database operation with exponential backoff on lock errors."""
last_error = None
for attempt in range(max_retries + 1):
try:
return operation_func()
except sqlite3.OperationalError as e:
if "database is locked" in str(e):
last_error = e
if attempt < max_retries:
delay = base_delay * (2 ** attempt) # 2s, 4s, 8s
print(f"[AITK] Database locked (attempt {attempt + 1}/{max_retries + 1}), retrying in {delay:.1f}s...")
time.sleep(delay)
else:
print(f"[AITK] Database locked after {max_retries + 1} attempts, giving up.")
else:
raise
raise last_error

def should_stop(self):
if not self.is_ui_trainer:
return False
# In distributed mode, only rank 0 polls the DB to avoid SQLite locking
if self.accelerator.num_processes > 1 and not self.accelerator.is_main_process:
return False
def _check_stop():
with self._db_connect() as conn:
cursor = conn.cursor()
Expand All @@ -143,11 +145,14 @@ def _check_stop():
stop = cursor.fetchone()
return False if stop is None else stop[0] == 1

return self._retry_db_operation(_check_stop)
return _check_stop()

def should_return_to_queue(self):
if not self.is_ui_trainer:
return False
# In distributed mode, only rank 0 polls the DB to avoid SQLite locking
if self.accelerator.num_processes > 1 and not self.accelerator.is_main_process:
return False
def _check_return_to_queue():
with self._db_connect() as conn:
cursor = conn.cursor()
Expand All @@ -156,11 +161,18 @@ def _check_return_to_queue():
return_to_queue = cursor.fetchone()
return False if return_to_queue is None else return_to_queue[0] == 1

return self._retry_db_operation(_check_return_to_queue)
return _check_return_to_queue()

def maybe_stop(self):
if not self.is_ui_trainer:
return
# In distributed mode, only rank 0 checks stop signals.
# Non-rank-0 processes are terminated via the stop watcher's os.killpg().
# We CANNOT use dist.broadcast() here because maybe_stop() is called from
# rank-0-only code paths (inside sample(), save(), sample_step_hook()) where
# other ranks have already diverged. A broadcast would deadlock.
if self.accelerator.num_processes > 1 and not self.accelerator.is_main_process:
return
if self.should_stop():
self._run_async_operation(
self._update_status("stopped", "Job stopped"))
Expand All @@ -170,6 +182,7 @@ def maybe_stop(self):
self._run_async_operation(
self._update_status("queued", "Job queued"))
self.is_stopping = True
self.is_returning_to_queue = True
raise Exception("Job returning to queue")

async def _update_key(self, key, value):
Expand Down Expand Up @@ -251,28 +264,16 @@ async def wait_for_all_async(self):
def on_error(self, e: Exception):
super(DiffusionTrainer, self).on_error(e)
if self.is_ui_trainer:
try:
if self.accelerator.is_main_process and not self.is_stopping:
if self.accelerator.is_main_process:
if self.is_returning_to_queue:
self.update_status("queued", "Job queued")
elif self.is_stopping:
self.update_status("stopped", "Job stopped")
else:
self.update_status("error", str(e))
self.update_db_key("step", self.last_save_step)
asyncio.run(self.wait_for_all_async())
except Exception as db_err:
print(f"[AITK] Warning: failed to update DB during error handling: {db_err}")
finally:
self.thread_pool.shutdown(wait=True)

def handle_timing_print_hook(self, timing_dict):
if "train_loop" not in timing_dict:
print("train_loop not found in timing_dict", timing_dict)
return
seconds_per_iter = timing_dict["train_loop"]
# determine iter/sec or sec/iter
if seconds_per_iter < 1:
iters_per_sec = 1 / seconds_per_iter
self.update_db_key("speed_string", f"{iters_per_sec:.2f} iter/sec")
else:
self.update_db_key(
"speed_string", f"{seconds_per_iter:.2f} sec/iter")
self.update_db_key("step", self.last_save_step)
asyncio.run(self.wait_for_all_async())
self.thread_pool.shutdown(wait=True)

def done_hook(self):
super(DiffusionTrainer, self).done_hook()
Expand All @@ -286,7 +287,29 @@ def end_step_hook(self):
super(DiffusionTrainer, self).end_step_hook()
if self.is_ui_trainer:
self.update_step()
self.maybe_stop()
# Update speed_string every ~10 seconds using timer's rolling average
train_loop_timings = self.timer.timers.get('train_loop')
if train_loop_timings and (time.time() - self._last_speed_update) >= 10.0:
seconds_per_iter = sum(train_loop_timings) / len(train_loop_timings)
if seconds_per_iter <= 0:
pass
elif seconds_per_iter < 1:
self.update_db_key("speed_string", f"{1 / seconds_per_iter:.2f} iter/sec")
else:
self.update_db_key("speed_string", f"{seconds_per_iter:.2f} sec/iter")
self._last_speed_update = time.time()
# FSDP: broadcast stop flag so all ranks exit together for collective save.
# Only uses the watcher's is_stopping flag (no DB query in hot path).
if self.accelerator.num_processes > 1:
import torch.distributed as dist
flag = 1 if self.is_stopping else 0
stop_tensor = torch.tensor([flag], device=self.accelerator.device)
dist.broadcast(stop_tensor, src=0)
if stop_tensor.item() == 1:
self.is_stopping = True
raise Exception("Job stopped")
else:
self.maybe_stop()

def hook_before_model_load(self):
super().hook_before_model_load()
Expand All @@ -306,8 +329,6 @@ def hook_before_train_loop(self):
self.maybe_stop()
self.update_step()
self.update_status("running", "Training")
self.timer.add_after_print_hook(self.handle_timing_print_hook)

def status_update_hook_func(self, string):
self.update_status("running", string)

Expand All @@ -332,9 +353,16 @@ def sample(self, step=None, is_first=False):
self.maybe_stop()
self.update_status("running", "Training")

def save(self, step=None):
self.maybe_stop()
self.update_status("running", "Saving model")
super().save(step)
self.maybe_stop()
self.update_status("running", "Training")
def save(self, step=None, is_temp=False):
if not is_temp:
# Under FSDP, skip pre-save maybe_stop() — it could throw on rank 0
# before the collective save, deadlocking other ranks in full_tensor().
# The post-save check and stop watcher handle stop signals safely.
if not self.use_fsdp:
self.maybe_stop()
self.update_status("running", "Saving model")
super().save(step, is_temp=is_temp)
if not is_temp:
if not self.use_fsdp:
self.maybe_stop()
self.update_status("running", "Training")
Loading