Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions config/examples/train_lora_flux_24gb_mlflow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
---
# Example: Training a FLUX LoRA with MLflow experiment tracking
#
# Prerequisites:
# pip install "mlflow>=3,<4"
#
# To view results, start the MLflow UI:
# mlflow ui # local (default: http://localhost:5000)
# mlflow server --host 0.0.0.0 # remote
#
# You can also point to an existing tracking server:
# logging.mlflow_tracking_uri: "http://your-server:5000"

job: extension
config:
name: "my_flux_lora_mlflow"
process:
- type: 'sd_trainer'
training_folder: "output"
device: cuda:0
network:
type: "lora"
linear: 16
linear_alpha: 16
save:
dtype: float16
save_every: 250
max_step_saves_to_keep: 4
datasets:
- folder_path: "/path/to/images/folder"
caption_ext: "txt"
caption_dropout_rate: 0.05
shuffle_tokens: false
cache_latents_to_disk: true
resolution: [ 512, 768, 1024 ]
train:
batch_size: 1
steps: 2000
gradient_accumulation_steps: 1
train_unet: true
train_text_encoder: false
gradient_checkpointing: true
noise_scheduler: "flowmatch"
optimizer: "adamw8bit"
lr: 1e-4
ema_config:
use_ema: true
ema_decay: 0.99
dtype: bf16
model:
name_or_path: "black-forest-labs/FLUX.1-dev"
is_flux: true
quantize: true

# ---- Logging & Experiment Tracking ----
logging:
log_every: 10
verbose: false

# MLflow
use_mlflow: true
project_name: "flux-lora-training"
run_name: "my_flux_lora_mlflow"
mlflow_tracking_uri: null # null = local ./mlruns directory
mlflow_experiment_name: "flux-lora"
mlflow_log_artifacts: true # log checkpoint .safetensors as MLflow artifacts
mlflow_register_model: true # register final LoRA in MLflow (appears in Models tab)
mlflow_registered_model_name: null # optional: versioned name in Model Registry (null = model logged as artifact only, not registered)

# W&B (can be enabled simultaneously with MLflow)
# use_wandb: true

sample:
sampler: "flowmatch"
sample_every: 250
width: 1024
height: 1024
prompts:
- "woman with red hair, playing chess at the park, bomb going off in the background"
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
neg: ""
seed: 42
walk_seed: true
guidance_scale: 4
sample_steps: 20
meta:
name: "[name]"
version: '1.0'
3 changes: 2 additions & 1 deletion extensions_built_in/sd_trainer/DiffusionTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def sample(self, step=None, is_first=False):
def save(self, step=None):
self.maybe_stop()
self.update_status("running", "Saving model")
super().save(step)
result = super().save(step)
self.maybe_stop()
self.update_status("running", "Training")
return result
3 changes: 2 additions & 1 deletion extensions_built_in/sd_trainer/UITrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def sample(self, step=None, is_first=False):
def save(self, step=None):
self.maybe_stop()
self.update_status("running", "Saving model")
super().save(step)
result = super().save(step)
self.maybe_stop()
self.update_status("running", "Training")
return result
37 changes: 30 additions & 7 deletions jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def sample(self, step=None, is_first=False):
ctrl_img_2=sample_item.ctrl_img_2,
ctrl_img_3=sample_item.ctrl_img_3,
do_cfg_norm=sample_config.do_cfg_norm,
log_step=step,
**extra_args
))

Expand Down Expand Up @@ -482,8 +483,7 @@ def clean_up_saves(self):
return latest_item

def post_save_hook(self, save_path):
# override in subclass
pass
self.logger.log_checkpoint(save_path)

def done_hook(self):
pass
Expand All @@ -493,7 +493,7 @@ def end_step_hook(self):

def save(self, step=None):
if not self.accelerator.is_main_process:
return
return None
flush()
if self.ema is not None:
# always save params as ema
Expand Down Expand Up @@ -614,6 +614,7 @@ def save(self, step=None):
yaml.dump(self.meta, f)
# move it back
self.adapter = self.adapter.to(orig_device, dtype=orig_dtype)
file_path = name_or_path
else:
direct_save = False
if self.adapter_config.train_only_image_encoder:
Expand Down Expand Up @@ -665,6 +666,9 @@ def save(self, step=None):
get_torch_dtype(self.save_config.dtype)
)

