diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 759f1f44643..03bb6800b16 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -15,11 +15,9 @@ import numpy as np import PIL.Image import pytest - import torch import torchvision.ops import torchvision.transforms.v2 as transforms - from common_utils import ( assert_equal, cache, @@ -40,14 +38,12 @@ needs_cvcuda, set_rng_seed, ) - from torch import nn from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors from torchvision.ops.boxes import box_iou - from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.functional import pil_modes_mapping, to_pil_image from torchvision.transforms.v2 import functional as F @@ -63,7 +59,6 @@ ) from torchvision.transforms.v2.functional._utils import _get_kernel, _import_cvcuda, _register_kernel_internal - # turns all warnings into errors for this module pytestmark = [pytest.mark.filterwarnings("error")] @@ -8120,3 +8115,124 @@ def test_different_sizes(self, make_input1, make_input2, query): def test_no_valid_input(self, query): with pytest.raises(TypeError, match="No image"): query(["blah"]) + + +class TestThreadSafeGenerator: + """Test that transforms correctly use torch.thread_safe_generator(). + + For multiprocessing workers, thread_safe_generator() returns None, + so transforms use the default process global RNG, + i.e. for a multiprocessing worker the RNG of that process. + For thread workers, it returns a thread-local torch.Generator. + """ + + TRANSFORMS = [ + transforms.RandomResizedCrop(size=(24, 24)), + transforms.RandomRotation(degrees=10), + transforms.RandomAffine(degrees=10), + transforms.RandomCrop(size=(24, 24), pad_if_needed=True), + transforms.RandomPerspective(p=1.0), + transforms.RandomErasing(p=1.0), + transforms.ScaleJitter(target_size=(24, 24)), + transforms.RandomZoomOut(), + transforms.ElasticTransform(), + transforms.RandomShortestSize(min_size=(20, 24)), + transforms.RandomResize(min_size=20, max_size=28), + transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1), + transforms.RandomChannelPermutation(), + transforms.RandomPhotometricDistort(), + transforms.AutoAugment(), + transforms.RandAugment(), + transforms.TrivialAugmentWide(), + transforms.AugMix(), + transforms.JPEG(quality=(1, 100)), + transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), + ] + + class TransformDataset(torch.utils.data.Dataset): + def __init__(self, size, transform): + self.size = size + self.transform = transform + self.image = make_image((32, 32)) + + def __getitem__(self, idx): + return self.transform(self.image) + + def __len__(self): + return self.size + + @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) + def test_multiprocessing_workers(self, transform): + """With multiprocessing DataLoader workers, thread_safe_generator() + returns None and transforms use the per-process global RNG. + Each worker gets a different seed, so results should differ.""" + dataset = self.TransformDataset(size=2, transform=transform) + dl = DataLoader(dataset, batch_size=1, num_workers=2) + batch0, batch1 = list(dl) + assert not torch.equal(batch0, batch1) + + @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) + def test_thread_worker_uses_thread_local_generator(self, transform): + """In thread workers, thread_safe_generator() returns a thread-local + Generator. Mimic two workers with differently seeded generators + and verify they produce different results.""" + image = make_image((32, 32)) + + g0 = torch.Generator() + g0.manual_seed(0) + with mock.patch("torch.thread_safe_generator", return_value=g0): + result_worker0 = transform(image) + + g1 = torch.Generator() + g1.manual_seed(5) + with mock.patch("torch.thread_safe_generator", return_value=g1): + result_worker1 = transform(image) + + assert not torch.equal(result_worker0, result_worker1) + + def test_thread_generator_random_iou_crop(self): + """RandomIoUCrop requires bounding boxes, so test it separately.""" + image = make_image((32, 32)) + bboxes = make_bounding_boxes(canvas_size=(32, 32), format="XYXY", num_boxes=3) + + transform = transforms.RandomIoUCrop() + + results = [] + for seed in (0, 1): + g = torch.Generator() + g.manual_seed(seed) + with mock.patch("torch.thread_safe_generator", return_value=g): + result = transform(image, bboxes) + results.append(result) + + # The image output should differ between different seeds + assert not torch.equal(results[0][0], results[1][0]) + + # Reproducibility test list: includes flips which are excluded from + # the divergence tests above. + ALL_TRANSFORMS = TRANSFORMS + [ + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomVerticalFlip(p=0.5), + ] + + @pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=lambda t: type(t).__name__) + def test_thread_generator_reproducibility(self, transform): + """Verify transforms use the provided generator, not the global RNG. + Same seeded generator should produce identical results even when + the global RNG state changes between calls.""" + image = make_image((32, 32)) + + g1 = torch.Generator() + g1.manual_seed(42) + with mock.patch("torch.thread_safe_generator", return_value=g1): + result1 = transform(image) + + # Advance global RNG so it's in a different state + torch.rand(100) + + g2 = torch.Generator() + g2.manual_seed(42) + with mock.patch("torch.thread_safe_generator", return_value=g2): + result2 = transform(image) + + torch.testing.assert_close(result1, result2) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index c6da9aba98b..0a0c94579e5 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -108,12 +108,14 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: area = img_h * img_w log_ratio = self._log_ratio + g = torch.thread_safe_generator() for _ in range(10): - erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item() aspect_ratio = torch.exp( torch.empty(1).uniform_( log_ratio[0], # type: ignore[arg-type] log_ratio[1], # type: ignore[arg-type] + generator=g, ) ).item() @@ -123,12 +125,12 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: continue if self.value is None: - v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() + v = torch.empty([img_c, h, w], dtype=torch.float32).normal_(generator=g) else: v = torch.tensor(self.value)[:, None, None] - i = torch.randint(0, img_h - h + 1, size=(1,)).item() - j = torch.randint(0, img_w - w + 1, size=(1,)).item() + i = torch.randint(0, img_h - h + 1, size=(1,), generator=g).item() + j = torch.randint(0, img_w - w + 1, size=(1,), generator=g).item() break else: i, j, h, w, v = 0, 0, img_h, img_w, None @@ -300,8 +302,9 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: H, W = query_size(flat_inputs) - r_x = torch.randint(W, size=(1,)) - r_y = torch.randint(H, size=(1,)) + g = torch.thread_safe_generator() + r_x = torch.randint(W, size=(1,), generator=g) + r_y = torch.randint(H, size=(1,), generator=g) r = 0.5 * math.sqrt(1.0 - lam) r_w_half = int(r * W) @@ -367,7 +370,8 @@ def __init__(self, quality: Union[int, Sequence[int]]): self.quality = quality def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: - quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item() + g = torch.thread_safe_generator() + quality = torch.randint(self.quality[0], self.quality[1] + 1, (), generator=g).item() return dict(quality=quality) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index 52707af1f2e..714a18fc823 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -38,9 +38,11 @@ def _extract_params_for_v1_transform(self) -> dict[str, Any]: return params - def _get_random_item(self, dct: dict[str, tuple[Callable, bool]]) -> tuple[str, tuple[Callable, bool]]: + def _get_random_item( + self, dct: dict[str, tuple[Callable, bool]], generator: torch.Generator = None + ) -> tuple[str, tuple[Callable, bool]]: keys = tuple(dct.keys()) - key = keys[int(torch.randint(len(keys), ()))] + key = keys[int(torch.randint(len(keys), (), generator=generator))] return key, dct[key] def _flatten_and_extract_image_or_video( @@ -327,10 +329,11 @@ def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) height, width = get_size(image_or_video) # type: ignore[arg-type] - policy = self._policies[int(torch.randint(len(self._policies), ()))] + g = torch.thread_safe_generator() + policy = self._policies[int(torch.randint(len(self._policies), (), generator=g))] for transform_id, probability, magnitude_idx in policy: - if not torch.rand(()) <= probability: + if not torch.rand((), generator=g) <= probability: continue magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] @@ -338,7 +341,7 @@ def forward(self, *inputs: Any) -> Any: magnitudes = magnitudes_fn(10, height, width) if magnitudes is not None: magnitude = float(magnitudes[magnitude_idx]) - if signed and torch.rand(()) <= 0.5: + if signed and torch.rand((), generator=g) <= 0.5: magnitude *= -1 else: magnitude = 0.0 @@ -419,12 +422,13 @@ def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) height, width = get_size(image_or_video) # type: ignore[arg-type] + g = torch.thread_safe_generator() for _ in range(self.num_ops): - transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) + transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE, generator=g) magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: magnitude = float(magnitudes[self.magnitude]) - if signed and torch.rand(()) <= 0.5: + if signed and torch.rand((), generator=g) <= 0.5: magnitude *= -1 else: magnitude = 0.0 @@ -488,12 +492,13 @@ def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) height, width = get_size(image_or_video) # type: ignore[arg-type] - transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) + g = torch.thread_safe_generator() + transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE, generator=g) magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: - magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) - if signed and torch.rand(()) <= 0.5: + magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, (), generator=g))]) + if signed and torch.rand((), generator=g) <= 0.5: magnitude *= -1 else: magnitude = 0.0 @@ -572,9 +577,9 @@ def __init__( self.alpha = alpha self.all_ops = all_ops - def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: + def _sample_dirichlet(self, params: torch.Tensor, generator: torch.Generator = None) -> torch.Tensor: # Must be on a separate method so that we can overwrite it in tests. - return torch._sample_dirichlet(params) + return torch._sample_dirichlet(params, generator) def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs) @@ -595,26 +600,33 @@ def forward(self, *inputs: Any) -> Any: # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a # Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of # augmented image or video. + g = torch.thread_safe_generator() m = self._sample_dirichlet( - torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1) + torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1), + generator=g, ) # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos. combined_weights = self._sample_dirichlet( - torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1) + torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1), + generator=g, ) * m[:, 1].reshape([batch_dims[0], -1]) mix = m[:, 0].reshape(batch_dims) * batch for i in range(self.mixture_width): aug = batch - depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item()) + depth = ( + self.chain_depth + if self.chain_depth > 0 + else int(torch.randint(low=1, high=4, size=(1,), generator=g).item()) + ) for _ in range(depth): - transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space) + transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space, generator=g) magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width) if magnitudes is not None: - magnitude = float(magnitudes[int(torch.randint(self.severity, ()))]) - if signed and torch.rand(()) <= 0.5: + magnitude = float(magnitudes[int(torch.randint(self.severity, (), generator=g))]) + if signed and torch.rand((), generator=g) <= 0.5: magnitude *= -1 else: magnitude = 0.0 diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index bf4ae55d232..1948dd9e47f 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -140,16 +140,17 @@ def _check_input( return None if value[0] == value[1] == center else (float(value[0]), float(value[1])) @staticmethod - def _generate_value(left: float, right: float) -> float: - return torch.empty(1).uniform_(left, right).item() + def _generate_value(left: float, right: float, generator: torch.Generator = None) -> float: + return torch.empty(1).uniform_(left, right, generator=generator).item() def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: - fn_idx = torch.randperm(4) + g = torch.thread_safe_generator() + fn_idx = torch.randperm(4, generator=g) - b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1]) - c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1]) - s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1]) - h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1]) + b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1], g) + c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1], g) + s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1], g) + h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1], g) return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h) @@ -176,7 +177,8 @@ class RandomChannelPermutation(Transform): def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: num_channels, *_ = query_chw(flat_inputs) - return dict(permutation=torch.randperm(num_channels)) + g = torch.thread_safe_generator() + return dict(permutation=torch.randperm(num_channels, generator=g)) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.permute_channels, inpt, params["permutation"]) @@ -223,8 +225,9 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: num_channels, *_ = query_chw(flat_inputs) + g = torch.thread_safe_generator() params: dict[str, Any] = { - key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None + key: ColorJitter._generate_value(range[0], range[1], g) if torch.rand(1, generator=g) < self.p else None for key, range in [ ("brightness_factor", self.brightness), ("contrast_factor", self.contrast), @@ -232,8 +235,10 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: ("hue_factor", self.hue), ] } - params["contrast_before"] = bool(torch.rand(()) < 0.5) - params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None + params["contrast_before"] = bool(torch.rand((), generator=g) < 0.5) + params["channel_permutation"] = ( + torch.randperm(num_channels, generator=g) if torch.rand(1, generator=g) < self.p else None + ) return params def transform(self, inpt: Any, params: dict[str, Any]) -> Any: diff --git a/torchvision/transforms/v2/_container.py b/torchvision/transforms/v2/_container.py index 95ec25a22f8..dee01f4802e 100644 --- a/torchvision/transforms/v2/_container.py +++ b/torchvision/transforms/v2/_container.py @@ -101,7 +101,8 @@ def _extract_params_for_v1_transform(self) -> dict[str, Any]: def forward(self, *inputs: Any) -> Any: needs_unpacking = len(inputs) > 1 - if torch.rand(1) >= self.p: + g = torch.thread_safe_generator() + if torch.rand(1, generator=g) >= self.p: return inputs if needs_unpacking else inputs[0] for transform in self.transforms: @@ -149,7 +150,8 @@ def __init__( self.p = [prob / total for prob in p] def forward(self, *inputs: Any) -> Any: - idx = int(torch.multinomial(torch.tensor(self.p), 1)) + g = torch.thread_safe_generator() + idx = int(torch.multinomial(torch.tensor(self.p), 1, generator=g)) transform = self.transforms[idx] return transform(*inputs) @@ -173,7 +175,8 @@ def __init__(self, transforms: Sequence[Callable]) -> None: def forward(self, *inputs: Any) -> Any: needs_unpacking = len(inputs) > 1 - for idx in torch.randperm(len(self.transforms)): + g = torch.thread_safe_generator() + for idx in torch.randperm(len(self.transforms), generator=g): transform = self.transforms[idx] outputs = transform(*inputs) inputs = outputs if needs_unpacking else (outputs,) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index c88f3d9a504..eedda4c4c6a 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -6,7 +6,6 @@ import PIL.Image import torch - from torchvision import transforms as _transforms, tv_tensors from torchvision.ops.boxes import box_iou from torchvision.transforms.functional import _get_perspective_coeffs @@ -281,13 +280,16 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: height, width = query_size(flat_inputs) area = height * width + g = torch.thread_safe_generator() + log_ratio = self._log_ratio for _ in range(10): - target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item() aspect_ratio = torch.exp( torch.empty(1).uniform_( log_ratio[0], # type: ignore[arg-type] log_ratio[1], # type: ignore[arg-type] + generator=g, ) ).item() @@ -295,8 +297,8 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: h = int(round(math.sqrt(target_area / aspect_ratio))) if 0 < w <= width and 0 < h <= height: - i = torch.randint(0, height - h + 1, size=(1,)).item() - j = torch.randint(0, width - w + 1, size=(1,)).item() + i = torch.randint(0, height - h + 1, size=(1,), generator=g).item() + j = torch.randint(0, width - w + 1, size=(1,), generator=g).item() break else: # Fallback to central crop @@ -547,11 +549,13 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: orig_h, orig_w = query_size(flat_inputs) - r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) + g = torch.thread_safe_generator() + + r = self.side_range[0] + torch.rand(1, generator=g) * (self.side_range[1] - self.side_range[0]) canvas_width = int(orig_w * r) canvas_height = int(orig_h * r) - r = torch.rand(2) + r = torch.rand(2, generator=g) left = int((canvas_width - orig_w) * r[0]) top = int((canvas_height - orig_h) * r[1]) right = canvas_width - (left + orig_w) @@ -628,7 +632,8 @@ def __init__( self.center = center def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: - angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() + g = torch.thread_safe_generator() + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1], generator=g).item() return dict(angle=angle) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: @@ -728,26 +733,28 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: height, width = query_size(flat_inputs) - angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() + g = torch.thread_safe_generator() + + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1], generator=g).item() if self.translate is not None: max_dx = float(self.translate[0] * width) max_dy = float(self.translate[1] * height) - tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) - ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) + tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx, generator=g).item())) + ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy, generator=g).item())) translate = (tx, ty) else: translate = (0, 0) if self.scale is not None: - scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + scale = torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item() else: scale = 1.0 shear_x = shear_y = 0.0 if self.shear is not None: - shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item() + shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1], generator=g).item() if len(self.shear) == 4: - shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item() + shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3], generator=g).item() shear = (shear_x, shear_y) return dict(angle=angle, translate=translate, scale=scale, shear=shear) @@ -885,13 +892,15 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: padding = [pad_left, pad_top, pad_right, pad_bottom] needs_pad = any(padding) + g = torch.thread_safe_generator() + needs_vert_crop, top = ( - (True, int(torch.randint(0, padded_height - cropped_height + 1, size=()))) + (True, int(torch.randint(0, padded_height - cropped_height + 1, size=(), generator=g))) if padded_height > cropped_height else (False, 0) ) needs_horz_crop, left = ( - (True, int(torch.randint(0, padded_width - cropped_width + 1, size=()))) + (True, int(torch.randint(0, padded_width - cropped_width + 1, size=(), generator=g))) if padded_width > cropped_width else (False, 0) ) @@ -970,21 +979,24 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: half_width = width // 2 bound_height = int(distortion_scale * half_height) + 1 bound_width = int(distortion_scale * half_width) + 1 + + g = torch.thread_safe_generator() + topleft = [ - int(torch.randint(0, bound_width, size=(1,))), - int(torch.randint(0, bound_height, size=(1,))), + int(torch.randint(0, bound_width, size=(1,), generator=g)), + int(torch.randint(0, bound_height, size=(1,), generator=g)), ] topright = [ - int(torch.randint(width - bound_width, width, size=(1,))), - int(torch.randint(0, bound_height, size=(1,))), + int(torch.randint(width - bound_width, width, size=(1,), generator=g)), + int(torch.randint(0, bound_height, size=(1,), generator=g)), ] botright = [ - int(torch.randint(width - bound_width, width, size=(1,))), - int(torch.randint(height - bound_height, height, size=(1,))), + int(torch.randint(width - bound_width, width, size=(1,), generator=g)), + int(torch.randint(height - bound_height, height, size=(1,), generator=g)), ] botleft = [ - int(torch.randint(0, bound_width, size=(1,))), - int(torch.randint(height - bound_height, height, size=(1,))), + int(torch.randint(0, bound_width, size=(1,), generator=g)), + int(torch.randint(height - bound_height, height, size=(1,), generator=g)), ] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] endpoints = [topleft, topright, botright, botleft] @@ -1065,7 +1077,9 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: height, width = query_size(flat_inputs) - dx = torch.rand(1, 1, height, width) * 2 - 1 + g = torch.thread_safe_generator() + + dx = torch.rand(1, 1, height, width, generator=g) * 2 - 1 if self.sigma[0] > 0.0: kx = int(8 * self.sigma[0] + 1) # if kernel size is even we have to make it odd @@ -1074,7 +1088,7 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma)) dx = dx * self.alpha[0] / width - dy = torch.rand(1, 1, height, width) * 2 - 1 + dy = torch.rand(1, 1, height, width, generator=g) * 2 - 1 if self.sigma[1] > 0.0: ky = int(8 * self.sigma[1] + 1) # if kernel size is even we have to make it odd @@ -1157,16 +1171,18 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: orig_h, orig_w = query_size(flat_inputs) bboxes = get_bounding_boxes(flat_inputs) + g = torch.thread_safe_generator() + while True: # sample an option - idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) + idx = int(torch.randint(low=0, high=len(self.options), size=(1,), generator=g)) min_jaccard_overlap = self.options[idx] if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option return dict() for _ in range(self.trials): # check the aspect ratio limitations - r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2) + r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2, generator=g) new_w = int(orig_w * r[0]) new_h = int(orig_h * r[1]) aspect_ratio = new_w / new_h @@ -1174,7 +1190,7 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: continue # check for 0 area crops - r = torch.rand(2) + r = torch.rand(2, generator=g) left = int((orig_w - new_w) * r[0]) top = int((orig_h - new_h) * r[1]) right = left + new_w @@ -1206,7 +1222,6 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: - if len(params) < 1: return inpt @@ -1276,7 +1291,9 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: orig_height, orig_width = query_size(flat_inputs) - scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) + g = torch.thread_safe_generator() + + scale = self.scale_range[0] + torch.rand(1, generator=g) * (self.scale_range[1] - self.scale_range[0]) r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale new_width = int(orig_width * r) new_height = int(orig_height * r) @@ -1341,7 +1358,9 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: orig_height, orig_width = query_size(flat_inputs) - min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] + g = torch.thread_safe_generator() + + min_size = self.min_size[int(torch.randint(len(self.min_size), (), generator=g))] r = min_size / min(orig_height, orig_width) if self.max_size is not None: r = min(r, self.max_size / max(orig_height, orig_width)) @@ -1418,7 +1437,8 @@ def __init__( self.antialias = antialias def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: - size = int(torch.randint(self.min_size, self.max_size, ())) + g = torch.thread_safe_generator() + size = int(torch.randint(self.min_size, self.max_size, (), generator=g)) return dict(size=[size]) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 305149c87b1..66f4f8e18cf 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -207,7 +207,8 @@ def __init__( raise ValueError(f"sigma values should be positive and of the form (min, max). Got {self.sigma}") def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: - sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item() + g = torch.thread_safe_generator() + sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1], generator=g).item() return dict(sigma=[sigma, sigma]) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index ac84fcb6c82..ae02e736f05 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -178,7 +178,8 @@ def forward(self, *inputs: Any) -> Any: self.check_inputs(flat_inputs) - if torch.rand(1) >= self.p: + g = torch.thread_safe_generator() + if torch.rand(1, generator=g) >= self.p: return inputs needs_transform_list = self._needs_transform_list(flat_inputs)