Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 121 additions & 5 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
import numpy as np
import PIL.Image
import pytest

import torch
import torchvision.ops
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_equal,
cache,
Expand All @@ -40,14 +38,12 @@
needs_cvcuda,
set_rng_seed,
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
from torchvision.ops.boxes import box_iou

from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.functional import pil_modes_mapping, to_pil_image
from torchvision.transforms.v2 import functional as F
Expand All @@ -63,7 +59,6 @@
)
from torchvision.transforms.v2.functional._utils import _get_kernel, _import_cvcuda, _register_kernel_internal


# turns all warnings into errors for this module
pytestmark = [pytest.mark.filterwarnings("error")]

Expand Down Expand Up @@ -8120,3 +8115,124 @@ def test_different_sizes(self, make_input1, make_input2, query):
def test_no_valid_input(self, query):
with pytest.raises(TypeError, match="No image"):
query(["blah"])


class TestThreadSafeGenerator:
"""Test that transforms correctly use torch.thread_safe_generator().

For multiprocessing workers, thread_safe_generator() returns None,
so transforms use the default process global RNG,
i.e. for a multiprocessing worker the RNG of that process.
For thread workers, it returns a thread-local torch.Generator.
"""

TRANSFORMS = [
transforms.RandomResizedCrop(size=(24, 24)),
transforms.RandomRotation(degrees=10),
transforms.RandomAffine(degrees=10),
transforms.RandomCrop(size=(24, 24), pad_if_needed=True),
transforms.RandomPerspective(p=1.0),
transforms.RandomErasing(p=1.0),
transforms.ScaleJitter(target_size=(24, 24)),
transforms.RandomZoomOut(),
transforms.ElasticTransform(),
transforms.RandomShortestSize(min_size=(20, 24)),
transforms.RandomResize(min_size=20, max_size=28),
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
transforms.RandomChannelPermutation(),
transforms.RandomPhotometricDistort(),
transforms.AutoAugment(),
transforms.RandAugment(),
transforms.TrivialAugmentWide(),
transforms.AugMix(),
transforms.JPEG(quality=(1, 100)),
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a few more random transforms in TV that we'll also want to update and test. I think the list you'll find in https://github.com/pytorch/vision/pull/7848/changes should have the proper coverage (but claude should be able to find all the relevant ones)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was missing a lot of random transforms. Updated the PR to cover the transforms in these files in torchvision/transforms/v2:
_augment.py
_auto_augment.py
_color.py
_container.py
_geometry.py
_misc.py
_transform.py

This matches the files touched in your PR above.

RandomHorizontalFlip and RandomVerticalFlip were noisy for the test I want, i.e. two batches are different for different seeded workers. This is because the flipped outputs can be same for different seeds. So I added another reproducibility test which cover the flips and other transforms. This check if two transforms are the same for the same torch.thread_safe_generator value. This should add extra coverage.


class TransformDataset(torch.utils.data.Dataset):
def __init__(self, size, transform):
self.size = size
self.transform = transform
self.image = make_image((32, 32))

def __getitem__(self, idx):
return self.transform(self.image)

def __len__(self):
return self.size

@pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__)
def test_multiprocessing_workers(self, transform):
"""With multiprocessing DataLoader workers, thread_safe_generator()
returns None and transforms use the per-process global RNG.
Each worker gets a different seed, so results should differ."""
dataset = self.TransformDataset(size=2, transform=transform)
dl = DataLoader(dataset, batch_size=1, num_workers=2)
batch0, batch1 = list(dl)
assert not torch.equal(batch0, batch1)

@pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__)
def test_thread_worker_uses_thread_local_generator(self, transform):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this multi-threading test, is there a way to test the actual multi-threaded behavior, without the mocking? I.e. ideally I'd like to test the public-facing APIs when a user requests multi-threaded from the DataLoader. I'm not sure what the public entry point is though?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I am using the mocks only because the threading workers isn't landed yet and I want to land this PR first. Once in, I can update these tests. With respect to the transforms we are not simplifying anything though. With the mocking, each transform function is getting a different generator as would happen if done through the dataloading workers.