# snapshot the checkpoint path before it gets overwritten by SNR/optimizer saves
checkpoint_path = file_path

# save learnable params as json if we have thim
if self.snr_gos:
json_data = {
Expand All @@ -677,7 +681,7 @@ def save(self, step=None):
with open(path_to_save, 'w') as f:
json.dump(json_data, f, indent=4)

print_acc(f"Saved checkpoint to {file_path}")
print_acc(f"Saved checkpoint to {checkpoint_path}")

# save optimizer
if self.optimizer is not None:
Expand All @@ -695,11 +699,12 @@ def save(self, step=None):
print_acc("Could not save optimizer")

self.clean_up_saves()
self.post_save_hook(file_path)
self.post_save_hook(checkpoint_path)

if self.ema is not None:
self.ema.train()
flush()
return checkpoint_path

# Called before the model is loaded
def hook_before_model_load(self):
Expand All @@ -717,6 +722,8 @@ def hook_add_extra_train_params(self, params):
def hook_before_train_loop(self):
if self.accelerator.is_main_process:
self.logger.start()
if self.dataset_configs:
self.logger.log_datasets(self.dataset_configs)
self.prepare_accelerator()

def sample_step_hook(self, img_num, total_imgs):
Expand Down Expand Up @@ -2395,8 +2402,24 @@ def run(self):
self.logger.commit(step=self.step_num)
print_acc("")
if self.accelerator.is_main_process:
self.save()
self.logger.finish()
try:
final_path = self.save()
# Register model (only for adapter training, not merged saves)
if (
final_path
and self.network_config is not None
and not self.train_config.merge_network_on_save
):
self.logger.log_model(
lora_path=final_path,
base_model=self.model_config.name_or_path_original,
model_type=self.model_config.arch or "sd1",
network_type=self.network_config.type,
lora_rank=self.network_config.rank,
lora_alpha=self.network_config.linear_alpha,
)
finally:
self.logger.finish()
self.accelerator.end_training()

if self.accelerator.is_main_process:
Expand Down
11 changes: 10 additions & 1 deletion toolkit/config_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,15 @@ def __init__(self, **kwargs):
self.log_every: int = kwargs.get('log_every', 100)
self.verbose: bool = kwargs.get('verbose', False)
self.use_wandb: bool = kwargs.get('use_wandb', False)
self.use_mlflow: bool = kwargs.get('use_mlflow', False)
self.use_ui_logger: bool = kwargs.get('use_ui_logger', False)
self.project_name: str = kwargs.get('project_name', 'ai-toolkit')
self.run_name: str = kwargs.get('run_name', None)
self.mlflow_tracking_uri: str = kwargs.get('mlflow_tracking_uri', None)
self.mlflow_experiment_name: str = kwargs.get('mlflow_experiment_name', None)
self.mlflow_log_artifacts: bool = kwargs.get('mlflow_log_artifacts', False)
self.mlflow_register_model: bool = kwargs.get('mlflow_register_model', False)
self.mlflow_registered_model_name: str = kwargs.get('mlflow_registered_model_name', None)

class SampleItem:
def __init__(
Expand Down Expand Up @@ -1052,6 +1058,7 @@ def __init__(
fps: int = 15,
ctrl_idx: int = 0,
do_cfg_norm: bool = False,
log_step: Optional[int] = None,
):
self.width: int = width
self.height: int = height
Expand Down Expand Up @@ -1091,6 +1098,8 @@ def __init__(
self.ctrl_img_1 = ctrl_img_1
self.ctrl_img_2 = ctrl_img_2
self.ctrl_img_3 = ctrl_img_3
# Runtime-only metadata used to align sample artifacts with the training step.
self.log_step = log_step

# prompt string will override any settings above
self._process_prompt_string()
Expand Down Expand Up @@ -1315,7 +1324,7 @@ def log_image(self, image, count: int = 0, max_count=0):
if self.logger is None:
return

self.logger.log_image(image, count, self.prompt)
self.logger.log_image(image, count, self.prompt, step=self.log_step)


def validate_configs(
Expand Down
Loading
Loading