Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion config/examples/train_lora_chroma_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
3 changes: 2 additions & 1 deletion config/examples/train_lora_flex2_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
3 changes: 2 additions & 1 deletion config/examples/train_lora_flex_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
3 changes: 2 additions & 1 deletion config/examples/train_lora_flux_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
3 changes: 2 additions & 1 deletion config/examples/train_lora_flux_kontext_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
5 changes: 3 additions & 2 deletions config/examples/train_lora_flux_schnell_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
# hf_private: true #whether the repo is private or public
datasets:
# datasets are a folder of images. captions need to be txt files with the same name as the image
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
Expand Down
3 changes: 2 additions & 1 deletion config/examples/train_lora_omnigen2_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
3 changes: 2 additions & 1 deletion config/examples/train_lora_sd35_large_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
3 changes: 2 additions & 1 deletion config/examples/train_lora_wan21_14b_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
3 changes: 2 additions & 1 deletion config/examples/train_lora_wan21_1b_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
87 changes: 78 additions & 9 deletions jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from torch.utils.data import DataLoader
import torch
import torch.backends.cuda
from huggingface_hub import HfApi, Repository, interpreter_login
from huggingface_hub.utils import HfFolder
from huggingface_hub import HfApi, get_token, interpreter_login
from huggingface_hub.utils import HfHubHTTPError
from toolkit.memory_management import MemoryManager

from toolkit.basic import value_map
Expand Down Expand Up @@ -482,8 +482,66 @@ def clean_up_saves(self):
return latest_item

def post_save_hook(self, save_path):
# override in subclass
pass
if self.save_config.push_to_hub and self.save_config.push_to_hub_every_save:
# Unlike the end-of-training push (which can prompt interactively via
# interpreter_login), intermediate pushes must not block the training
# loop, so we silently check for an existing token and bail out if
# none is found. The result is cached to avoid repeated checks.
if getattr(self, "_hub_push_disabled", False):
return
if "HF_TOKEN" not in os.environ and get_token() is None:
print_acc("No HF token available, skipping intermediate Hub pushes")
self._hub_push_disabled = True
return
if not os.path.exists(save_path):
print_acc(f"Checkpoint not found, skipping Hub push: {save_path}")
return
repo_id = self.save_config.hf_repo_id
api = HfApi()
if not getattr(self, "_hub_repo_created", False):
try:
api.create_repo(repo_id, private=self.save_config.hf_private, exist_ok=True)
self._hub_repo_created = True
except HfHubHTTPError as e:
print_acc(f"Failed to create Hub repo '{repo_id}': {e}")
print_acc(traceback.format_exc())
status = getattr(e.response, "status_code", None)
if status in (401, 403):
print_acc("Disabling intermediate Hub pushes for this run (auth/permission error)")
self._hub_push_disabled = True
return
except Exception as e:
print_acc(f"Failed to create Hub repo '{repo_id}': {e}")
print_acc(traceback.format_exc())
return
try:
# Upload only the new checkpoint, not the entire save_root.
# The full folder (with README) is pushed at end of training.
if os.path.isdir(save_path):
api.upload_folder(
repo_id=repo_id,
folder_path=save_path,
path_in_repo=os.path.basename(save_path),
repo_type="model",
)
else:
api.upload_file(
repo_id=repo_id,
path_or_fileobj=save_path,
path_in_repo=os.path.basename(save_path),
repo_type="model",
)
print_acc(f"Pushed checkpoint to Hub: {os.path.basename(save_path)}")
except HfHubHTTPError as e:
print_acc(f"Failed to upload checkpoint to Hub: {e}")
print_acc(traceback.format_exc())
status = getattr(e.response, "status_code", None)
if status in (401, 403):
print_acc("Disabling intermediate Hub pushes for this run (auth/permission error)")
self._hub_push_disabled = True
except Exception as e:
print_acc(f"Failed to upload checkpoint to Hub: {e}")
print_acc(traceback.format_exc())

def done_hook(self):
pass
Expand Down Expand Up @@ -600,6 +658,7 @@ def save(self, step=None):
elif self.adapter_config.type == 'control_net':
# save in diffusers format
name_or_path = file_path.replace('.safetensors', '')
file_path = name_or_path
# move it to the new dtype and cpu
orig_device = self.adapter.device
orig_dtype = self.adapter.dtype
Expand Down Expand Up @@ -665,6 +724,9 @@ def save(self, step=None):
get_torch_dtype(self.save_config.dtype)
)

# Capture checkpoint path; file_path is reassigned below for SNR/optimizer saves
checkpoint_path = file_path

# save learnable params as json if we have thim
if self.snr_gos:
json_data = {
Expand Down Expand Up @@ -695,7 +757,7 @@ def save(self, step=None):
print_acc("Could not save optimizer")

self.clean_up_saves()
self.post_save_hook(file_path)
self.post_save_hook(checkpoint_path)

if self.ema is not None:
self.ema.train()
Expand Down Expand Up @@ -2404,10 +2466,17 @@ def run(self):
if self.save_config.push_to_hub:
if("HF_TOKEN" not in os.environ):
interpreter_login(new_session=False, write_permission=True)
self.push_to_hub(
repo_id=self.save_config.hf_repo_id,
private=self.save_config.hf_private
)
try:
self.push_to_hub(
repo_id=self.save_config.hf_repo_id,
private=self.save_config.hf_private
)
except Exception as e:
print_acc("=" * 60)
print_acc(f"Failed to push final model to Hub: {e}")
print_acc(traceback.format_exc())
print_acc(f"Model saved locally at: {self.save_root}")
print_acc("=" * 60)
del (
self.sd,
unet,
Expand Down
7 changes: 6 additions & 1 deletion toolkit/config_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@ def __init__(self, **kwargs):
if self.save_format not in ['safetensors', 'diffusers']:
raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}")
self.push_to_hub: bool = kwargs.get("push_to_hub", False)
self.push_to_hub_every_save: bool = kwargs.get("push_to_hub_every_save", False)
self.hf_repo_id: Optional[str] = kwargs.get("hf_repo_id", None)
self.hf_private: Optional[str] = kwargs.get("hf_private", False)
self.hf_private: bool = kwargs.get("hf_private", False)
if self.push_to_hub and not self.hf_repo_id:
raise ValueError("hf_repo_id must be provided when push_to_hub is enabled")
if self.push_to_hub_every_save and not self.push_to_hub:
raise ValueError("push_to_hub must be enabled when push_to_hub_every_save is True")

class LoggingConfig:
def __init__(self, **kwargs):
Expand Down
5 changes: 5 additions & 0 deletions ui/src/app/jobs/new/jobConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ export const defaultJobConfig: JobConfig = {
max_step_saves_to_keep: 4,
save_format: 'diffusers',
push_to_hub: false,
push_to_hub_every_save: false,
},
datasets: [defaultDatasetConfig],
train: {
Expand Down Expand Up @@ -192,6 +193,10 @@ export const migrateJobConfig = (jobConfig: JobConfig): JobConfig => {
delete jobConfig.config.process[0].model.auto_memory;
}

if (jobConfig.config.process[0]?.save && !('push_to_hub_every_save' in jobConfig.config.process[0].save)) {
jobConfig.config.process[0].save.push_to_hub_every_save = false;
}

if (!('logging' in jobConfig.config.process[0])) {
//@ts-ignore
jobConfig.config.process[0].logging = {
Expand Down
1 change: 1 addition & 0 deletions ui/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ export interface SaveConfig {
max_step_saves_to_keep: number;
save_format: string;
push_to_hub: boolean;
push_to_hub_every_save: boolean;
}

export interface DatasetConfig {
Expand Down