Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 61 additions & 5 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
import numpy as np
import PIL.Image
import pytest

import torch
import torchvision.ops
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_equal,
cache,
Expand All @@ -40,14 +38,12 @@
needs_cvcuda,
set_rng_seed,
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
from torchvision.ops.boxes import box_iou

from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.functional import pil_modes_mapping, to_pil_image
from torchvision.transforms.v2 import functional as F
Expand All @@ -63,7 +59,6 @@
)
from torchvision.transforms.v2.functional._utils import _get_kernel, _import_cvcuda, _register_kernel_internal


# turns all warnings into errors for this module
pytestmark = [pytest.mark.filterwarnings("error")]

Expand Down Expand Up @@ -8120,3 +8115,64 @@ def test_different_sizes(self, make_input1, make_input2, query):
def test_no_valid_input(self, query):
with pytest.raises(TypeError, match="No image"):
query(["blah"])


class TestThreadSafeGenerator:
"""Test that transforms correctly use torch.thread_safe_generator().

For multiprocessing workers, thread_safe_generator() returns None,
so transforms use the default process global RNG,
i.e. for a multiprocessing worker the RNG of that process.
For thread workers, it returns a thread-local torch.Generator.
"""

TRANSFORMS = [
transforms.RandomResizedCrop(size=(24, 24)),
transforms.RandomRotation(degrees=10),
transforms.RandomAffine(degrees=10),
transforms.RandomCrop(size=(24, 24), pad_if_needed=True),
transforms.RandomPerspective(p=1.0),
transforms.RandomErasing(p=1.0),
transforms.ScaleJitter(target_size=(24, 24)),
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a few more random transforms in TV that we'll also want to update and test. I think the list you'll find in https://github.com/pytorch/vision/pull/7848/changes should have the proper coverage (but claude should be able to find all the relevant ones)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was missing a lot of random transforms. Updated the PR to cover the transforms in these files in torchvision/transforms/v2:
_augment.py
_auto_augment.py
_color.py
_container.py
_geometry.py
_misc.py
_transform.py

This matches the files touched in your PR above.

RandomHorizontalFlip and RandomVerticalFlip were noisy for the test I want, i.e. two batches are different for different seeded workers. This is because the flipped outputs can be same for different seeds. So I added another reproducibility test which cover the flips and other transforms. This check if two transforms are the same for the same torch.thread_safe_generator value. This should add extra coverage.


class TransformDataset(torch.utils.data.Dataset):
def __init__(self, size, transform):
self.size = size
self.transform = transform
self.image = make_image((32, 32))

def __getitem__(self, idx):
return self.transform(self.image)

def __len__(self):
return self.size

@pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__)
def test_multiprocessing_workers(self, transform):
"""With multiprocessing DataLoader workers, thread_safe_generator()
returns None and transforms use the per-process global RNG.
Each worker gets a different seed, so results should differ."""
dataset = self.TransformDataset(size=2, transform=transform)
dl = DataLoader(dataset, batch_size=1, num_workers=2)
batch0, batch1 = list(dl)
assert not torch.equal(batch0, batch1)

@pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__)
def test_thread_worker_uses_thread_local_generator(self, transform):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this multi-threading test, is there a way to test the actual multi-threaded behavior, without the mocking? I.e. ideally I'd like to test the public-facing APIs when a user requests multi-threaded from the DataLoader. I'm not sure what the public entry point is though?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I am using the mocks only because the threading workers isn't landed yet and I want to land this PR first. Once in, I can update these tests. With respect to the transforms we are not simplifying anything though. With the mocking, each transform function is getting a different generator as would happen if done through the dataloading workers.

"""In thread workers, thread_safe_generator() returns a thread-local
Generator. Mimic two workers with differently seeded generators
and verify they produce different results."""
image = make_image((32, 32))

g0 = torch.Generator()
g0.manual_seed(0)
with mock.patch("torch.thread_safe_generator", return_value=g0):
result_worker0 = transform(image)