"""In thread workers, thread_safe_generator() returns a thread-local
Generator. Mimic two workers with differently seeded generators
and verify they produce different results."""
image = make_image((32, 32))

g0 = torch.Generator()
g0.manual_seed(0)
with mock.patch("torch.thread_safe_generator", return_value=g0):
result_worker0 = transform(image)

g1 = torch.Generator()
g1.manual_seed(5)
with mock.patch("torch.thread_safe_generator", return_value=g1):
result_worker1 = transform(image)

assert not torch.equal(result_worker0, result_worker1)

def test_thread_generator_random_iou_crop(self):
"""RandomIoUCrop requires bounding boxes, so test it separately."""
image = make_image((32, 32))
bboxes = make_bounding_boxes(canvas_size=(32, 32), format="XYXY", num_boxes=3)

transform = transforms.RandomIoUCrop()

results = []
for seed in (0, 1):
g = torch.Generator()
g.manual_seed(seed)
with mock.patch("torch.thread_safe_generator", return_value=g):
result = transform(image, bboxes)
results.append(result)

# The image output should differ between different seeds
assert not torch.equal(results[0][0], results[1][0])

# Reproducibility test list: includes flips which are excluded from
# the divergence tests above.
ALL_TRANSFORMS = TRANSFORMS + [
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
]

@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=lambda t: type(t).__name__)
def test_thread_generator_reproducibility(self, transform):
"""Verify transforms use the provided generator, not the global RNG.
Same seeded generator should produce identical results even when
the global RNG state changes between calls."""
image = make_image((32, 32))

g1 = torch.Generator()
g1.manual_seed(42)
with mock.patch("torch.thread_safe_generator", return_value=g1):
result1 = transform(image)

# Advance global RNG so it's in a different state
torch.rand(100)

g2 = torch.Generator()
g2.manual_seed(42)
with mock.patch("torch.thread_safe_generator", return_value=g2):
result2 = transform(image)

torch.testing.assert_close(result1, result2)
18 changes: 11 additions & 7 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,14 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
area = img_h * img_w

log_ratio = self._log_ratio
g = torch.thread_safe_generator()
for _ in range(10):
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
generator=g,
)
).item()

Expand All @@ -123,12 +125,12 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
continue

if self.value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_(generator=g)
else:
v = torch.tensor(self.value)[:, None, None]

i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
i = torch.randint(0, img_h - h + 1, size=(1,), generator=g).item()
j = torch.randint(0, img_w - w + 1, size=(1,), generator=g).item()
break
else:
i, j, h, w, v = 0, 0, img_h, img_w, None
Expand Down Expand Up @@ -300,8 +302,9 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:

H, W = query_size(flat_inputs)

r_x = torch.randint(W, size=(1,))
r_y = torch.randint(H, size=(1,))
g = torch.thread_safe_generator()
r_x = torch.randint(W, size=(1,), generator=g)
r_y = torch.randint(H, size=(1,), generator=g)

r = 0.5 * math.sqrt(1.0 - lam)
r_w_half = int(r * W)
Expand Down Expand Up @@ -367,7 +370,8 @@ def __init__(self, quality: Union[int, Sequence[int]]):
self.quality = quality

def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item()
g = torch.thread_safe_generator()
quality = torch.randint(self.quality[0], self.quality[1] + 1, (), generator=g).item()
return dict(quality=quality)

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
Expand Down
48 changes: 30 additions & 18 deletions torchvision/transforms/v2/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def _extract_params_for_v1_transform(self) -> dict[str, Any]:

return params

def _get_random_item(self, dct: dict[str, tuple[Callable, bool]]) -> tuple[str, tuple[Callable, bool]]:
def _get_random_item(
self, dct: dict[str, tuple[Callable, bool]], generator: torch.Generator = None
) -> tuple[str, tuple[Callable, bool]]:
keys = tuple(dct.keys())
key = keys[int(torch.randint(len(keys), ()))]
key = keys[int(torch.randint(len(keys), (), generator=generator))]
return key, dct[key]

def _flatten_and_extract_image_or_video(
Expand Down Expand Up @@ -327,18 +329,19 @@ def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_size(image_or_video) # type: ignore[arg-type]

policy = self._policies[int(torch.randint(len(self._policies), ()))]
g = torch.thread_safe_generator()
policy = self._policies[int(torch.randint(len(self._policies), (), generator=g))]

for transform_id, probability, magnitude_idx in policy:
if not torch.rand(()) <= probability:
if not torch.rand((), generator=g) <= probability:
continue

magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]

magnitudes = magnitudes_fn(10, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5:
if signed and torch.rand((), generator=g) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
Expand Down Expand Up @@ -419,12 +422,13 @@ def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_size(image_or_video) # type: ignore[arg-type]

g = torch.thread_safe_generator()
for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE, generator=g)
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[self.magnitude])
if signed and torch.rand(()) <= 0.5:
if signed and torch.rand((), generator=g) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
Expand Down Expand Up @@ -488,12 +492,13 @@ def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_size(image_or_video) # type: ignore[arg-type]

transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
g = torch.thread_safe_generator()
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE, generator=g)

magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, (), generator=g))])
if signed and torch.rand((), generator=g) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
Expand Down Expand Up @@ -572,9 +577,9 @@ def __init__(
self.alpha = alpha
self.all_ops = all_ops

def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
def _sample_dirichlet(self, params: torch.Tensor, generator: torch.Generator = None) -> torch.Tensor:
# Must be on a separate method so that we can overwrite it in tests.
return torch._sample_dirichlet(params)
return torch._sample_dirichlet(params, generator)

def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
Expand All @@ -595,26 +600,33 @@ def forward(self, *inputs: Any) -> Any:
# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
# Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of
# augmented image or video.
g = torch.thread_safe_generator()
m = self._sample_dirichlet(
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1),
generator=g,
)

# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
combined_weights = self._sample_dirichlet(
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1),
generator=g,
) * m[:, 1].reshape([batch_dims[0], -1])

mix = m[:, 0].reshape(batch_dims) * batch
for i in range(self.mixture_width):
aug = batch
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
depth = (
self.chain_depth
if self.chain_depth > 0
else int(torch.randint(low=1, high=4, size=(1,), generator=g).item())
)
for _ in range(depth):
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space, generator=g)

magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude = float(magnitudes[int(torch.randint(self.severity, (), generator=g))])
if signed and torch.rand((), generator=g) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
Expand Down
27 changes: 16 additions & 11 deletions torchvision/transforms/v2/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,17 @@ def _check_input(
return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))

@staticmethod
def _generate_value(left: float, right: float) -> float:
return torch.empty(1).uniform_(left, right).item()
def _generate_value(left: float, right: float, generator: torch.Generator = None) -> float:
return torch.empty(1).uniform_(left, right, generator=generator).item()

def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
fn_idx = torch.randperm(4)
g = torch.thread_safe_generator()
fn_idx = torch.randperm(4, generator=g)

b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1])
c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1])
s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1])
h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1])
b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1], g)
c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1], g)
s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1], g)
h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1], g)

return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h)

Expand All @@ -176,7 +177,8 @@ class RandomChannelPermutation(Transform):

def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs)
return dict(permutation=torch.randperm(num_channels))
g = torch.thread_safe_generator()
return dict(permutation=torch.randperm(num_channels, generator=g))

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
return self._call_kernel(F.permute_channels, inpt, params["permutation"])
Expand Down Expand Up @@ -223,17 +225,20 @@ def __init__(

def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs)
g = torch.thread_safe_generator()
params: dict[str, Any] = {
key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None
key: ColorJitter._generate_value(range[0], range[1], g) if torch.rand(1, generator=g) < self.p else None
for key, range in [
("brightness_factor", self.brightness),
("contrast_factor", self.contrast),
("saturation_factor", self.saturation),
("hue_factor", self.hue),
]
}
params["contrast_before"] = bool(torch.rand(()) < 0.5)
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
params["contrast_before"] = bool(torch.rand((), generator=g) < 0.5)
params["channel_permutation"] = (
torch.randperm(num_channels, generator=g) if torch.rand(1, generator=g) < self.p else None
)
return params

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
Expand Down
Loading
Loading