-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Adds functionality to populate torch generator using torch.thread_safe_generator #9371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
273b95c
52d3353
30b165a
6277f11
a6bff65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,11 +15,9 @@ | |
| import numpy as np | ||
| import PIL.Image | ||
| import pytest | ||
|
|
||
| import torch | ||
| import torchvision.ops | ||
| import torchvision.transforms.v2 as transforms | ||
|
|
||
| from common_utils import ( | ||
| assert_equal, | ||
| cache, | ||
|
|
@@ -40,14 +38,12 @@ | |
| needs_cvcuda, | ||
| set_rng_seed, | ||
| ) | ||
|
|
||
| from torch import nn | ||
| from torch.testing import assert_close | ||
| from torch.utils._pytree import tree_flatten, tree_map | ||
| from torch.utils.data import DataLoader, default_collate | ||
| from torchvision import tv_tensors | ||
| from torchvision.ops.boxes import box_iou | ||
|
|
||
| from torchvision.transforms._functional_tensor import _max_value as get_max_value | ||
| from torchvision.transforms.functional import pil_modes_mapping, to_pil_image | ||
| from torchvision.transforms.v2 import functional as F | ||
|
|
@@ -63,7 +59,6 @@ | |
| ) | ||
| from torchvision.transforms.v2.functional._utils import _get_kernel, _import_cvcuda, _register_kernel_internal | ||
|
|
||
|
|
||
| # turns all warnings into errors for this module | ||
| pytestmark = [pytest.mark.filterwarnings("error")] | ||
|
|
||
|
|
@@ -8120,3 +8115,64 @@ def test_different_sizes(self, make_input1, make_input2, query): | |
| def test_no_valid_input(self, query): | ||
| with pytest.raises(TypeError, match="No image"): | ||
| query(["blah"]) | ||
|
|
||
|
|
||
| class TestThreadSafeGenerator: | ||
| """Test that transforms correctly use torch.thread_safe_generator(). | ||
|
|
||
| For multiprocessing workers, thread_safe_generator() returns None, | ||
| so transforms use the default process global RNG, | ||
| i.e. for a multiprocessing worker the RNG of that process. | ||
| For thread workers, it returns a thread-local torch.Generator. | ||
| """ | ||
|
|
||
| TRANSFORMS = [ | ||
| transforms.RandomResizedCrop(size=(24, 24)), | ||
| transforms.RandomRotation(degrees=10), | ||
| transforms.RandomAffine(degrees=10), | ||
| transforms.RandomCrop(size=(24, 24), pad_if_needed=True), | ||
| transforms.RandomPerspective(p=1.0), | ||
| transforms.RandomErasing(p=1.0), | ||
| transforms.ScaleJitter(target_size=(24, 24)), | ||
| ] | ||
|
|
||
| class TransformDataset(torch.utils.data.Dataset): | ||
| def __init__(self, size, transform): | ||
| self.size = size | ||
| self.transform = transform | ||
| self.image = make_image((32, 32)) | ||
|
|
||
| def __getitem__(self, idx): | ||
| return self.transform(self.image) | ||
|
|
||
| def __len__(self): | ||
| return self.size | ||
|
|
||
| @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) | ||
| def test_multiprocessing_workers(self, transform): | ||
| """With multiprocessing DataLoader workers, thread_safe_generator() | ||
| returns None and transforms use the per-process global RNG. | ||
| Each worker gets a different seed, so results should differ.""" | ||
| dataset = self.TransformDataset(size=2, transform=transform) | ||
| dl = DataLoader(dataset, batch_size=1, num_workers=2) | ||
| batch0, batch1 = list(dl) | ||
| assert not torch.equal(batch0, batch1) | ||
|
|
||
| @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) | ||
| def test_thread_worker_uses_thread_local_generator(self, transform): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For this multi-threading test, is there a way to test the actual multi-threaded behavior, without the mocking? I.e. ideally I'd like to test the public-facing APIs when a user requests multi-threaded from the DataLoader. I'm not sure what the public entry point is though?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. I am using the mocks only because the threading workers isn't landed yet and I want to land this PR first. Once in, I can update these tests. With respect to the transforms we are not simplifying anything though. With the mocking, each transform function is getting a different generator as would happen if done through the dataloading workers. |
||
| """In thread workers, thread_safe_generator() returns a thread-local | ||
| Generator. Mimic two workers with differently seeded generators | ||
| and verify they produce different results.""" | ||
| image = make_image((32, 32)) | ||
|
|
||
| g0 = torch.Generator() | ||
| g0.manual_seed(0) | ||
| with mock.patch("torch.thread_safe_generator", return_value=g0): | ||
| result_worker0 = transform(image) | ||
|
|
||
| g1 = torch.Generator() | ||
| g1.manual_seed(1) | ||
| with mock.patch("torch.thread_safe_generator", return_value=g1): | ||
| result_worker1 = transform(image) | ||
|
|
||
| assert not torch.equal(result_worker0, result_worker1) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a few more random transforms in TV that we'll also want to update and test. I think the list you'll find in https://github.com/pytorch/vision/pull/7848/changes should have the proper coverage (but
claudeshould be able to find all the relevant ones)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was missing a lot of random transforms. Updated the PR to cover the transforms in these files in
torchvision/transforms/v2:_augment.py
_auto_augment.py
_color.py
_container.py
_geometry.py
_misc.py
_transform.py
This matches the files touched in your PR above.
RandomHorizontalFlip and RandomVerticalFlip were noisy for the test I want, i.e. two batches are different for different seeded workers. This is because the flipped outputs can be same for different seeds. So I added another reproducibility test which cover the flips and other transforms. This check if two transforms are the same for the same torch.thread_safe_generator value. This should add extra coverage.