g1 = torch.Generator()
g1.manual_seed(1)
with mock.patch("torch.thread_safe_generator", return_value=g1):
result_worker1 = transform(image)

assert not torch.equal(result_worker0, result_worker1)
84 changes: 52 additions & 32 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import PIL.Image
import torch

from torchvision import transforms as _transforms, tv_tensors
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import _get_perspective_coeffs
Expand Down Expand Up @@ -281,22 +280,25 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
height, width = query_size(flat_inputs)
area = height * width

g = torch.thread_safe_generator()

log_ratio = self._log_ratio
for _ in range(10):
target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
generator=g,
)
).item()

w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))

if 0 < w <= width and 0 < h <= height:
i = torch.randint(0, height - h + 1, size=(1,)).item()
j = torch.randint(0, width - w + 1, size=(1,)).item()
i = torch.randint(0, height - h + 1, size=(1,), generator=g).item()
j = torch.randint(0, width - w + 1, size=(1,), generator=g).item()
break
else:
# Fallback to central crop
Expand Down Expand Up @@ -547,11 +549,13 @@ def __init__(
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
orig_h, orig_w = query_size(flat_inputs)

r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
g = torch.thread_safe_generator()

r = self.side_range[0] + torch.rand(1, generator=g) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r)

r = torch.rand(2)
r = torch.rand(2, generator=g)
left = int((canvas_width - orig_w) * r[0])
top = int((canvas_height - orig_h) * r[1])
right = canvas_width - (left + orig_w)
Expand Down Expand Up @@ -628,7 +632,8 @@ def __init__(
self.center = center

def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
g = torch.thread_safe_generator()
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1], generator=g).item()
return dict(angle=angle)

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
Expand Down Expand Up @@ -728,26 +733,28 @@ def __init__(
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
height, width = query_size(flat_inputs)

angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
g = torch.thread_safe_generator()

angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1], generator=g).item()
if self.translate is not None:
max_dx = float(self.translate[0] * width)
max_dy = float(self.translate[1] * height)
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx, generator=g).item()))
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy, generator=g).item()))
translate = (tx, ty)
else:
translate = (0, 0)

if self.scale is not None:
scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
scale = torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item()
else:
scale = 1.0

shear_x = shear_y = 0.0
if self.shear is not None:
shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()
shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1], generator=g).item()
if len(self.shear) == 4:
shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()
shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3], generator=g).item()

shear = (shear_x, shear_y)
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
Expand Down Expand Up @@ -885,13 +892,15 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
padding = [pad_left, pad_top, pad_right, pad_bottom]
needs_pad = any(padding)

g = torch.thread_safe_generator()

needs_vert_crop, top = (
(True, int(torch.randint(0, padded_height - cropped_height + 1, size=())))
(True, int(torch.randint(0, padded_height - cropped_height + 1, size=(), generator=g)))
if padded_height > cropped_height
else (False, 0)
)
needs_horz_crop, left = (
(True, int(torch.randint(0, padded_width - cropped_width + 1, size=())))
(True, int(torch.randint(0, padded_width - cropped_width + 1, size=(), generator=g)))
if padded_width > cropped_width
else (False, 0)
)
Expand Down Expand Up @@ -970,21 +979,24 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
half_width = width // 2
bound_height = int(distortion_scale * half_height) + 1
bound_width = int(distortion_scale * half_width) + 1

g = torch.thread_safe_generator()

