Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,6 @@ cython_debug/
._.DS_Store
aitk_db.db
/notes.md
/data
/data
masks/*
MULTITRIGGER_DOP_REFACTOR_PLAN.md
8 changes: 6 additions & 2 deletions extensions_built_in/diffusion_models/z_image/z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@

scheduler_config = {
"num_train_timesteps": 1000,
"use_dynamic_shifting": False,
"shift": 3.0,
"use_dynamic_shifting": True,
"base_image_seq_len": 256, # 512x512 with VAE/16 and patch/2
"max_image_seq_len": 1024, # 1024x1024
"base_shift": 0.5,
"max_shift": 0.85,
"min_shift": 0.33, # Floor to prevent very low noise sampling
}


Expand Down
97 changes: 78 additions & 19 deletions extensions_built_in/sd_trainer/SDTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,32 @@ def hook_before_train_loop(self):
self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs)
if self.trigger_word is not None:
self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word, **encode_kwargs)

# DOP: Precompute embeddings via dataloader (new pattern)
if self.train_config.diff_output_preservation:
self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class)

from toolkit.prompt_utils import build_dop_replacement_pairs

triggers_csv = self.trigger_word
classes_csv = self.train_config.diff_output_preservation_class

# Build replacement pairs for fallback encoding
self._dop_replacement_pairs = build_dop_replacement_pairs(
triggers_csv=triggers_csv,
classes_csv=classes_csv,
case_insensitive=False
)

# Delegate precompute to dataloader
datasets = get_dataloader_datasets(self.data_loader)
for dataset in datasets:
dataset.precompute_dop_embeddings(
triggers_csv=triggers_csv,
classes_csv=classes_csv,
encode_fn=lambda caption: self.sd.encode_prompt(caption, **encode_kwargs),
case_insensitive=False,
debug=getattr(self.train_config, 'diff_output_preservation_debug', False)
)

self.cache_sample_prompts()

print_acc("\n***** UNLOADING TEXT ENCODER *****")
Expand Down Expand Up @@ -1522,6 +1545,13 @@ def get_adapter_multiplier():
unconditional_embeds = concat_prompt_embeds(
[unconditional_embeds] * noisy_latents.shape[0]
)
if self.train_config.diff_output_preservation:

if batch.dop_prompt_embeds is not None:
# use the cached embeds
self.diff_output_preservation_embeds = batch.dop_prompt_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)

if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
Expand Down Expand Up @@ -1587,18 +1617,39 @@ def get_adapter_multiplier():
self.adapter.is_unconditional_run = False

if self.train_config.diff_output_preservation:
dop_prompts = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in conditioned_prompts]
dop_prompts_2 = None
if prompt_2 is not None:
dop_prompts_2 = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in prompt_2]
self.diff_output_preservation_embeds = self.sd.encode_prompt(
dop_prompts, dop_prompts_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts,
**prompt_kwargs
).to(
self.device_torch,
dtype=dtype)
# Use dataloader-provided DOP embeddings (new pattern)
self.diff_output_preservation_embeds = batch.dop_prompt_embeds

# Fallback: encode on-the-fly if cache missing (text encoder still available in this branch)
if self.diff_output_preservation_embeds is None:
from toolkit.prompt_utils import apply_dop_replacements
print_acc("[DOP] Cache missing for batch - encoding on-the-fly")

# Apply trigger→class replacements using utility
dop_prompts = [
apply_dop_replacements(p, self._dop_replacement_pairs, debug=False)
for p in conditioned_prompts
]
dop_prompts_2 = None
if prompt_2 is not None:
dop_prompts_2 = [
apply_dop_replacements(p, self._dop_replacement_pairs, debug=False)
for p in prompt_2
]
self.diff_output_preservation_embeds = self.sd.encode_prompt(
dop_prompts, dop_prompts_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts,
**prompt_kwargs
).to(
self.device_torch,
dtype=dtype)
else:
# Move cached embeddings to device
self.diff_output_preservation_embeds = self.diff_output_preservation_embeds.to(
self.device_torch,
dtype=dtype
)
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
if self.train_config.do_cfg:
Expand Down Expand Up @@ -1784,7 +1835,10 @@ def get_adapter_multiplier():
prior_embeds_to_use = conditional_embeds
# use diff_output_preservation embeds if doing dfe
if self.train_config.diff_output_preservation:
prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
if self.diff_output_preservation_embeds is None:
print_acc("[DOP WARNING] diff_output_preservation enabled but embeddings missing - skipping DOP this step")
else:
prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])

if self.train_config.blank_prompt_preservation:
blank_embeds = self.cached_blank_embeds.clone().detach().to(
Expand Down Expand Up @@ -1983,13 +2037,18 @@ def get_adapter_multiplier():
prior_pred=prior_to_calculate_loss,
)

if self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation:
# Check if DOP is actually available this step (embeds might be missing if cache failed)
do_dop_this_step = self.train_config.diff_output_preservation and self.diff_output_preservation_embeds is not None
if self.train_config.diff_output_preservation and self.diff_output_preservation_embeds is None:
print_acc("[DOP WARNING] diff_output_preservation enabled but embeddings missing - skipping preservation loss this step")

if do_dop_this_step or self.train_config.blank_prompt_preservation:
# send the loss backwards otherwise checkpointing will fail
self.accelerator.backward(loss)
normal_loss = loss.detach() # dont send backward again

with torch.no_grad():
if self.train_config.diff_output_preservation:
if do_dop_this_step:
preservation_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
elif self.train_config.blank_prompt_preservation:
blank_embeds = self.cached_blank_embeds.clone().detach().to(
Expand All @@ -2006,7 +2065,7 @@ def get_adapter_multiplier():
batch=batch,
**pred_kwargs
)
multiplier = self.train_config.diff_output_preservation_multiplier if self.train_config.diff_output_preservation else self.train_config.blank_prompt_preservation_multiplier
multiplier = self.train_config.diff_output_preservation_multiplier if do_dop_this_step else self.train_config.blank_prompt_preservation_multiplier
preservation_loss = torch.nn.functional.mse_loss(preservation_pred, prior_pred) * multiplier
self.accelerator.backward(preservation_loss)

Expand Down
4 changes: 2 additions & 2 deletions jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,8 +1174,8 @@ def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
timestep_type = 'shift'

patch_size = 1
if self.sd.is_flux or 'flex' in self.sd.arch:
# flux is a patch size of 1, but latents are divided by 2, so we need to double it
if self.sd.is_flux or 'flex' in self.sd.arch or self.sd.arch == 'zimage':
# flux/zimage is a patch size of 1, but latents are divided by 2, so we need to double it
patch_size = 2
elif hasattr(self.sd.unet.config, 'patch_size'):
patch_size = self.sd.unet.config.patch_size
Expand Down
148 changes: 148 additions & 0 deletions testing/test_dop_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""
Tests for DOP (Differential Output Preservation) dataloader integration.

Tests the new clean separation: dataloader handles all embedding I/O,
trainer consumes batch.dop_prompt_embeds directly.
"""

import pytest
from toolkit.prompt_utils import build_dop_replacement_pairs, apply_dop_replacements


class TestDOPTextTransformations:
"""Test DOP text transformation utilities."""

def test_build_dop_replacement_pairs_simple(self):
"""Test building replacement pairs with equal triggers and classes."""
pairs, digest = build_dop_replacement_pairs("Jinx, Zapper", "Woman, Gun")

assert len(pairs) == 2
# Should be sorted by length DESC
assert pairs[0] == ("Zapper", "Gun")
assert pairs[1] == ("Jinx", "Woman")
assert isinstance(digest, str)
assert len(digest) > 0

def test_build_dop_replacement_pairs_more_triggers(self):
"""Test with more triggers than classes - missing classes become empty string."""
pairs, digest = build_dop_replacement_pairs("Jinx, Zapper, Vest", "Woman, Gun")

assert len(pairs) == 3
assert pairs[0] == ("Zapper", "Gun")
assert pairs[1] == ("Jinx", "Woman")
assert pairs[2] == ("Vest", "")

def test_build_dop_replacement_pairs_sorting(self):
"""Test that pairs are sorted by trigger length DESC to avoid substring issues."""
pairs, _ = build_dop_replacement_pairs("Jinx, Jinx Master", "Woman, Veteran")

# "Jinx Master" should come first (longer)
assert len(pairs) == 2
assert pairs[0] == ("Jinx Master", "Veteran")
assert pairs[1] == ("Jinx", "Woman")

def test_apply_dop_replacements_simple(self):
"""Test simple trigger→class replacement."""
pairs = [("Jinx", "Woman"), ("Zapper", "Gun")]
result = apply_dop_replacements("Jinx with a Zapper", pairs)

assert result == "Woman with a Gun"

def test_apply_dop_replacements_multiple_occurrences(self):
"""Test replacement of multiple occurrences of same trigger."""
pairs = [("Jinx", "Woman")]
result = apply_dop_replacements("Jinx likes Jinx", pairs)

assert result == "Woman likes Woman"

def test_apply_dop_replacements_longest_first(self):
"""Test that longer triggers are replaced first (substring protection)."""
pairs = [("Jinx Master", "Veteran"), ("Jinx", "Woman")] # Already sorted
result = apply_dop_replacements("Jinx Master and Jinx", pairs)

# Should replace "Jinx Master" first, then "Jinx"
assert result == "Veteran and Woman"

def test_apply_dop_replacements_empty_class(self):
"""Test replacement with empty class (removes trigger)."""
pairs = [("Vest", ""), ("Jinx", "Woman")]
result = apply_dop_replacements("Jinx wearing Vest", pairs)

# "Vest" replaced with empty string
assert result == "Woman wearing"

def test_apply_dop_replacements_normalizes_whitespace(self):
"""Test that result has normalized whitespace."""
pairs = [("item1", ""), ("item2", "foo")]
result = apply_dop_replacements("item1 item2", pairs)

# Should collapse repeated whitespace
assert result == "foo"

def test_digest_changes_with_inputs(self):
"""Test that digest changes when inputs change."""
_, digest1 = build_dop_replacement_pairs("Jinx", "Woman")
_, digest2 = build_dop_replacement_pairs("Jinx", "Girl")
_, digest3 = build_dop_replacement_pairs("Zapper", "Woman")

# All should be different
assert digest1 != digest2
assert digest1 != digest3
assert digest2 != digest3

def test_digest_stable_for_same_inputs(self):
"""Test that digest is stable for identical inputs."""
_, digest1 = build_dop_replacement_pairs("Jinx, Zapper", "Woman, Gun")
_, digest2 = build_dop_replacement_pairs("Jinx, Zapper", "Woman, Gun")

assert digest1 == digest2


class TestDOPIntegration:
"""Integration tests for DOP dataloader patterns."""

def test_dop_embeds_collation_pattern(self):
"""Test the all-or-nothing collation pattern for DOP embeddings.

This tests the pattern used in DataLoaderBatchDTO where DOP embeddings
are only collated if ALL items have them.
"""
from toolkit.prompt_utils import PromptEmbeds
import torch

# Simulate file items with DOP embeddings
class MockFileItem:
def __init__(self, has_dop):
if has_dop:
self.dop_prompt_embeds = PromptEmbeds(torch.randn(77, 768))
else:
self.dop_prompt_embeds = None

# Case 1: All items have DOP embeds
file_items_all = [MockFileItem(True), MockFileItem(True), MockFileItem(True)]

dop_list = []
for x in file_items_all:
if getattr(x, 'dop_prompt_embeds', None) is None:
dop_list = None
break
dop_list.append(x.dop_prompt_embeds)

assert dop_list is not None
assert len(dop_list) == 3

# Case 2: Some items missing DOP embeds (should fail all-or-nothing check)
file_items_partial = [MockFileItem(True), MockFileItem(False), MockFileItem(True)]

dop_list = []
for x in file_items_partial:
if getattr(x, 'dop_prompt_embeds', None) is None:
dop_list = None
break
dop_list.append(x.dop_prompt_embeds)

assert dop_list is None # All-or-nothing: should be None


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading