Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion config/examples/train_lora_chroma_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
3 changes: 2 additions & 1 deletion config/examples/train_lora_flex2_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
3 changes: 2 additions & 1 deletion config/examples/train_lora_flex_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
3 changes: 2 additions & 1 deletion config/examples/train_lora_flux_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
3 changes: 2 additions & 1 deletion config/examples/train_lora_flux_kontext_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
5 changes: 3 additions & 2 deletions config/examples/train_lora_flux_schnell_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
# hf_private: true #whether the repo is private or public
datasets:
# datasets are a folder of images. captions need to be txt files with the same name as the image
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
Expand Down
3 changes: 2 additions & 1 deletion config/examples/train_lora_omnigen2_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
3 changes: 2 additions & 1 deletion config/examples/train_lora_sd35_large_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
3 changes: 2 additions & 1 deletion config/examples/train_lora_wan21_14b_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
3 changes: 2 additions & 1 deletion config/examples/train_lora_wan21_1b_24gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ config:
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
push_to_hub_every_save: false # push to HuggingFace Hub every time a checkpoint is saved (requires push_to_hub: true)
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
Expand Down
86 changes: 77 additions & 9 deletions jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from torch.utils.data import DataLoader
import torch
import torch.backends.cuda
from huggingface_hub import HfApi, Repository, interpreter_login
from huggingface_hub.utils import HfFolder
from huggingface_hub import HfApi, get_token, interpreter_login
from huggingface_hub.utils import HfHubHTTPError
from toolkit.memory_management import MemoryManager

from toolkit.basic import value_map
Expand Down Expand Up @@ -482,8 +482,66 @@ def clean_up_saves(self):
return latest_item

def post_save_hook(self, save_path):
# override in subclass
pass
if self.save_config.push_to_hub and self.save_config.push_to_hub_every_save:
# Unlike the end-of-training push (which can prompt interactively via
# interpreter_login), intermediate pushes must not block the training
# loop, so we silently check for an existing token and bail out if
# none is found. The result is cached to avoid repeated checks.
if getattr(self, "_hub_push_disabled", False):
return
if "HF_TOKEN" not in os.environ and get_token() is None:
print_acc("No HF token available, skipping intermediate Hub pushes")
self._hub_push_disabled = True
return
if not os.path.exists(save_path):
print_acc(f"Checkpoint not found, skipping Hub push: {save_path}")
return
repo_id = self.save_config.hf_repo_id
api = HfApi()
if not getattr(self, "_hub_repo_created", False):
try:
api.create_repo(repo_id, private=self.save_config.hf_private, exist_ok=True)
self._hub_repo_created = True
except HfHubHTTPError as e:
print_acc(f"Failed to create Hub repo '{repo_id}': {e}")
print_acc(traceback.format_exc())
status = getattr(e.response, "status_code", None)
if status in (401, 403):
print_acc("Disabling intermediate Hub pushes for this run (auth/permission error)")
self._hub_push_disabled = True
return
except Exception as e:
print_acc(f"Failed to create Hub repo '{repo_id}': {e}")
print_acc(traceback.format_exc())
return
try:
# Upload only the new checkpoint, not the entire save_root.
# The full folder (with README) is pushed at end of training.
if os.path.isdir(save_path):
api.upload_folder(
repo_id=repo_id,
folder_path=save_path,
path_in_repo=os.path.basename(save_path),
repo_type="model",
)
else:
api.upload_file(
repo_id=repo_id,
path_or_fileobj=save_path,
path_in_repo=os.path.basename(save_path),
repo_type="model",
)
print_acc(f"Pushed checkpoint to Hub: {os.path.basename(save_path)}")
except HfHubHTTPError as e:
print_acc(f"Failed to upload checkpoint to Hub: {e}")
print_acc(traceback.format_exc())
status = getattr(e.response, "status_code", None)
if status in (401, 403):
print_acc("Disabling intermediate Hub pushes for this run (auth/permission error)")
self._hub_push_disabled = True
except Exception as e:
print_acc(f"Failed to upload checkpoint to Hub: {e}")
print_acc(traceback.format_exc())

def done_hook(self):
pass
Expand Down Expand Up @@ -665,6 +723,9 @@ def save(self, step=None):
get_torch_dtype(self.save_config.dtype)
)

# Capture checkpoint path; file_path is reassigned below for SNR/optimizer saves
checkpoint_path = file_path

# save learnable params as json if we have thim
if self.snr_gos:
json_data = {
Expand Down Expand Up @@ -695,7 +756,7 @@ def save(self, step=None):
print_acc("Could not save optimizer")

self.clean_up_saves()
self.post_save_hook(file_path)
self.post_save_hook(checkpoint_path)

if self.ema is not None:
self.ema.train()
Expand Down Expand Up @@ -2404,10 +2465,17 @@ def run(self):
if self.save_config.push_to_hub:
if("HF_TOKEN" not in os.environ):
interpreter_login(new_session=False, write_permission=True)
self.push_to_hub(
repo_id=self.save_config.hf_repo_id,
private=self.save_config.hf_private
)
try:
self.push_to_hub(
repo_id=self.save_config.hf_repo_id,
private=self.save_config.hf_private
)
except Exception as e:
print_acc("=" * 60)
print_acc(f"Failed to push final model to Hub: {e}")
print_acc(traceback.format_exc())
print_acc(f"Model saved locally at: {self.save_root}")
print_acc("=" * 60)
del (
self.sd,
unet,
Expand Down
7 changes: 6 additions & 1 deletion toolkit/config_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@ def __init__(self, **kwargs):
if self.save_format not in ['safetensors', 'diffusers']:
raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}")
self.push_to_hub: bool = kwargs.get("push_to_hub", False)
self.push_to_hub_every_save: bool = kwargs.get("push_to_hub_every_save", False)
self.hf_repo_id: Optional[str] = kwargs.get("hf_repo_id", None)
self.hf_private: Optional[str] = kwargs.get("hf_private", False)
self.hf_private: bool = kwargs.get("hf_private", False)
if self.push_to_hub and not self.hf_repo_id:
raise ValueError("hf_repo_id must be provided when push_to_hub is enabled")
if self.push_to_hub_every_save and not self.push_to_hub:
raise ValueError("push_to_hub must be enabled when push_to_hub_every_save is True")

class LoggingConfig:
def __init__(self, **kwargs):
Expand Down
5 changes: 5 additions & 0 deletions ui/src/app/jobs/new/jobConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ export const defaultJobConfig: JobConfig = {
max_step_saves_to_keep: 4,
save_format: 'diffusers',
push_to_hub: false,
push_to_hub_every_save: false,
},
datasets: [defaultDatasetConfig],
train: {
Expand Down Expand Up @@ -192,6 +193,10 @@ export const migrateJobConfig = (jobConfig: JobConfig): JobConfig => {
delete jobConfig.config.process[0].model.auto_memory;
}

if (jobConfig.config.process[0]?.save && !('push_to_hub_every_save' in jobConfig.config.process[0].save)) {
jobConfig.config.process[0].save.push_to_hub_every_save = false;
}

if (!('logging' in jobConfig.config.process[0])) {
//@ts-ignore
jobConfig.config.process[0].logging = {
Expand Down
1 change: 1 addition & 0 deletions ui/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ export interface SaveConfig {
max_step_saves_to_keep: number;
save_format: string;
push_to_hub: boolean;
push_to_hub_every_save: boolean;
}

export interface DatasetConfig {
Expand Down
Loading