Adds functionality to populate torch generator using torch.thread_safe_generator#9371
Adds functionality to populate torch generator using torch.thread_safe_generator#9371divyanshk wants to merge 5 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9371
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 PendingAs of commit a6bff65 with merge base 48956e0 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
1e0225e to
b15da1c
Compare
NicolasHug
left a comment
There was a problem hiding this comment.
Thanks for the PR @divyanshk . I think the changes look reasonable.
One thing I'm wondering is how does this affect the multiprocess-based dataloaders? Currently, since TV is using the global torch RNG, that global generator will be seeded by torch using a different seed for each process/worker. This is the correct behavior since we want each worker to have a different RNG.
Is that behavior preserved now that we're using torch.thread_safe_generator()?
It'd be good to have tests ensure that's the case (both for multiprocess and multithreaded cases).
b15da1c to
18bdef3
Compare
|
The multiprocessing case remains unchanged because torch.thread_safe_generator will return None for multiprocessing use-case. So for MP, there is no change. Earlier the torch.rand functions received None for generator arg, and now they would get the same. Also added a test case where I confirm the expected behavior for multiprocessing. |
c416fad to
e7da958
Compare
e7da958 to
6277f11
Compare
| transforms.RandomPerspective(p=1.0), | ||
| transforms.RandomErasing(p=1.0), | ||
| transforms.ScaleJitter(target_size=(24, 24)), | ||
| ] |
There was a problem hiding this comment.
We have a few more random transforms in TV that we'll also want to update and test. I think the list you'll find in https://github.com/pytorch/vision/pull/7848/changes should have the proper coverage (but claude should be able to find all the relevant ones)
There was a problem hiding this comment.
I was missing a lot of random transforms. Updated the PR to cover the transforms in these files in torchvision/transforms/v2:
_augment.py
_auto_augment.py
_color.py
_container.py
_geometry.py
_misc.py
_transform.py
This matches the files touched in your PR above.
RandomHorizontalFlip and RandomVerticalFlip were noisy for the test I want, i.e. two batches are different for different seeded workers. This is because the flipped outputs can be same for different seeds. So I added another reproducibility test which cover the flips and other transforms. This check if two transforms are the same for the same torch.thread_safe_generator value. This should add extra coverage.
| assert not torch.equal(batch0, batch1) | ||
|
|
||
| @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) | ||
| def test_thread_worker_uses_thread_local_generator(self, transform): |
There was a problem hiding this comment.
For this multi-threading test, is there a way to test the actual multi-threaded behavior, without the mocking? I.e. ideally I'd like to test the public-facing APIs when a user requests multi-threaded from the DataLoader. I'm not sure what the public entry point is though?
There was a problem hiding this comment.
I agree. I am using the mocks only because the threading workers isn't landed yet and I want to land this PR first. Once in, I can update these tests. With respect to the transforms we are not simplifying anything though. With the mocking, each transform function is getting a different generator as would happen if done through the dataloading workers.
Added thread-safe random number generation to all V2 torchvision random transforms to prevent race conditions when using DataLoader with thread-based workers (worker_method='thread').
This is based on
torch.thread_safe_generatorwhich returns dataloader thread-worker specific RNG or None otherwise.