diff --git a/.gitignore b/.gitignore index 04c233ac4..aa4fe1ad4 100644 --- a/.gitignore +++ b/.gitignore @@ -181,4 +181,6 @@ cython_debug/ ._.DS_Store aitk_db.db /notes.md -/data \ No newline at end of file +/data +masks/* +MULTITRIGGER_DOP_REFACTOR_PLAN.md diff --git a/extensions_built_in/diffusion_models/z_image/z_image.py b/extensions_built_in/diffusion_models/z_image/z_image.py index 368ae9e7c..590b0a48f 100644 --- a/extensions_built_in/diffusion_models/z_image/z_image.py +++ b/extensions_built_in/diffusion_models/z_image/z_image.py @@ -32,8 +32,12 @@ scheduler_config = { "num_train_timesteps": 1000, - "use_dynamic_shifting": False, - "shift": 3.0, + "use_dynamic_shifting": True, + "base_image_seq_len": 256, # 512x512 with VAE/16 and patch/2 + "max_image_seq_len": 1024, # 1024x1024 + "base_shift": 0.5, + "max_shift": 0.85, + "min_shift": 0.33, # Floor to prevent very low noise sampling } diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 152cd131c..6555b7fe5 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -325,9 +325,32 @@ def hook_before_train_loop(self): self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs) if self.trigger_word is not None: self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word, **encode_kwargs) + + # DOP: Precompute embeddings via dataloader (new pattern) if self.train_config.diff_output_preservation: - self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class) - + from toolkit.prompt_utils import build_dop_replacement_pairs + + triggers_csv = self.trigger_word + classes_csv = self.train_config.diff_output_preservation_class + + # Build replacement pairs for fallback encoding + self._dop_replacement_pairs = build_dop_replacement_pairs( + triggers_csv=triggers_csv, + classes_csv=classes_csv, + case_insensitive=False + ) + + # Delegate precompute to dataloader + datasets = get_dataloader_datasets(self.data_loader) + for dataset in datasets: + dataset.precompute_dop_embeddings( + triggers_csv=triggers_csv, + classes_csv=classes_csv, + encode_fn=lambda caption: self.sd.encode_prompt(caption, **encode_kwargs), + case_insensitive=False, + debug=getattr(self.train_config, 'diff_output_preservation_debug', False) + ) + self.cache_sample_prompts() print_acc("\n***** UNLOADING TEXT ENCODER *****") @@ -1522,6 +1545,13 @@ def get_adapter_multiplier(): unconditional_embeds = concat_prompt_embeds( [unconditional_embeds] * noisy_latents.shape[0] ) + if self.train_config.diff_output_preservation: + + if batch.dop_prompt_embeds is not None: + # use the cached embeds + self.diff_output_preservation_embeds = batch.dop_prompt_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False @@ -1587,18 +1617,39 @@ def get_adapter_multiplier(): self.adapter.is_unconditional_run = False if self.train_config.diff_output_preservation: - dop_prompts = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in conditioned_prompts] - dop_prompts_2 = None - if prompt_2 is not None: - dop_prompts_2 = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in prompt_2] - self.diff_output_preservation_embeds = self.sd.encode_prompt( - dop_prompts, dop_prompts_2, - dropout_prob=self.train_config.prompt_dropout_prob, - long_prompts=self.do_long_prompts, - **prompt_kwargs - ).to( - self.device_torch, - dtype=dtype) + # Use dataloader-provided DOP embeddings (new pattern) + self.diff_output_preservation_embeds = batch.dop_prompt_embeds + + # Fallback: encode on-the-fly if cache missing (text encoder still available in this branch) + if self.diff_output_preservation_embeds is None: + from toolkit.prompt_utils import apply_dop_replacements + print_acc("[DOP] Cache missing for batch - encoding on-the-fly") + + # Apply trigger→class replacements using utility + dop_prompts = [ + apply_dop_replacements(p, self._dop_replacement_pairs, debug=False) + for p in conditioned_prompts + ] + dop_prompts_2 = None + if prompt_2 is not None: + dop_prompts_2 = [ + apply_dop_replacements(p, self._dop_replacement_pairs, debug=False) + for p in prompt_2 + ] + self.diff_output_preservation_embeds = self.sd.encode_prompt( + dop_prompts, dop_prompts_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts, + **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype) + else: + # Move cached embeddings to device + self.diff_output_preservation_embeds = self.diff_output_preservation_embeds.to( + self.device_torch, + dtype=dtype + ) # detach the embeddings conditional_embeds = conditional_embeds.detach() if self.train_config.do_cfg: @@ -1784,7 +1835,10 @@ def get_adapter_multiplier(): prior_embeds_to_use = conditional_embeds # use diff_output_preservation embeds if doing dfe if self.train_config.diff_output_preservation: - prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) + if self.diff_output_preservation_embeds is None: + print_acc("[DOP WARNING] diff_output_preservation enabled but embeddings missing - skipping DOP this step") + else: + prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) if self.train_config.blank_prompt_preservation: blank_embeds = self.cached_blank_embeds.clone().detach().to( @@ -1983,13 +2037,18 @@ def get_adapter_multiplier(): prior_pred=prior_to_calculate_loss, ) - if self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation: + # Check if DOP is actually available this step (embeds might be missing if cache failed) + do_dop_this_step = self.train_config.diff_output_preservation and self.diff_output_preservation_embeds is not None + if self.train_config.diff_output_preservation and self.diff_output_preservation_embeds is None: + print_acc("[DOP WARNING] diff_output_preservation enabled but embeddings missing - skipping preservation loss this step") + + if do_dop_this_step or self.train_config.blank_prompt_preservation: # send the loss backwards otherwise checkpointing will fail self.accelerator.backward(loss) normal_loss = loss.detach() # dont send backward again - + with torch.no_grad(): - if self.train_config.diff_output_preservation: + if do_dop_this_step: preservation_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) elif self.train_config.blank_prompt_preservation: blank_embeds = self.cached_blank_embeds.clone().detach().to( @@ -2006,7 +2065,7 @@ def get_adapter_multiplier(): batch=batch, **pred_kwargs ) - multiplier = self.train_config.diff_output_preservation_multiplier if self.train_config.diff_output_preservation else self.train_config.blank_prompt_preservation_multiplier + multiplier = self.train_config.diff_output_preservation_multiplier if do_dop_this_step else self.train_config.blank_prompt_preservation_multiplier preservation_loss = torch.nn.functional.mse_loss(preservation_pred, prior_pred) * multiplier self.accelerator.backward(preservation_loss) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 91f62a5ac..7fcaf4a0c 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1174,8 +1174,8 @@ def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): timestep_type = 'shift' patch_size = 1 - if self.sd.is_flux or 'flex' in self.sd.arch: - # flux is a patch size of 1, but latents are divided by 2, so we need to double it + if self.sd.is_flux or 'flex' in self.sd.arch or self.sd.arch == 'zimage': + # flux/zimage is a patch size of 1, but latents are divided by 2, so we need to double it patch_size = 2 elif hasattr(self.sd.unet.config, 'patch_size'): patch_size = self.sd.unet.config.patch_size diff --git a/testing/test_dop_dataloader.py b/testing/test_dop_dataloader.py new file mode 100644 index 000000000..2765ace06 --- /dev/null +++ b/testing/test_dop_dataloader.py @@ -0,0 +1,148 @@ +""" +Tests for DOP (Differential Output Preservation) dataloader integration. + +Tests the new clean separation: dataloader handles all embedding I/O, +trainer consumes batch.dop_prompt_embeds directly. +""" + +import pytest +from toolkit.prompt_utils import build_dop_replacement_pairs, apply_dop_replacements + + +class TestDOPTextTransformations: + """Test DOP text transformation utilities.""" + + def test_build_dop_replacement_pairs_simple(self): + """Test building replacement pairs with equal triggers and classes.""" + pairs, digest = build_dop_replacement_pairs("Jinx, Zapper", "Woman, Gun") + + assert len(pairs) == 2 + # Should be sorted by length DESC + assert pairs[0] == ("Zapper", "Gun") + assert pairs[1] == ("Jinx", "Woman") + assert isinstance(digest, str) + assert len(digest) > 0 + + def test_build_dop_replacement_pairs_more_triggers(self): + """Test with more triggers than classes - missing classes become empty string.""" + pairs, digest = build_dop_replacement_pairs("Jinx, Zapper, Vest", "Woman, Gun") + + assert len(pairs) == 3 + assert pairs[0] == ("Zapper", "Gun") + assert pairs[1] == ("Jinx", "Woman") + assert pairs[2] == ("Vest", "") + + def test_build_dop_replacement_pairs_sorting(self): + """Test that pairs are sorted by trigger length DESC to avoid substring issues.""" + pairs, _ = build_dop_replacement_pairs("Jinx, Jinx Master", "Woman, Veteran") + + # "Jinx Master" should come first (longer) + assert len(pairs) == 2 + assert pairs[0] == ("Jinx Master", "Veteran") + assert pairs[1] == ("Jinx", "Woman") + + def test_apply_dop_replacements_simple(self): + """Test simple trigger→class replacement.""" + pairs = [("Jinx", "Woman"), ("Zapper", "Gun")] + result = apply_dop_replacements("Jinx with a Zapper", pairs) + + assert result == "Woman with a Gun" + + def test_apply_dop_replacements_multiple_occurrences(self): + """Test replacement of multiple occurrences of same trigger.""" + pairs = [("Jinx", "Woman")] + result = apply_dop_replacements("Jinx likes Jinx", pairs) + + assert result == "Woman likes Woman" + + def test_apply_dop_replacements_longest_first(self): + """Test that longer triggers are replaced first (substring protection).""" + pairs = [("Jinx Master", "Veteran"), ("Jinx", "Woman")] # Already sorted + result = apply_dop_replacements("Jinx Master and Jinx", pairs) + + # Should replace "Jinx Master" first, then "Jinx" + assert result == "Veteran and Woman" + + def test_apply_dop_replacements_empty_class(self): + """Test replacement with empty class (removes trigger).""" + pairs = [("Vest", ""), ("Jinx", "Woman")] + result = apply_dop_replacements("Jinx wearing Vest", pairs) + + # "Vest" replaced with empty string + assert result == "Woman wearing" + + def test_apply_dop_replacements_normalizes_whitespace(self): + """Test that result has normalized whitespace.""" + pairs = [("item1", ""), ("item2", "foo")] + result = apply_dop_replacements("item1 item2", pairs) + + # Should collapse repeated whitespace + assert result == "foo" + + def test_digest_changes_with_inputs(self): + """Test that digest changes when inputs change.""" + _, digest1 = build_dop_replacement_pairs("Jinx", "Woman") + _, digest2 = build_dop_replacement_pairs("Jinx", "Girl") + _, digest3 = build_dop_replacement_pairs("Zapper", "Woman") + + # All should be different + assert digest1 != digest2 + assert digest1 != digest3 + assert digest2 != digest3 + + def test_digest_stable_for_same_inputs(self): + """Test that digest is stable for identical inputs.""" + _, digest1 = build_dop_replacement_pairs("Jinx, Zapper", "Woman, Gun") + _, digest2 = build_dop_replacement_pairs("Jinx, Zapper", "Woman, Gun") + + assert digest1 == digest2 + + +class TestDOPIntegration: + """Integration tests for DOP dataloader patterns.""" + + def test_dop_embeds_collation_pattern(self): + """Test the all-or-nothing collation pattern for DOP embeddings. + + This tests the pattern used in DataLoaderBatchDTO where DOP embeddings + are only collated if ALL items have them. + """ + from toolkit.prompt_utils import PromptEmbeds + import torch + + # Simulate file items with DOP embeddings + class MockFileItem: + def __init__(self, has_dop): + if has_dop: + self.dop_prompt_embeds = PromptEmbeds(torch.randn(77, 768)) + else: + self.dop_prompt_embeds = None + + # Case 1: All items have DOP embeds + file_items_all = [MockFileItem(True), MockFileItem(True), MockFileItem(True)] + + dop_list = [] + for x in file_items_all: + if getattr(x, 'dop_prompt_embeds', None) is None: + dop_list = None + break + dop_list.append(x.dop_prompt_embeds) + + assert dop_list is not None + assert len(dop_list) == 3 + + # Case 2: Some items missing DOP embeds (should fail all-or-nothing check) + file_items_partial = [MockFileItem(True), MockFileItem(False), MockFileItem(True)] + + dop_list = [] + for x in file_items_partial: + if getattr(x, 'dop_prompt_embeds', None) is None: + dop_list = None + break + dop_list.append(x.dop_prompt_embeds) + + assert dop_list is None # All-or-nothing: should be None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/toolkit/cache_utils.py b/toolkit/cache_utils.py new file mode 100644 index 000000000..4ed1703b9 --- /dev/null +++ b/toolkit/cache_utils.py @@ -0,0 +1,182 @@ +"""Utilities for content-aware cache filenames and atomic writes. + +Small, dependency-free helpers used by cache invalidation and atomic writes. +""" +from __future__ import annotations + +import hashlib +import json +import os +import tempfile +from pathlib import Path +from typing import Any, Callable, Iterable, Mapping, Optional + + +def compute_file_sha256(path: Path, chunk_size: int = 1 << 20) -> str: + """Compute full SHA-256 hex digest of file at `path` using chunked reads.""" + h = hashlib.sha256() + with path.open('rb') as f: + while True: + chunk = f.read(chunk_size) + if not chunk: + break + h.update(chunk) + return h.hexdigest() + + +def compute_param_digest(params: Mapping[str, Any], length: int = 12) -> str: + """Compute a stable SHA-256 hex digest of params and return truncated prefix. + + Uses deterministic JSON serialization (sorted keys, compact separators). + """ + canonical = json.dumps(params, sort_keys=True, separators=(',', ':')).encode('utf-8') + h = hashlib.sha256(canonical).hexdigest() + return h[:length] + + +def compute_combined_hash(paths: Iterable[Path], chunk_size: int = 1 << 20) -> str: + """Compute combined SHA-256 over the bytes of the files in *deterministic* order. + + Implementation: compute each file's SHA-256, concatenate the hex digests in order + and hash that concatenated string to produce a combined digest. + """ + digests = [] + for p in paths: + if not p.exists(): + digests.append('') + else: + digests.append(compute_file_sha256(p, chunk_size=chunk_size)) + combined = ''.join(digests).encode('utf-8') + return hashlib.sha256(combined).hexdigest() + + +def cache_filename(base: str, param_digest: str, content_hex: str, ext: str) -> str: + if not ext.startswith('.'): + ext = '.' + ext + return f"{base}_{param_digest}_{content_hex}{ext}" + + +def atomic_write(target: Path, write_fn: Callable[[Path], None], fsync: bool = True) -> None: + """Atomically write a file to `target` using a temporary file in the same directory. + + write_fn(tmp_path) should write to the supplied Path. On success, the temporary + file will be moved into place using os.replace(). On failure the tmp file will + be removed. + """ + target = Path(target) + tmp_dir = target.parent + tmp_file = None + # create a unique tmp filename in same dir to ensure os.replace is atomic + fd = None + try: + os.makedirs(tmp_dir, exist_ok=True) + fd, tmp_path = tempfile.mkstemp(prefix=f".{target.name}.tmp.", dir=str(tmp_dir)) + tmp_path = Path(tmp_path) + os.close(fd) # we'll let write_fn open the path + # call user-provided writer + write_fn(tmp_path) + if fsync: + # fsync file to ensure data durability. Use best-effort fsync and tolerate + # platforms where a particular fsync call may fail. + try: + with tmp_path.open('r+b') as f: + os.fsync(f.fileno()) + except Exception: + # best-effort: ignore fsync failure on some platforms/handles + pass + try: + dir_fd = os.open(str(tmp_dir), os.O_RDONLY) + try: + os.fsync(dir_fd) + finally: + os.close(dir_fd) + except Exception: + # best-effort: ok if fsync dir fails on some platforms + pass + os.replace(str(tmp_path), str(target)) + except Exception: + # cleanup tmp file on any failure + if tmp_path is not None and tmp_path.exists(): + try: + tmp_path.unlink() + except Exception: + pass + raise + + +def find_cached_file(expected_path: Path, legacy_fallback: bool = True) -> Optional[Path]: + """Return a Path to the cached file. + + If `expected_path` exists return it. Otherwise, if `legacy_fallback` is True, + attempt to find a legacy-named cache file by searching for files that start with + the same base (prefix before first underscore) and return the best candidate + (most recent modification time) or None. + """ + expected_path = Path(expected_path) + if expected_path.exists(): + return expected_path + if not legacy_fallback: + return None + parent = expected_path.parent + if not parent.exists(): + return None + # prefix before first underscore in the expected filename + base_name = expected_path.name.split('_')[0] + # get the expected file extension + expected_ext = expected_path.suffix.lower() + candidates = [] + try: + for p in parent.iterdir(): + if not p.is_file(): + continue + # skip temporary files (atomic_write creates .tmp.* files) + if p.name.startswith('.') and '.tmp.' in p.name: + continue + # only consider files with matching extension to avoid picking up temp files + if expected_ext and p.suffix.lower() != expected_ext: + continue + if p.name.startswith(base_name + '_'): + candidates.append(p) + except Exception: + return None + if not candidates: + return None + # pick most recently modified candidate as heuristic + candidates.sort(key=lambda p: p.stat().st_mtime, reverse=True) + return candidates[0] + + +def wait_for_cached_file(expected_path: Path, timeout: float = 5.0, poll_interval: float = 0.1, legacy_fallback: bool = True) -> Optional[Path]: + """Wait up to `timeout` seconds for a cached file to appear and be stable. + + Returns the Path if found and stable, otherwise None. + Stability heuristic: file exists and its size is unchanged across two polls. + """ + import time + + deadline = time.time() + float(timeout) + last_size = None + while time.time() < deadline: + candidate = find_cached_file(Path(expected_path), legacy_fallback=legacy_fallback) + if candidate is None: + time.sleep(poll_interval) + continue + try: + stat = candidate.stat() + size = stat.st_size + # If size is 0, it's still possibly being written; wait for stability + if last_size is None: + last_size = size + time.sleep(poll_interval) + continue + if size == last_size: + return candidate + last_size = size + except FileNotFoundError: + # race: file disappeared between discovery and stat — retry + time.sleep(poll_interval) + continue + except Exception: + time.sleep(poll_interval) + continue + return None diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 974e0cb64..780f0d422 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -1330,11 +1330,6 @@ def validate_configs( # see if any datasets are caching text embeddings is_caching_text_embeddings = any(dataset.cache_text_embeddings for dataset in dataset_configs) if is_caching_text_embeddings: - - # check if they are doing differential output preservation - if train_config.diff_output_preservation: - raise ValueError("Cannot use differential output preservation with caching text embeddings. Please set diff_output_preservation to False.") - # make sure they are all cached for dataset in dataset_configs: if not dataset.cache_text_embeddings: diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 95075a61a..758905e20 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -578,6 +578,10 @@ def __len__(self): def _get_single_item(self, index) -> 'FileItemDTO': file_item: 'FileItemDTO' = copy.deepcopy(self.file_list[index]) + # Inject DOP params from dataset to file_item (ensures survival through pickle/deepcopy) + if hasattr(self, '_dop_enabled') and self._dop_enabled: + file_item._dop_replacement_pairs = getattr(self, '_dop_replacement_pairs', None) + file_item._dop_case_insensitive = getattr(self, '_dop_case_insensitive', False) file_item.load_and_process_image(self.transform) file_item.load_caption(self.caption_dict) return file_item diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 7af8de016..55535b8c4 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -5,7 +5,7 @@ from PIL import Image from PIL.ImageOps import exif_transpose - +from toolkit.print import print_acc from toolkit import image_utils from toolkit.basic import get_quick_signature_string from toolkit.dataloader_mixins import ( @@ -401,6 +401,19 @@ def __init__(self, **kwargs): prompt_embeds_list.append(y) self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list) + # DOP (Differential Output Preservation) embeddings collation + self.dop_prompt_embeds: Union[PromptEmbeds, None] = None + if any([getattr(x, 'dop_prompt_embeds', None) is not None for x in self.file_items]): + dop_list = [] + # Only collate if all items have DOP embeddings (all-or-nothing approach) + for x in self.file_items: + if getattr(x, 'dop_prompt_embeds', None) is None: + dop_list = None + break + dop_list.append(x.dop_prompt_embeds) + if dop_list is not None: + self.dop_prompt_embeds = concat_prompt_embeds(dop_list) + if any([x.audio_tensor is not None for x in self.file_items]): # find one to use as a base base_audio_tensor = None diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index e4fb95bbb..a6e6d1670 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -32,9 +32,9 @@ from toolkit.accelerator import get_accelerator from toolkit.prompt_utils import PromptEmbeds from torchvision.transforms import functional as TF - +from pathlib import Path from toolkit.train_tools import get_torch_dtype - +from toolkit.print import print_acc if TYPE_CHECKING: from toolkit.data_loader import AiToolkitDataset from toolkit.data_transfer_object.data_loader import FileItemDTO @@ -728,6 +728,21 @@ def load_and_process_image( # handle get_prompt_embedding if self.is_text_embedding_cached: self.load_prompt_embedding() + # Also load DOP embedding if enabled (dataloader pattern: always load if available) + dop_pairs = getattr(self, '_dop_replacement_pairs', None) + if dop_pairs: + from toolkit.prompt_utils import apply_dop_replacements + # Ensure caption is loaded before computing transformation + if self.caption is None: + self.load_caption() + # Compute transformed caption on-the-fly + transformed_caption = apply_dop_replacements( + caption=self.caption, + replacement_pairs=dop_pairs, + case_insensitive=getattr(self, '_dop_case_insensitive', False), + debug=False + ) + self.load_dop_prompt_embedding(dop_caption=transformed_caption) # if we are caching latents, just do that if self.is_latent_cached: self.get_latent() @@ -1930,55 +1945,138 @@ def __init__(self, *args, **kwargs): if hasattr(super(), '__init__'): super().__init__(*args, **kwargs) self.prompt_embeds: Union[PromptEmbeds, None] = None + self.dop_prompt_embeds: Union[PromptEmbeds, None] = None self._text_embedding_path: Union[str, None] = None + self._dop_text_embedding_path: Union[str, None] = None self.is_text_embedding_cached = False self.text_embedding_load_device = 'cpu' self.text_embedding_space_version = 'sd1' self.text_embedding_version = 1 + # honor dataset-level preference to keep text embeddings in memory + self.is_caching_text_embeddings_to_memory = getattr(self.dataset_config, 'cache_text_embeddings_to_memory', True) - def get_text_embedding_info_dict(self: 'FileItemDTO'): + def get_text_embedding_info_dict(self: 'FileItemDTO', dop_caption: str = None): # make sure the caption is loaded here - # TODO: we need a way to cache all the other features like trigger words, DOP, etc. For now, we need to throw an error if not compatible. if self.caption is None: self.load_caption() + # Build a deterministic dict describing the text embedding input. + # For DOP, use the transformed caption instead of the original. + caption_to_use = dop_caption if dop_caption is not None else self.caption item = OrderedDict([ - ("caption", self.caption), + ("caption", caption_to_use), ("text_embedding_space_version", self.text_embedding_space_version), ("text_embedding_version", self.text_embedding_version), ]) # if we have a control image, cache the path if self.encode_control_in_text_embeddings and self.control_path is not None: item["control_path"] = self.control_path + # Add a marker for DOP to keep cache keys separate + if dop_caption is not None: + item["is_dop"] = True return item - def get_text_embedding_path(self: 'FileItemDTO', recalculate=False): - if self._text_embedding_path is not None and not recalculate: - return self._text_embedding_path + def get_text_embedding_path(self: 'FileItemDTO', recalculate=False, dop_caption: str = None): + # choose cached path for normal or dop variant + if dop_caption is None: + if self._text_embedding_path is not None and not recalculate: + return self._text_embedding_path else: - # we store text embeddings in a folder in same path as image called _text_embedding_cache - img_dir = os.path.dirname(self.path) - te_dir = os.path.join(img_dir, '_t_e_cache') - hash_dict = self.get_text_embedding_info_dict() - filename_no_ext = os.path.splitext(os.path.basename(self.path))[0] - # get base64 hash of md5 checksum of hash_dict - hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8') - hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii') - hash_str = hash_str.replace('=', '') - self._text_embedding_path = os.path.join(te_dir, f'{filename_no_ext}_{hash_str}.safetensors') + if self._dop_text_embedding_path is not None and not recalculate: + return self._dop_text_embedding_path + + # we store text embeddings in a folder in same path as image called _text_embedding_cache + img_dir = os.path.dirname(self.path) + te_dir = os.path.join(img_dir, '_t_e_cache') + hash_dict = self.get_text_embedding_info_dict(dop_caption=dop_caption) + filename_no_ext = os.path.splitext(os.path.basename(self.path))[0] + # get base64 hash of md5 checksum of hash_dict + # compute param digest (stable) and content digest (caption or transformed caption) + from toolkit.cache_utils import compute_param_digest + param_digest = compute_param_digest(hash_dict) + # For DOP, hash the transformed caption; otherwise use original caption + caption_to_hash = dop_caption if dop_caption is not None else self.caption + content_digest = hashlib.sha256(caption_to_hash.encode('utf-8')).hexdigest() + path = os.path.join(te_dir, f'{filename_no_ext}_{param_digest}_{content_digest}.safetensors') + if dop_caption is None: + self._text_embedding_path = path + else: + self._dop_text_embedding_path = path - return self._text_embedding_path + return path def cleanup_text_embedding(self): + # Respect memory caching preference: if enabled, keep embeddings in-memory (move to CPU), + # otherwise clear them so subsequent batches will reload from disk when needed. if self.prompt_embeds is not None: - # we are caching on disk, don't save in memory - self.prompt_embeds = None + if not getattr(self, 'is_caching_text_embeddings_to_memory', False): + self.prompt_embeds = None + else: + try: + self.prompt_embeds = self.prompt_embeds.to('cpu') + except Exception: + # best-effort: do not fail cleanup + pass + if self.dop_prompt_embeds is not None: + if not getattr(self, 'is_caching_text_embeddings_to_memory', False): + self.dop_prompt_embeds = None + else: + try: + self.dop_prompt_embeds = self.dop_prompt_embeds.to('cpu') + except Exception: + pass def load_prompt_embedding(self, device=None): if not self.is_text_embedding_cached: return if self.prompt_embeds is None: - # load it from disk - self.prompt_embeds = PromptEmbeds.load(self.get_text_embedding_path()) + # load it from disk using robust lookup (hashed + legacy fallback) + from toolkit.cache_utils import find_cached_file, wait_for_cached_file + try: + expected = Path(self.get_text_embedding_path()) + except Exception: + return + cached = wait_for_cached_file(expected, timeout=float(os.getenv('CACHE_WAIT_TIMEOUT', 5.0))) + if not cached: + return + try: + self.prompt_embeds = PromptEmbeds.load(str(cached)) + except FileNotFoundError: + # file was removed racing with load; treat as missing + return + except Exception as e: + # don't let a corrupted cache crash the flow; log and skip + try: + print_acc(f"Warning: failed to load prompt embedding {cached}: {e}") + except Exception: + pass + return + + def load_dop_prompt_embedding(self, dop_caption: str, device=None): + """Load a precomputed DOP variant prompt embedding (if present on disk). + + DOP prompt embeddings are persisted under the `_t_e_cache` directory. + + Args: + dop_caption: The transformed DOP caption (with triggers replaced by classes) + device: Optional device to load to + """ + if self.dop_prompt_embeds is None: + from toolkit.cache_utils import wait_for_cached_file + try: + dop_path = Path(self.get_text_embedding_path(recalculate=False, dop_caption=dop_caption)) + except Exception: + return + + cached = wait_for_cached_file(dop_path, timeout=float(os.getenv('CACHE_WAIT_TIMEOUT', 5.0))) + if not cached: + return + try: + self.dop_prompt_embeds = PromptEmbeds.load(str(cached)) + except FileNotFoundError: + return + except Exception: + return + class TextEmbeddingCachingMixin: def __init__(self: 'AiToolkitDataset', **kwargs): @@ -1992,62 +2090,286 @@ def cache_text_embeddings(self: 'AiToolkitDataset'): print_acc(f"Caching text_embeddings for {self.dataset_path}") print_acc(" - Saving text embeddings to disk") - did_move = False + # If a per-dataset SplitPrompt is configured, encode and save it now (so the trainer or later stages can load it) + ds_cfg = self.dataset_config + sp_enabled = bool(getattr(ds_cfg, 'split_prompt_enabled', False)) + sp_text = getattr(ds_cfg, 'split_prompt', None) + if sp_enabled and sp_text and str(sp_text).strip() != '': + encode_kwargs = {} + if self.sd.encode_control_in_text_embeddings: + # use a blank control image placeholder similar to generator + control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.sd.has_multiple_control_images: + control_image = [control_image] + encode_kwargs['control_images'] = control_image + + # Encode the split prompt. If encoding fails, raise a RuntimeError. + try: + sp_emb = self.sd.encode_prompt(sp_text, **encode_kwargs) + except Exception as e: + raise RuntimeError(f"SplitPrompt encoding failed for dataset {self.dataset_path}: {e}") from e - # use tqdm to show progress - i = 0 - for file_item in tqdm(self.file_list, desc='Caching text embeddings to disk'): + # Save to disk and cache in-memory. Failures during save should raise as well. + try: + sp_emb = sp_emb.to('cpu') + split_path = os.path.join(self.dataset_path, 'split_prompt.safetensors') + from toolkit.cache_utils import atomic_write + atomic_write(Path(split_path), lambda p: sp_emb.save(str(p))) + # also cache in-memory on the dataset object for immediate use + self.split_prompt_embeds = sp_emb + print_acc(f"[SplitPrompt] Saved split prompt embedding to {split_path}") + except Exception as e: + raise RuntimeError(f"Failed to save split prompt embedding for dataset {self.dataset_path}: {e}") from e + + # PRE-CHECK: Count how many files already have valid text embedding caches + from toolkit.cache_utils import find_cached_file + files_needing_encode = [] + files_cached = 0 + + for file_item in self.file_list: file_item.text_embedding_space_version = self.sd.model_config.arch file_item.latent_load_device = self.sd.device + + text_embedding_path = Path(file_item.get_text_embedding_path(recalculate=True)) + cached = find_cached_file(text_embedding_path) + if cached: + files_cached += 1 + file_item.is_text_embedding_cached = True + else: + files_needing_encode.append(file_item) + + # Report cache hit rate + total_files = len(self.file_list) + print_acc(f"Text embedding cache: {files_cached}/{total_files} files cached, {len(files_needing_encode)} need encoding") + + # If everything is cached, we're done! + if len(files_needing_encode) == 0: + print_acc("All text embeddings already cached, skipping encoding") + return - text_embedding_path = file_item.get_text_embedding_path(recalculate=True) - # only process if not saved to disk - if not os.path.exists(text_embedding_path): - # load if not loaded - if not did_move: - self.sd.set_device_state_preset('cache_text_encoder') - did_move = True - - if file_item.encode_control_in_text_embeddings: - if file_item.control_path is None: - raise Exception(f"Could not find a control image for {file_item.path} which is needed for this model") - ctrl_img_list = [] - control_path_list = file_item.control_path - if not isinstance(file_item.control_path, list): - control_path_list = [control_path_list] - for i in range(len(control_path_list)): - try: - img = Image.open(control_path_list[i]).convert("RGB") - img = exif_transpose(img) - # convert to 0 to 1 tensor - img = ( - TF.to_tensor(img) - .unsqueeze(0) - .to(self.sd.device_torch, dtype=self.sd.torch_dtype) - ) - ctrl_img_list.append(img) - except Exception as e: - print_acc(f"Error: {e}") - print_acc(f"Error loading control image: {control_path_list[i]}") - - if len(ctrl_img_list) == 0: - ctrl_img = None - elif not self.sd.has_multiple_control_images: - ctrl_img = ctrl_img_list[0] - else: - ctrl_img = ctrl_img_list - prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img) + did_move = False + + # use tqdm to show progress (only for files that need encoding) + i = 0 + for file_item in tqdm(files_needing_encode, desc='Caching text embeddings to disk'): + # Note: text_embedding_space_version already set during pre-check + text_embedding_path = Path(file_item.get_text_embedding_path(recalculate=True)) + + # We know this file needs encoding (pre-check determined it's not cached) + # load if not loaded + if not did_move: + self.sd.set_device_state_preset('cache_text_encoder') + did_move = True + + if file_item.encode_control_in_text_embeddings: + if file_item.control_path is None: + raise Exception(f"Could not find a control image for {file_item.path} which is needed for this model") + ctrl_img_list = [] + control_path_list = file_item.control_path + if not isinstance(file_item.control_path, list): + control_path_list = [control_path_list] + for i in range(len(control_path_list)): + try: + img = Image.open(control_path_list[i]).convert("RGB") + img = exif_transpose(img) + # convert to 0 to 1 tensor + img = ( + TF.to_tensor(img) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(img) + except Exception as e: + print_acc(f"Error: {e}") + print_acc(f"Error loading control image: {control_path_list[i]}") + + if len(ctrl_img_list) == 0: + ctrl_img = None + elif not self.sd.has_multiple_control_images: + ctrl_img = ctrl_img_list[0] else: - prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption) - # save it - prompt_embeds.save(text_embedding_path) - del prompt_embeds - file_item.is_text_embedding_cached = True + ctrl_img = ctrl_img_list + prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img) + else: + prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption) + + # atomic write to avoid partial files (no race re-check needed since we pre-checked) + from toolkit.cache_utils import atomic_write + try: + atomic_write(text_embedding_path, lambda p: prompt_embeds.save(str(p))) + file_item.is_text_embedding_cached = True + except Exception as e: + # Surface a clear error instead of failing silently — include dataset and file for diagnostics + try: + import traceback + print_acc(f"Error: failed to save text embedding for {file_item.path} to {text_embedding_path}: {e}") + print_acc(traceback.format_exc()) + except Exception: + pass + raise RuntimeError(f"Text embedding caching failed for dataset {self.dataset_path} on file {file_item.path}: {e}") from e + finally: + # ensure we drop the prompt embeds reference to avoid leaking tensors even on error + try: + del prompt_embeds + except Exception: + pass i += 1 # restore device state # if did_move: # self.sd.restore_device_state() + def precompute_dop_embeddings( + self: 'AiToolkitDataset', + triggers_csv: str, + classes_csv: str, + encode_fn: 'Callable[[str], PromptEmbeds]', + case_insensitive: bool = False, + debug: bool = False + ): + """Precompute DOP (Differential Output Preservation) embeddings for all file items. + + This method applies trigger→class text replacements and encodes the transformed + captions, saving them to disk for use during training. Follows the dataloader + pattern: dataloader loads from cache, trainer provides encoding function. + + Args: + triggers_csv: Comma-separated trigger words, e.g., "Jinx, Zapper" + classes_csv: Comma-separated class names, e.g., "Woman, Gun" + encode_fn: Function to encode text → PromptEmbeds (from trainer's sd.encode_prompt) + case_insensitive: Match triggers case-insensitively (default: False) + debug: Log each trigger→class replacement (default: False) + + Process: + 1. Build replacement pairs and compute digest + 2. Check cache hit rate for DOP embeddings + 3. For files missing DOP cache: + a. Apply trigger→class replacements to caption + b. Call encode_fn(transformed_caption) → PromptEmbeds + c. Save to _dop_text_embedding_path with atomic write + d. Store digest in file_item for cache key stability + """ + from toolkit.prompt_utils import build_dop_replacement_pairs, apply_dop_replacements + from toolkit.cache_utils import atomic_write, find_cached_file + from pathlib import Path + + with accelerator.main_process_first(): + print_acc(f"Precomputing DOP embeddings for {self.dataset_path}") + print_acc(f" - Triggers: {triggers_csv}") + print_acc(f" - Classes: {classes_csv}") + + # Build replacement pairs (shared for all files) + replacement_pairs = build_dop_replacement_pairs( + triggers_csv=triggers_csv, + classes_csv=classes_csv, + case_insensitive=case_insensitive + ) + + # Store DOP params on dataset for use in __getitem__ + self._dop_enabled = True + self._dop_replacement_pairs = replacement_pairs + self._dop_case_insensitive = case_insensitive + + # PRE-CHECK: Count how many files already have valid DOP embedding caches + files_needing_encode = [] + files_cached = 0 + + for file_item in self.file_list: + # Store DOP params on file_item so it can compute transformed caption during loading + file_item._dop_replacement_pairs = replacement_pairs + file_item._dop_case_insensitive = case_insensitive + + # Apply trigger→class replacements to get transformed caption for caching + transformed_caption = apply_dop_replacements( + caption=file_item.caption, + replacement_pairs=replacement_pairs, + case_insensitive=case_insensitive, + debug=False + ) + + # Store transformed caption on file_item for later use during encoding + file_item._dop_transformed_caption = transformed_caption + + # Check if DOP embedding already cached (hash based on transformed caption) + # Use strict path check (not find_cached_file which does fuzzy matching) + dop_path = Path(file_item.get_text_embedding_path( + recalculate=True, + dop_caption=transformed_caption + )) + if dop_path.exists(): + files_cached += 1 + else: + files_needing_encode.append(file_item) + + # Report cache hit rate + total_files = len(self.file_list) + print_acc(f"DOP embedding cache: {files_cached}/{total_files} files cached, {len(files_needing_encode)} need encoding") + + # Show example transformations for debugging (first 3 files) + if debug or len(files_needing_encode) == 0: + print_acc("\nDOP transformation examples:") + for i, file_item in enumerate(self.file_list[:3]): + transformed = apply_dop_replacements( + caption=file_item.caption, + replacement_pairs=replacement_pairs, + case_insensitive=case_insensitive, + debug=False + ) + print_acc(f" [{i+1}] Original: {file_item.caption}") + print_acc(f" DOP: {transformed}") + + # If everything is cached, we're done! + if len(files_needing_encode) == 0: + print_acc("\nAll DOP embeddings already cached, skipping encoding") + return + + did_move = False + + # Use tqdm to show progress (only for files that need encoding) + from tqdm import tqdm + for file_item in tqdm(files_needing_encode, desc='Precomputing DOP embeddings to disk'): + # Get transformed caption (already computed and stored in pre-check loop) + transformed_caption = file_item._dop_transformed_caption + + dop_path = Path(file_item.get_text_embedding_path( + recalculate=True, + dop_caption=transformed_caption + )) + + # Set device state for encoding (only once) + if not did_move: + self.sd.set_device_state_preset('cache_text_encoder') + did_move = True + + # Encode transformed caption using trainer's encode function + try: + dop_prompt_embeds = encode_fn(transformed_caption) + except Exception as e: + raise RuntimeError( + f"DOP encoding failed for {file_item.path} with transformed caption '{transformed_caption}': {e}" + ) from e + + # Atomic write to avoid partial files + try: + atomic_write(dop_path, lambda p: dop_prompt_embeds.save(str(p))) + except Exception as e: + try: + import traceback + print_acc(f"Error: failed to save DOP embedding for {file_item.path} to {dop_path}: {e}") + print_acc(traceback.format_exc()) + except Exception: + pass + raise RuntimeError( + f"DOP embedding caching failed for dataset {self.dataset_path} on file {file_item.path}: {e}" + ) from e + finally: + # Drop reference to avoid memory leak + try: + del dop_prompt_embeds + except Exception: + pass + + print_acc(f"DOP precompute complete: {len(files_needing_encode)} files encoded") + class CLIPCachingMixin: def __init__(self: 'AiToolkitDataset', **kwargs): diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 1a9f23e7a..49cc79623 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -413,6 +413,21 @@ def generate_images( except: pass + # Enable CacheDiT acceleration for supported DiT pipelines + cache_dit_enabled = False + try: + import cache_dit + import fnmatch + + pipeline_class_name = pipeline.__class__.__name__ + _, supported_patterns = cache_dit.supported_pipelines() + + if any(fnmatch.fnmatch(pipeline_class_name, p) for p in supported_patterns): + cache_dit.enable_cache(pipeline) + cache_dit_enabled = True + except ImportError: + pass + start_multiplier = 1.0 if network is not None: start_multiplier = network.multiplier @@ -674,6 +689,15 @@ def generate_images( torch.cuda.set_rng_state(cuda_rng_state) self.restore_device_state() + + # Disable CacheDiT to remove any forward hooks before training resumes + if cache_dit_enabled: + try: + import cache_dit + cache_dit.disable_cache(pipeline) + except Exception: + pass + if network is not None: network.train() network.multiplier = start_multiplier diff --git a/toolkit/print.py b/toolkit/print.py index e0f6c23b0..cc8a1dc1f 100644 --- a/toolkit/print.py +++ b/toolkit/print.py @@ -3,9 +3,93 @@ from toolkit.accelerator import get_accelerator -def print_acc(*args, **kwargs): - if get_accelerator().is_local_main_process: - print(*args, **kwargs) +def print_acc(*args, sep=' ', end='\n', file=None, flush=False, **kwargs): + """Print messages in a way that is safe for tqdm progress bars *and* ensures + messages are atomically appended to the log file when `setup_log_to_file` is used. + + Behavior: + - Still only prints on the local main process (keeps Accelerate behavior). + - Writes to the terminal using `tqdm.write` (or fallback) so progress bars are not broken. + - If `sys.stdout` has been wrapped by `Logger`, the message is written to the original + terminal stream and an atomic append (os.write) is used to append to the log file. + - If an explicit `file` is provided (and it's not `sys.stdout`) we print to that file + and also append the message to the log file if present so it stays captured. + """ + if not get_accelerator().is_local_main_process: + return + + # Build message and ensure it ends with the provided end + msg = sep.join([str(a) for a in args]) + if not msg.endswith(end): + msg = msg + end + + # Detect if sys.stdout is the Logger wrapper + outer_stdout = sys.stdout + is_logger = hasattr(outer_stdout, 'log') and hasattr(outer_stdout, 'terminal') + + # We will write to the original terminal stream (if wrapped) so that we can + # control whether the Logger.log gets written by us atomically. + terminal_stream = outer_stdout.terminal if is_logger else outer_stdout + + # If an explicit file object is provided (and it's not the stdout wrapper): honor it + if file is not None and file is not outer_stdout: + try: + print(msg, end='', file=file, flush=flush) + except Exception: + try: + print(msg, end='', flush=flush) + except Exception: + pass + + # Also append to the log file (if present) using atomic os.write + if hasattr(outer_stdout, 'log'): + try: + fd = outer_stdout.log.fileno() + os.write(fd, msg.encode('utf-8')) + if flush: + try: + os.fsync(fd) + except Exception: + pass + except Exception: + # fallback to normal python write + try: + outer_stdout.log.write(msg) + outer_stdout.log.flush() + except Exception: + pass + return + + # Default: print to terminal safely using tqdm.write + try: + from tqdm import tqdm + # tqdm.write will append a newline; strip our trailing newline to avoid double blank lines + tqdm.write(msg.rstrip('\n'), file=terminal_stream) + except Exception: + try: + terminal_stream.write(msg) + terminal_stream.flush() + except Exception: + pass + + # If the stdout wrapper has a log file, append atomically to it + if hasattr(outer_stdout, 'log'): + try: + fd = outer_stdout.log.fileno() + os.write(fd, msg.encode('utf-8')) + # Only force an fsync when flush requested (keeps perf reasonable) + if flush: + try: + os.fsync(fd) + except Exception: + pass + except Exception: + # fallback to text write which is what Logger.write used to do + try: + outer_stdout.log.write(msg) + outer_stdout.log.flush() + except Exception: + pass class Logger: diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index 0bcbe876e..f59da918e 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -1,5 +1,7 @@ import os from typing import Optional, TYPE_CHECKING, List, Union, Tuple +import re +from pathlib import Path import torch from safetensors.torch import load_file, save_file @@ -135,8 +137,18 @@ def save(self, path: str): state_dict[f"attention_mask_{i}"] = attn.cpu() else: state_dict["attention_mask"] = pe.attention_mask.cpu() + from toolkit.cache_utils import atomic_write os.makedirs(os.path.dirname(path), exist_ok=True) - save_file(state_dict, path) + # write via atomic_write to avoid partial/corrupt files + def _writer(p): + # save_file expects a path-like or str + save_file(state_dict, str(p)) + atomic_write(Path(path), _writer) + # record source path for diagnostics + try: + self._source_path = path + except Exception: + pass @classmethod def load(cls, path: str) -> 'PromptEmbeds': @@ -146,6 +158,8 @@ def load(cls, path: str) -> 'PromptEmbeds': :return: An instance of PromptEmbeds. """ state_dict = load_file(path, device='cpu') + # record source path for diagnostics on the created object (set below) + source_path = path text_embeds = [] pooled_embeds = None attention_mask = [] @@ -173,6 +187,10 @@ def load(cls, path: str) -> 'PromptEmbeds': pe.attention_mask = attention_mask[0] else: pe.attention_mask = attention_mask + try: + pe._source_path = source_path + except Exception: + pass return pe @@ -341,6 +359,9 @@ def concat_prompt_pairs(prompt_pairs: list[EncodedPromptPair]): ) + + + def split_prompt_embeds(concatenated: PromptEmbeds, num_parts=None) -> List[PromptEmbeds]: if num_parts is None: # use batch size @@ -484,6 +505,149 @@ def get_permutations(s, max_permutations=8): return [', '.join(permutation) for permutation in permutations] +def parse_csv_list(value: Optional[str]) -> List[str]: + """Parse a comma-separated string into a list preserving explicit empty entries. + + Examples: + "a, b, c" -> ["a", "b", "c"] + "a, , c" -> ["a", "", "c"] + "" -> [] + None -> [] + """ + if value is None: + return [] + # split on comma, preserve empty strings if present + parts = [p.strip() for p in value.split(',')] + # If the input is empty string, return [] (consistent with previous behavior) + if len(parts) == 1 and parts[0] == "": + return [] + return parts + + +def normalize_caption_separators(text: str) -> str: + """Normalize separators and spacing in caption text. + + Ensures there is a single space after commas and collapses repeated whitespace. + This helps avoid tokenizer merging of adjacent tokens in many cases. + """ + if text is None: + return "" + # Ensure space after commas + text = re.sub(r",\s*", ", ", text) + # Collapse multiple spaces + text = re.sub(r"\s+", " ", text).strip() + return text + + +def build_dop_replacement_pairs( + triggers_csv: str, + classes_csv: str, + case_insensitive: bool = False +) -> List[Tuple[str, str]]: + """Build trigger→class replacement pairs from CSV strings. + + Args: + triggers_csv: Comma-separated triggers, e.g., "Jinx, Zapper" + classes_csv: Comma-separated classes, e.g., "Woman, Gun" + case_insensitive: Whether replacements should be case-insensitive + + Returns: + pairs: [(trigger, class), ...] sorted by trigger length DESC + + Examples: + >>> build_dop_replacement_pairs("Jinx, Zapper", "Woman, Gun") + [("Zapper", "Gun"), ("Jinx", "Woman")] + + >>> build_dop_replacement_pairs("Jinx, Zapper, Vest", "Woman, Gun") + [("Zapper", "Gun"), ("Jinx", "Woman"), ("Vest", "")] + """ + triggers = parse_csv_list(triggers_csv) + classes = parse_csv_list(classes_csv) + + # Build pairs: if more triggers than classes, missing classes become empty string + pairs = [(t, classes[i] if i < len(classes) else '') for i, t in enumerate(triggers)] + + # Sort by trigger length DESC to avoid substring collision + # Example: "Jinx Master" should be replaced before "Jinx" + pairs.sort(key=lambda x: len(x[0]) if x[0] else 0, reverse=True) + + return pairs + + +def apply_dop_replacements( + caption: str, + replacement_pairs: List[Tuple[str, str]], + case_insensitive: bool = False, + debug: bool = False +) -> str: + """Apply trigger→class replacements to caption text. + + Follows MultiTrigger.md algorithm: + 1. Normalize caption (spaces, punctuation) + 2. For each (trigger, cls) pair: + - Use word-boundary regex: \\btrigger\\b (respects punctuation) + - Fall back to simple string.replace() if regex fails + 3. Return modified caption + + Args: + caption: Original caption text + replacement_pairs: [(trigger, class), ...] sorted by length DESC + case_insensitive: Match triggers case-insensitively + debug: Log replacement details + + Returns: + Transformed caption with triggers replaced by classes + + Examples: + >>> pairs = [("Zapper", "Gun"), ("Jinx", "Woman")] + >>> apply_dop_replacements("Jinx with a Zapper", pairs) + "Woman with a Gun" + """ + original = caption + out = normalize_caption_separators(caption) + + if not replacement_pairs: + return out + + for trigger, cls in replacement_pairs: + if not trigger: # Skip empty triggers + continue + + try: + # Word-boundary match to avoid replacing substrings inside words + # Use \b for proper word boundaries (works with punctuation) + escaped_trigger = re.escape(trigger) + if case_insensitive: + pattern = rf"\b{escaped_trigger}\b" + out = re.sub(pattern, cls, out, flags=re.IGNORECASE) + else: + pattern = rf"\b{escaped_trigger}\b" + out = re.sub(pattern, cls, out) + except Exception: + # Fallback to simple replace if regex fails + if case_insensitive: + # Manual case-insensitive replace (less efficient but works) + out = re.sub(re.escape(trigger), cls, out, flags=re.IGNORECASE) + else: + out = out.replace(trigger, cls) + + # Clean up comma artifacts from empty replacements + # Handle cases like "a, , b" → "a, b" and "a,, b" → "a, b" + out = re.sub(r",\s*,", ",", out) # ", ," or ",," → "," + out = re.sub(r",\s*,", ",", out) # Run twice to handle "a, , , b" cases + + # Clean up leading/trailing commas and whitespace + out = out.strip(", \t\n") + + # Collapse repeated whitespace after all cleanup + out = re.sub(r"\s+", " ", out).strip() + + if debug: + print(f"[DOP DEBUG] '{original}' → '{out}'") + + return out + + def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']: from toolkit.config_modules import SliderTargetConfig pos_permutations = get_permutations(target.positive, max_permutations=max_permutations) @@ -723,11 +887,15 @@ def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_i # replace it output_prompt = output_prompt.replace(to_replace, replace_with) - if trigger.strip() != "": + # If trigger contains multiple CSV entries, do not auto-prepend it for backwards compatibility + parsed_triggers = parse_csv_list(trigger) + allow_prepend = add_if_not_present and len(parsed_triggers) == 1 and parsed_triggers[0].strip() != '' + + if trigger.strip() != "" and allow_prepend: # see how many times replace_with is in the prompt num_instances = output_prompt.count(replace_with) - if num_instances == 0 and add_if_not_present: + if num_instances == 0: # add it to the beginning of the prompt output_prompt = replace_with + " " + output_prompt diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index bac7f3fde..3e6b3caaa 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -22,6 +22,8 @@ def calculate_shift( class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): def __init__(self, *args, **kwargs): + # Extract custom params not supported by parent class + self._min_shift = kwargs.pop("min_shift", None) super().__init__(*args, **kwargs) self.init_noise_sigma = 1.0 self.timestep_type = "linear" @@ -156,6 +158,9 @@ def set_train_timesteps( self.config.get("base_shift", 0.5), self.config.get("max_shift", 1.16), ) + # Clamp mu to min_shift floor if configured + if self._min_shift is not None: + mu = max(mu, self._min_shift) sigmas = self.time_shift(mu, 1.0, sigmas) else: sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) diff --git a/ui/package-lock.json b/ui/package-lock.json index 1b7ba9cd4..7eec79698 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -1473,6 +1473,7 @@ "resolved": "https://registry.npmjs.org/@types/node/-/node-20.17.19.tgz", "integrity": "sha512-LEwC7o1ifqg/6r2gn9Dns0f1rhK+fPFDoMiceTJ6kWmVk6bgXBI/9IOWfVan4WiAavK9pIVWdX0/e3J+eEUh5A==", "dev": true, + "peer": true, "dependencies": { "undici-types": "~6.19.2" } @@ -1486,6 +1487,7 @@ "version": "19.0.10", "resolved": "https://registry.npmjs.org/@types/react/-/react-19.0.10.tgz", "integrity": "sha512-JuRQ9KXLEjaUNjTWpzuR231Z2WpIwczOkBEIvbHNCzQefFIT0L8IqE6NV6ULLyC1SI/i234JnDoMkfg+RjQj2g==", + "peer": true, "dependencies": { "csstype": "^3.0.2" } @@ -4620,6 +4622,7 @@ "url": "https://github.com/sponsors/ai" } ], + "peer": true, "dependencies": { "nanoid": "^3.3.8", "picocolors": "^1.1.1", @@ -4799,6 +4802,7 @@ "integrity": "sha512-JKCZWvBC3enxk51tY4TWzS4b5iRt4sSU1uHn2I183giZTvonXaQonzVtjLzpOHE7qu9MxY510kAtFGJwryKe3Q==", "hasInstallScript": true, "license": "Apache-2.0", + "peer": true, "dependencies": { "@prisma/engines": "6.3.1" }, @@ -4918,6 +4922,7 @@ "version": "19.0.0", "resolved": "https://registry.npmjs.org/react/-/react-19.0.0.tgz", "integrity": "sha512-V8AVnmPIICiWpGfm6GLzCR/W5FXLchHop40W4nXBmdlEceh16rCN8O8LNWm5bh5XUX91fh7KpA+W0TgMKmgTpQ==", + "peer": true, "engines": { "node": ">=0.10.0" } @@ -4926,6 +4931,7 @@ "version": "19.0.0", "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.0.0.tgz", "integrity": "sha512-4GV5sHFG0e/0AD4X+ySy6UJd3jVl1iNsNHdpad0qhABJ11twS3TTBnseqsKurKcsNqCEFeGL3uLpVChpIO3QfQ==", + "peer": true, "dependencies": { "scheduler": "^0.25.0" }, @@ -4968,13 +4974,15 @@ "node_modules/react-is": { "version": "16.13.1", "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", - "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==" + "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==", + "peer": true }, "node_modules/react-redux": { "version": "9.2.0", "resolved": "https://registry.npmjs.org/react-redux/-/react-redux-9.2.0.tgz", "integrity": "sha512-ROY9fvHhwOD9ySfrF0wmvu//bKCQ6AeZZq1nJNtbDC+kk5DuSuNX/n6YWYF/SYy7bSba4D4FSz8DJeKY/S/r+g==", "license": "MIT", + "peer": true, "dependencies": { "@types/use-sync-external-store": "^0.0.6", "use-sync-external-store": "^1.4.0" @@ -5116,7 +5124,8 @@ "version": "5.0.1", "resolved": "https://registry.npmjs.org/redux/-/redux-5.0.1.tgz", "integrity": "sha512-M9/ELqF6fy8FwmkpnF0S3YKOqMyoWJ4+CS5Efg2ct3oY9daQvd/Pc71FpGZsVsbl3Cpb+IIcjBDUnnyBdQbq4w==", - "license": "MIT" + "license": "MIT", + "peer": true }, "node_modules/redux-thunk": { "version": "3.1.0", @@ -6168,6 +6177,7 @@ "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.7.3.tgz", "integrity": "sha512-84MVSjMEHP+FQRPy3pX9sTVV/INIex71s9TL2Gm5FG/WG1SqXeKyZ0k7/blY/4FdOzI12CBy1vGc4og/eus0fw==", "devOptional": true, + "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver"