diff --git a/config/examples/train_lora_flux_24gb_mlflow.yaml b/config/examples/train_lora_flux_24gb_mlflow.yaml new file mode 100644 index 000000000..9a786f6d7 --- /dev/null +++ b/config/examples/train_lora_flux_24gb_mlflow.yaml @@ -0,0 +1,89 @@ +--- +# Example: Training a FLUX LoRA with MLflow experiment tracking +# +# Prerequisites: +# pip install "mlflow>=3,<4" +# +# To view results, start the MLflow UI: +# mlflow ui # local (default: http://localhost:5000) +# mlflow server --host 0.0.0.0 # remote +# +# You can also point to an existing tracking server: +# logging.mlflow_tracking_uri: "http://your-server:5000" + +job: extension +config: + name: "my_flux_lora_mlflow" + process: + - type: 'sd_trainer' + training_folder: "output" + device: cuda:0 + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 + save_every: 250 + max_step_saves_to_keep: 4 + datasets: + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 + shuffle_tokens: false + cache_latents_to_disk: true + resolution: [ 512, 768, 1024 ] + train: + batch_size: 1 + steps: 2000 + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false + gradient_checkpointing: true + noise_scheduler: "flowmatch" + optimizer: "adamw8bit" + lr: 1e-4 + ema_config: + use_ema: true + ema_decay: 0.99 + dtype: bf16 + model: + name_or_path: "black-forest-labs/FLUX.1-dev" + is_flux: true + quantize: true + + # ---- Logging & Experiment Tracking ---- + logging: + log_every: 10 + verbose: false + + # MLflow + use_mlflow: true + project_name: "flux-lora-training" + run_name: "my_flux_lora_mlflow" + mlflow_tracking_uri: null # null = local ./mlruns directory + mlflow_experiment_name: "flux-lora" + mlflow_log_artifacts: true # log checkpoint .safetensors as MLflow artifacts + mlflow_register_model: true # register final LoRA in MLflow (appears in Models tab) + mlflow_registered_model_name: null # optional: versioned name in Model Registry (null = model logged as artifact only, not registered) + + # W&B (can be enabled simultaneously with MLflow) + # use_wandb: true + + sample: + sampler: "flowmatch" + sample_every: 250 + width: 1024 + height: 1024 + prompts: + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 20 +meta: + name: "[name]" + version: '1.0' diff --git a/extensions_built_in/sd_trainer/DiffusionTrainer.py b/extensions_built_in/sd_trainer/DiffusionTrainer.py index f39611b13..1debc1932 100644 --- a/extensions_built_in/sd_trainer/DiffusionTrainer.py +++ b/extensions_built_in/sd_trainer/DiffusionTrainer.py @@ -310,6 +310,7 @@ def sample(self, step=None, is_first=False): def save(self, step=None): self.maybe_stop() self.update_status("running", "Saving model") - super().save(step) + result = super().save(step) self.maybe_stop() self.update_status("running", "Training") + return result diff --git a/extensions_built_in/sd_trainer/UITrainer.py b/extensions_built_in/sd_trainer/UITrainer.py index 8b5fa796c..5194190cd 100644 --- a/extensions_built_in/sd_trainer/UITrainer.py +++ b/extensions_built_in/sd_trainer/UITrainer.py @@ -290,6 +290,7 @@ def sample(self, step=None, is_first=False): def save(self, step=None): self.maybe_stop() self.update_status("running", "Saving model") - super().save(step) + result = super().save(step) self.maybe_stop() self.update_status("running", "Training") + return result diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 01418c74f..4f580642c 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -350,6 +350,7 @@ def sample(self, step=None, is_first=False): ctrl_img_2=sample_item.ctrl_img_2, ctrl_img_3=sample_item.ctrl_img_3, do_cfg_norm=sample_config.do_cfg_norm, + log_step=step, **extra_args )) @@ -482,8 +483,7 @@ def clean_up_saves(self): return latest_item def post_save_hook(self, save_path): - # override in subclass - pass + self.logger.log_checkpoint(save_path) def done_hook(self): pass @@ -493,7 +493,7 @@ def end_step_hook(self): def save(self, step=None): if not self.accelerator.is_main_process: - return + return None flush() if self.ema is not None: # always save params as ema @@ -614,6 +614,7 @@ def save(self, step=None): yaml.dump(self.meta, f) # move it back self.adapter = self.adapter.to(orig_device, dtype=orig_dtype) + file_path = name_or_path else: direct_save = False if self.adapter_config.train_only_image_encoder: @@ -665,6 +666,9 @@ def save(self, step=None): get_torch_dtype(self.save_config.dtype) ) + # snapshot the checkpoint path before it gets overwritten by SNR/optimizer saves + checkpoint_path = file_path + # save learnable params as json if we have thim if self.snr_gos: json_data = { @@ -677,7 +681,7 @@ def save(self, step=None): with open(path_to_save, 'w') as f: json.dump(json_data, f, indent=4) - print_acc(f"Saved checkpoint to {file_path}") + print_acc(f"Saved checkpoint to {checkpoint_path}") # save optimizer if self.optimizer is not None: @@ -695,11 +699,12 @@ def save(self, step=None): print_acc("Could not save optimizer") self.clean_up_saves() - self.post_save_hook(file_path) + self.post_save_hook(checkpoint_path) if self.ema is not None: self.ema.train() flush() + return checkpoint_path # Called before the model is loaded def hook_before_model_load(self): @@ -717,6 +722,8 @@ def hook_add_extra_train_params(self, params): def hook_before_train_loop(self): if self.accelerator.is_main_process: self.logger.start() + if self.dataset_configs: + self.logger.log_datasets(self.dataset_configs) self.prepare_accelerator() def sample_step_hook(self, img_num, total_imgs): @@ -2395,8 +2402,24 @@ def run(self): self.logger.commit(step=self.step_num) print_acc("") if self.accelerator.is_main_process: - self.save() - self.logger.finish() + try: + final_path = self.save() + # Register model (only for adapter training, not merged saves) + if ( + final_path + and self.network_config is not None + and not self.train_config.merge_network_on_save + ): + self.logger.log_model( + lora_path=final_path, + base_model=self.model_config.name_or_path_original, + model_type=self.model_config.arch or "sd1", + network_type=self.network_config.type, + lora_rank=self.network_config.rank, + lora_alpha=self.network_config.linear_alpha, + ) + finally: + self.logger.finish() self.accelerator.end_training() if self.accelerator.is_main_process: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 228ad3f38..4dbe52d2c 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -35,9 +35,15 @@ def __init__(self, **kwargs): self.log_every: int = kwargs.get('log_every', 100) self.verbose: bool = kwargs.get('verbose', False) self.use_wandb: bool = kwargs.get('use_wandb', False) + self.use_mlflow: bool = kwargs.get('use_mlflow', False) self.use_ui_logger: bool = kwargs.get('use_ui_logger', False) self.project_name: str = kwargs.get('project_name', 'ai-toolkit') self.run_name: str = kwargs.get('run_name', None) + self.mlflow_tracking_uri: str = kwargs.get('mlflow_tracking_uri', None) + self.mlflow_experiment_name: str = kwargs.get('mlflow_experiment_name', None) + self.mlflow_log_artifacts: bool = kwargs.get('mlflow_log_artifacts', False) + self.mlflow_register_model: bool = kwargs.get('mlflow_register_model', False) + self.mlflow_registered_model_name: str = kwargs.get('mlflow_registered_model_name', None) class SampleItem: def __init__( @@ -1052,6 +1058,7 @@ def __init__( fps: int = 15, ctrl_idx: int = 0, do_cfg_norm: bool = False, + log_step: Optional[int] = None, ): self.width: int = width self.height: int = height @@ -1091,6 +1098,8 @@ def __init__( self.ctrl_img_1 = ctrl_img_1 self.ctrl_img_2 = ctrl_img_2 self.ctrl_img_3 = ctrl_img_3 + # Runtime-only metadata used to align sample artifacts with the training step. + self.log_step = log_step # prompt string will override any settings above self._process_prompt_string() @@ -1315,7 +1324,7 @@ def log_image(self, image, count: int = 0, max_count=0): if self.logger is None: return - self.logger.log_image(image, count, self.prompt) + self.logger.log_image(image, count, self.prompt, step=self.log_step) def validate_configs( diff --git a/toolkit/logging_aitk.py b/toolkit/logging_aitk.py index a84ad6ea6..678f4f394 100644 --- a/toolkit/logging_aitk.py +++ b/toolkit/logging_aitk.py @@ -1,11 +1,14 @@ -from typing import OrderedDict, Optional +from typing import Any, Dict, List, OrderedDict, Optional, Tuple from PIL import Image from toolkit.config_modules import LoggingConfig import os +import random +import re import sqlite3 +import string import time -from typing import Any, Dict, Tuple, List +import uuid # Base logger class @@ -30,11 +33,46 @@ def commit(self, step: Optional[int] = None): def log_image(self, *args, **kwargs): pass + # log checkpoint artifact + def log_checkpoint(self, file_path: str): + pass + + # register trained model + def log_model(self, **kwargs): + pass + + # log training datasets + def log_datasets(self, dataset_configs): + pass + # finish logging def finish(self): pass +def _make_lora_pyfunc_stub(): + """Build a minimal pyfunc stub so the LoRA can be registered in the Model Registry. + + This is a stopgap until MLflow gets a native diffusers adapter flavor — + see https://github.com/mlflow/mlflow/issues/22122. + The stub stores the LoRA weights as an artifact for lineage tracking but + does not implement inference. To use the LoRA, load it with diffusers, e.g.: + pipe.load_lora_weights(context.artifacts['lora_weights']) + """ + import mlflow.pyfunc + + class LoRAModelStub(mlflow.pyfunc.PythonModel): + def predict(self, context, model_input, params=None): + raise NotImplementedError( + "This LoRA adapter cannot be served directly via MLflow. " + "Load the weights from context.artifacts['lora_weights'] " + "with your preferred diffusers pipeline. " + "See https://github.com/mlflow/mlflow/issues/22122" + ) + + return LoRAModelStub + + # Wandb logger class # This class logs the data to wandb class WandbLogger(EmptyLogger): @@ -76,6 +114,8 @@ def log_image( **kwargs, ): # create a wandb image object and log it + # W&B associates images with the step passed to the next commit() call + kwargs.pop("step", None) image = self._image(image, caption=caption, *args, **kwargs) self._log({f"sample_{id}": image}, commit=False) @@ -83,7 +123,362 @@ def finish(self): self.run.finish() -class UILogger: +class MLflowLogger(EmptyLogger): + """MLflow experiment tracking logger. + + Follows the same two-phase pattern as WandbLogger: + - log() accumulates metrics into a buffer + - commit(step) flushes the buffer to MLflow as a single batched call + + Uses the MLflow fluent API (mlflow.start_run / mlflow.log_metrics / etc.). + All MLflow API calls are wrapped in try/except so a tracking server failure + never kills a training run. + """ + + # MLflow metric keys: alphanumerics, underscores, dashes, periods, spaces, slashes + _METRIC_KEY_RE = re.compile(r"[^a-zA-Z0-9_/.\- ]+") + + def __init__( + self, + project: str, + run_name: str | None, + config: OrderedDict, + tracking_uri: str | None = None, + experiment_name: str | None = None, + log_artifacts: bool = False, + register_model: bool = False, + registered_model_name: str | None = None, + ) -> None: + self.project = project + self.run_name = run_name + self.config = config + self.tracking_uri = tracking_uri + self.experiment_name = experiment_name or project + self.log_artifacts = log_artifacts + self.register_model = register_model + self.registered_model_name = registered_model_name + + self._pending: Dict[str, float] = {} + self._mlflow = None + self._run = None + self._started = False + self._last_step: Optional[int] = None + self._logged_images_tag_set = False + + def start(self): + if self._started: + return + + try: + import mlflow + except ImportError: + raise ImportError( + "Failed to import mlflow. Please install a compatible version by running `pip install \"mlflow>=3,<4\"`" + ) + + self._mlflow = mlflow + + try: + if self.tracking_uri: + mlflow.set_tracking_uri(self.tracking_uri) + mlflow.set_experiment(self.experiment_name) + self._run = mlflow.start_run(run_name=self.run_name) + except Exception as e: + print(f"[MLflowLogger] Failed to start MLflow run: {e}") + self._mlflow = None + return + + self._started = True + + # Log flattened config as params for comparison + flat_params = self._flatten_config(self.config) + if flat_params: + try: + items = list(flat_params.items()) + for i in range(0, len(items), 100): + batch = dict(items[i : i + 100]) + mlflow.log_params(batch) + except Exception as e: + print(f"[MLflowLogger] Warning: failed to log params: {e}") + + def log(self, log_dict=None, *args, **kwargs): + if log_dict is None: + if args: + log_dict = args[0] + else: + return + + if not isinstance(log_dict, dict): + return + + for k, v in log_dict.items(): + key = self._sanitize_key(k) + try: + self._pending[key] = float(v) + except (TypeError, ValueError): + pass + + def commit(self, step: Optional[int] = None): + if self._mlflow is None: + return + + self._last_step = step + + if self._pending: + try: + self._mlflow.log_metrics(self._pending, step=step) + self._pending.clear() + except Exception as e: + print(f"[MLflowLogger] Warning: failed to log metrics at step {step}: {e}") + self._pending.clear() + + def log_image( + self, + image, + id, # sample index + caption: str | None = None, + step: Optional[int] = None, + *args, + **kwargs, + ): + if self._mlflow is None: + return + + # handle video frames (list of images) — log only the first frame + if isinstance(image, list): + if len(image) == 0: + return + image = image[0] + + try: + import numpy as np + is_loggable = isinstance(image, (Image.Image, np.ndarray)) + except ImportError: + is_loggable = isinstance(image, Image.Image) + np = None + + if not is_loggable: + return + + # Convert numpy to PIL for consistent handling + if np is not None and isinstance(image, np.ndarray): + image = Image.fromarray(image) + + if step is None: + step = self._last_step if self._last_step is not None else 0 + try: + # Step-aligned images for the Image Grid chart in Model Metrics. + # We build the artifact path manually using '+' as the separator + # instead of relying on mlflow.log_image(key=, step=) because + # MLflow <= 3.10.1 uses '%' which breaks URL encoding for certain + # step numbers (e.g. step 23 → '%23' → '#'). Fixed on master via + # mlflow/mlflow#21269 but not yet in a stable release. + # The JS parser (ImageReducer.ts) already supports both '+' and '%'. + ts = int(time.time() * 1000) + file_uuid = f"{random.choice(string.ascii_lowercase[6:])}{str(uuid.uuid4())[1:]}" + safe_key = str(id).replace("/", "#") + + base = f"images/{safe_key}+step+{step}+timestamp+{ts}+{file_uuid}" + self._mlflow.log_image(image, artifact_file=f"{base}.png") + + if not self._logged_images_tag_set and self._run: + self._mlflow.set_tag("mlflow.loggedImages", "true") + self._logged_images_tag_set = True + except Exception as e: + print(f"[MLflowLogger] Warning: failed to log image sample_{id}: {e}") + + def finish(self): + if self._mlflow is None or not self._started: + return + + if self._pending: + try: + self._mlflow.log_metrics(self._pending, step=self._last_step) + except Exception as e: + print(f"[MLflowLogger] Warning: failed to flush final metrics: {e}") + self._pending.clear() + + try: + self._mlflow.end_run() + except Exception as e: + print(f"[MLflowLogger] Warning: failed to end run: {e}") + + self._run = None + self._started = False + self._mlflow = None + + @property + def run_id(self) -> str | None: + """Return the MLflow run ID.""" + if self._run is not None: + return self._run.info.run_id + return None + + def log_checkpoint(self, file_path: str): + """Log a saved checkpoint as an MLflow artifact.""" + if self._mlflow is None or not self.log_artifacts: + return + if not os.path.exists(file_path): + print(f"[MLflowLogger] Warning: checkpoint path does not exist, skipping: {file_path}") + return + + try: + if os.path.isdir(file_path): + self._mlflow.log_artifacts(file_path, artifact_path="checkpoints") + else: + self._mlflow.log_artifact(file_path, artifact_path="checkpoints") + except Exception as e: + print(f"[MLflowLogger] Warning: failed to log checkpoint artifact {file_path}: {e}") + + def log_model( + self, + lora_path: str, + base_model: str, + model_type: str = "sd1", + network_type: str = "lora", + lora_rank: int | None = None, + lora_alpha: float | None = None, + **kwargs, + ): + """Register the LoRA adapter in the MLflow Model Registry. + + Uses a minimal pyfunc stub so the LoRA appears in the Models section + with versioning and lineage. Inference is not supported — load the + weights with diffusers directly. + See https://github.com/mlflow/mlflow/issues/22122. + """ + if not self.register_model: + return + if self._mlflow is None or not self._started: + return + if not os.path.exists(lora_path): + print(f"[MLflowLogger] Warning: LoRA path does not exist, skipping registration: {lora_path}") + return + + try: + import mlflow.pyfunc + + LoRAModelStub = _make_lora_pyfunc_stub() + + model_config = { + "base_model": base_model, + "model_type": model_type, + "network_type": network_type, + } + if lora_rank is not None: + model_config["lora_rank"] = lora_rank + if lora_alpha is not None: + model_config["lora_alpha"] = lora_alpha + + mlflow.pyfunc.log_model( + name="lora_model", + python_model=LoRAModelStub(), + artifacts={"lora_weights": lora_path}, + model_config=model_config, + registered_model_name=self.registered_model_name, + ) + + print(f"[MLflowLogger] Registered LoRA model from {lora_path}") + if self.registered_model_name: + print(f"[MLflowLogger] Model registered as '{self.registered_model_name}'") + except Exception as e: + print(f"[MLflowLogger] Warning: failed to register LoRA model: {e}") + + def log_datasets(self, dataset_configs): + """Log training datasets to MLflow so they appear in the Datasets section.""" + if self._mlflow is None or not self._started: + return + + try: + import pandas as pd + + for i, ds in enumerate(dataset_configs): + source_path = ds.folder_path or ds.dataset_path or "unknown" + name = os.path.basename(source_path) if source_path != "unknown" else f"dataset_{i}" + + info = { + "source_path": [source_path], + "resolution": [str(ds.resolution)], + "type": [ds.type], + } + if ds.caption_ext: + info["caption_ext"] = [ds.caption_ext] + if ds.caption_dropout_rate is not None: + info["caption_dropout_rate"] = [ds.caption_dropout_rate] + if ds.trigger_word: + info["trigger_word"] = [ds.trigger_word] + + df = pd.DataFrame(info) + dataset = self._mlflow.data.from_pandas(df, name=name, source=source_path) + self._mlflow.log_input(dataset, context="training") + except Exception as e: + print(f"[MLflowLogger] Warning: failed to log datasets: {e}") + + # ---- internal helpers ---- + + def _sanitize_key(self, key: str) -> str: + return self._METRIC_KEY_RE.sub("_", key) + + @staticmethod + def _flatten_config(config, prefix: str = "", sep: str = ".") -> Dict[str, str]: + """Flatten a nested dict/OrderedDict into dot-separated keys with string values.""" + flat = {} + if not isinstance(config, (dict, OrderedDict)): + return flat + + for k, v in config.items(): + full_key = f"{prefix}{sep}{k}" if prefix else str(k) + if isinstance(v, (dict, OrderedDict)): + flat.update(MLflowLogger._flatten_config(v, full_key, sep)) + else: + flat[full_key] = str(v)[:250] + return flat + + +class CompositeLogger(EmptyLogger): + """Dispatches all logging calls to multiple loggers simultaneously. + + Enables running W&B + MLflow (+ UILogger) at the same time. + """ + + def __init__(self, loggers: list) -> None: + self._loggers = [lg for lg in loggers if type(lg) is not EmptyLogger] + + def _safe_call(self, method_name, *args, **kwargs): + for lg in self._loggers: + try: + getattr(lg, method_name)(*args, **kwargs) + except ImportError: + raise # missing package = config error, never swallow + except Exception as e: + print(f"[CompositeLogger] {type(lg).__name__}.{method_name}() failed: {e}") + + def start(self): + self._safe_call("start") + + def log(self, *args, **kwargs): + self._safe_call("log", *args, **kwargs) + + def commit(self, step: Optional[int] = None): + self._safe_call("commit", step=step) + + def log_image(self, *args, **kwargs): + self._safe_call("log_image", *args, **kwargs) + + def log_checkpoint(self, file_path: str): + self._safe_call("log_checkpoint", file_path) + + def log_model(self, **kwargs): + self._safe_call("log_model", **kwargs) + + def log_datasets(self, dataset_configs): + self._safe_call("log_datasets", dataset_configs) + + def finish(self): + self._safe_call("finish") + + +class UILogger(EmptyLogger): def __init__( self, log_file: str, @@ -190,6 +585,7 @@ def log_image(self, *args, **kwargs): # this doesnt log images for now pass + # finish logging def finish(self): if not self._started: @@ -303,14 +699,40 @@ def create_logger( all_config: OrderedDict, save_root: Optional[str] = None, ): + loggers: List[EmptyLogger] = [] + if logging_config.use_wandb: - project_name = logging_config.project_name - run_name = logging_config.run_name - return WandbLogger(project=project_name, run_name=run_name, config=all_config) - elif logging_config.use_ui_logger: + loggers.append( + WandbLogger( + project=logging_config.project_name, + run_name=logging_config.run_name, + config=all_config, + ) + ) + + if logging_config.use_mlflow: + loggers.append( + MLflowLogger( + project=logging_config.project_name, + run_name=logging_config.run_name, + config=all_config, + tracking_uri=logging_config.mlflow_tracking_uri, + experiment_name=logging_config.mlflow_experiment_name, + log_artifacts=logging_config.mlflow_log_artifacts, + register_model=logging_config.mlflow_register_model, + registered_model_name=logging_config.mlflow_registered_model_name, + ) + ) + + if logging_config.use_ui_logger: if save_root is None: raise ValueError("save_root must be provided when using UILogger") log_file = os.path.join(save_root, "loss_log.db") - return UILogger(log_file=log_file) - else: + loggers.append(UILogger(log_file=log_file)) + + if len(loggers) == 0: return EmptyLogger() + elif len(loggers) == 1: + return loggers[0] + else: + return CompositeLogger(loggers)