topleft = [
int(torch.randint(0, bound_width, size=(1,))),
int(torch.randint(0, bound_height, size=(1,))),
int(torch.randint(0, bound_width, size=(1,), generator=g)),
int(torch.randint(0, bound_height, size=(1,), generator=g)),
]
topright = [
int(torch.randint(width - bound_width, width, size=(1,))),
int(torch.randint(0, bound_height, size=(1,))),
int(torch.randint(width - bound_width, width, size=(1,), generator=g)),
int(torch.randint(0, bound_height, size=(1,), generator=g)),
]
botright = [
int(torch.randint(width - bound_width, width, size=(1,))),
int(torch.randint(height - bound_height, height, size=(1,))),
int(torch.randint(width - bound_width, width, size=(1,), generator=g)),
int(torch.randint(height - bound_height, height, size=(1,), generator=g)),
]
botleft = [
int(torch.randint(0, bound_width, size=(1,))),
int(torch.randint(height - bound_height, height, size=(1,))),
int(torch.randint(0, bound_width, size=(1,), generator=g)),
int(torch.randint(height - bound_height, height, size=(1,), generator=g)),
]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
endpoints = [topleft, topright, botright, botleft]
Expand Down Expand Up @@ -1065,7 +1077,9 @@ def __init__(
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
height, width = query_size(flat_inputs)

dx = torch.rand(1, 1, height, width) * 2 - 1
g = torch.thread_safe_generator()

dx = torch.rand(1, 1, height, width, generator=g) * 2 - 1
if self.sigma[0] > 0.0:
kx = int(8 * self.sigma[0] + 1)
# if kernel size is even we have to make it odd
Expand All @@ -1074,7 +1088,7 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma))
dx = dx * self.alpha[0] / width

dy = torch.rand(1, 1, height, width) * 2 - 1
dy = torch.rand(1, 1, height, width, generator=g) * 2 - 1
if self.sigma[1] > 0.0:
ky = int(8 * self.sigma[1] + 1)
# if kernel size is even we have to make it odd
Expand Down Expand Up @@ -1157,24 +1171,26 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
orig_h, orig_w = query_size(flat_inputs)
bboxes = get_bounding_boxes(flat_inputs)

g = torch.thread_safe_generator()

while True:
# sample an option
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
idx = int(torch.randint(low=0, high=len(self.options), size=(1,), generator=g))
min_jaccard_overlap = self.options[idx]
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
return dict()

for _ in range(self.trials):
# check the aspect ratio limitations
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2, generator=g)
new_w = int(orig_w * r[0])
new_h = int(orig_h * r[1])
aspect_ratio = new_w / new_h
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
continue

# check for 0 area crops
r = torch.rand(2)
r = torch.rand(2, generator=g)
left = int((orig_w - new_w) * r[0])
top = int((orig_h - new_h) * r[1])
right = left + new_w
Expand Down Expand Up @@ -1206,7 +1222,6 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area)

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:

if len(params) < 1:
return inpt

Expand Down Expand Up @@ -1276,7 +1291,9 @@ def __init__(
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
orig_height, orig_width = query_size(flat_inputs)

scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
g = torch.thread_safe_generator()

scale = self.scale_range[0] + torch.rand(1, generator=g) * (self.scale_range[1] - self.scale_range[0])
r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
new_width = int(orig_width * r)
new_height = int(orig_height * r)
Expand Down Expand Up @@ -1341,7 +1358,9 @@ def __init__(
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
orig_height, orig_width = query_size(flat_inputs)

min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
g = torch.thread_safe_generator()

min_size = self.min_size[int(torch.randint(len(self.min_size), (), generator=g))]
r = min_size / min(orig_height, orig_width)
if self.max_size is not None:
r = min(r, self.max_size / max(orig_height, orig_width))
Expand Down Expand Up @@ -1418,7 +1437,8 @@ def __init__(
self.antialias = antialias

def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
size = int(torch.randint(self.min_size, self.max_size, ()))
g = torch.thread_safe_generator()
size = int(torch.randint(self.min_size, self.max_size, (), generator=g))
return dict(size=[size])

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
Expand Down
3 changes: 2 additions & 1 deletion torchvision/transforms/v2/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def forward(self, *inputs: Any) -> Any:

self.check_inputs(flat_inputs)

if torch.rand(1) >= self.p:
g = torch.thread_safe_generator()
if torch.rand(1, generator=g) >= self.p:
return inputs

needs_transform_list = self._needs_transform_list(flat_inputs)
Expand Down
Loading