-
Notifications
You must be signed in to change notification settings - Fork 7.1k
remove private imports from torch.testing #7525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -21,8 +21,8 @@ | |||||||||||||||||
import torch | ||||||||||||||||||
import torch.testing | ||||||||||||||||||
from PIL import Image | ||||||||||||||||||
from torch.utils._pytree import tree_flatten | ||||||||||||||||||
|
||||||||||||||||||
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair | ||||||||||||||||||
from torchvision import datapoints, io | ||||||||||||||||||
from torchvision.transforms._functional_tensor import _max_value as get_max_value | ||||||||||||||||||
from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_tensor | ||||||||||||||||||
|
@@ -270,84 +270,57 @@ def combinations_grid(**kwargs): | |||||||||||||||||
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class ImagePair(TensorLikePair): | ||||||||||||||||||
def __init__( | ||||||||||||||||||
self, | ||||||||||||||||||
actual, | ||||||||||||||||||
expected, | ||||||||||||||||||
*, | ||||||||||||||||||
mae=False, | ||||||||||||||||||
**other_parameters, | ||||||||||||||||||
): | ||||||||||||||||||
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): | ||||||||||||||||||
actual, expected = [to_image_tensor(input) for input in [actual, expected]] | ||||||||||||||||||
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) | ||||||||||||||||||
|
||||||||||||||||||
super().__init__(actual, expected, **other_parameters) | ||||||||||||||||||
self.mae = mae | ||||||||||||||||||
|
||||||||||||||||||
def compare(self) -> None: | ||||||||||||||||||
actual, expected = self.actual, self.expected | ||||||||||||||||||
def assert_close_with_image_support(actual, expected, *, mae=False, atol=None, msg=None, **kwargs): | ||||||||||||||||||
def compare(actual, expected): | ||||||||||||||||||
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): | ||||||||||||||||||
actual, expected = [to_image_tensor(input) for input in [actual, expected]] | ||||||||||||||||||
|
||||||||||||||||||
self._compare_attributes(actual, expected) | ||||||||||||||||||
actual, expected = self._equalize_attributes(actual, expected) | ||||||||||||||||||
enable_mae_comparison = all(isinstance(input, torch.Tensor) for input in [actual, expected]) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be
Suggested change
? Regardless I don't really understand the reason for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wrote this as a general function that one can use like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But are we ever going to use it when the inputs aren't images (either PIL or Tensors)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we have such usages:
|
||||||||||||||||||
|
||||||||||||||||||
value_comparison_failure = False | ||||||||||||||||||
|
||||||||||||||||||
def msg_callback(default_msg): | ||||||||||||||||||
# This is a dirty hack that let's us "hook" into the comparison logic of `torch.testing.assert_close`. It | ||||||||||||||||||
# will only be triggered in case of failed value comparison. This let's us reuse all of the attribute | ||||||||||||||||||
# checking that `torch.testing.assert_close` does while giving us the ability to ignore the regular value | ||||||||||||||||||
# comparison. | ||||||||||||||||||
Comment on lines
+286
to
+289
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Per comment. Note while this is a hack, it doesn't depend on any non-public functionality or assumptions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT
Suggested change
|
||||||||||||||||||
nonlocal value_comparison_failure | ||||||||||||||||||
value_comparison_failure = True | ||||||||||||||||||
Comment on lines
+290
to
+291
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC this is the "hack" part and the rest below is just the same as the implementation for the default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. Basically this is our way of detecting that there was value comparison failure since this callback is only called in this specific case. |
||||||||||||||||||
if msg is None: | ||||||||||||||||||
return default_msg | ||||||||||||||||||
elif isinstance(msg, str): | ||||||||||||||||||
return msg | ||||||||||||||||||
elif callable(msg): | ||||||||||||||||||
return msg(default_msg) | ||||||||||||||||||
else: | ||||||||||||||||||
raise pytest.UsageError(f"`msg` can be either be `None`, a `str` or a callable, but got {msg}.") | ||||||||||||||||||
|
||||||||||||||||||
if self.mae: | ||||||||||||||||||
if actual.dtype is torch.uint8: | ||||||||||||||||||
actual, expected = actual.to(torch.int), expected.to(torch.int) | ||||||||||||||||||
mae = float(torch.abs(actual - expected).float().mean()) | ||||||||||||||||||
if mae > self.atol: | ||||||||||||||||||
self._fail( | ||||||||||||||||||
AssertionError, | ||||||||||||||||||
f"The MAE of the images is {mae}, but only {self.atol} is allowed.", | ||||||||||||||||||
) | ||||||||||||||||||
else: | ||||||||||||||||||
super()._compare_values(actual, expected) | ||||||||||||||||||
try: | ||||||||||||||||||
return torch.testing.assert_close( | ||||||||||||||||||
actual, expected, atol=atol, msg=msg_callback if enable_mae_comparison else msg, **kwargs | ||||||||||||||||||
) | ||||||||||||||||||
except AssertionError: | ||||||||||||||||||
if not (value_comparison_failure and mae): | ||||||||||||||||||
raise | ||||||||||||||||||
Comment on lines
+306
to
+307
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might be misunderstanding, but why do we even bother There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||
|
||||||||||||||||||
if actual.dtype is torch.uint8: | ||||||||||||||||||
actual, expected = actual.to(torch.int), expected.to(torch.int) | ||||||||||||||||||
mae_value = float(torch.abs(actual - expected).float().mean()) | ||||||||||||||||||
Comment on lines
+309
to
+311
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Hopefully correct?) NIT
Suggested change
|
||||||||||||||||||
if mae_value > atol: | ||||||||||||||||||
raise AssertionError(f"The MAE of the images is {mae_value}, but only {atol} is allowed.") | ||||||||||||||||||
|
||||||||||||||||||
def assert_close( | ||||||||||||||||||
actual, | ||||||||||||||||||
expected, | ||||||||||||||||||
*, | ||||||||||||||||||
allow_subclasses=True, | ||||||||||||||||||
rtol=None, | ||||||||||||||||||
atol=None, | ||||||||||||||||||
equal_nan=False, | ||||||||||||||||||
check_device=True, | ||||||||||||||||||
check_dtype=True, | ||||||||||||||||||
check_layout=True, | ||||||||||||||||||
check_stride=False, | ||||||||||||||||||
msg=None, | ||||||||||||||||||
**kwargs, | ||||||||||||||||||
): | ||||||||||||||||||
"""Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison""" | ||||||||||||||||||
__tracebackhide__ = True | ||||||||||||||||||
|
||||||||||||||||||
error_metas = not_close_error_metas( | ||||||||||||||||||
actual, | ||||||||||||||||||
expected, | ||||||||||||||||||
pair_types=( | ||||||||||||||||||
NonePair, | ||||||||||||||||||
BooleanPair, | ||||||||||||||||||
NumberPair, | ||||||||||||||||||
ImagePair, | ||||||||||||||||||
TensorLikePair, | ||||||||||||||||||
), | ||||||||||||||||||
allow_subclasses=allow_subclasses, | ||||||||||||||||||
rtol=rtol, | ||||||||||||||||||
atol=atol, | ||||||||||||||||||
equal_nan=equal_nan, | ||||||||||||||||||
check_device=check_device, | ||||||||||||||||||
check_dtype=check_dtype, | ||||||||||||||||||
check_layout=check_layout, | ||||||||||||||||||
check_stride=check_stride, | ||||||||||||||||||
**kwargs, | ||||||||||||||||||
) | ||||||||||||||||||
actual_flat, actual_spec = tree_flatten(actual) | ||||||||||||||||||
expected_flat, expected_spec = tree_flatten(expected) | ||||||||||||||||||
assert actual_spec, expected_spec | ||||||||||||||||||
|
||||||||||||||||||
if error_metas: | ||||||||||||||||||
raise error_metas[0].to_error(msg) | ||||||||||||||||||
for actual_item, expected_item in zip(actual_flat, expected_flat): | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only "regression" compared to the patched We can re-implement this behavior if needed, but not sure its worth it. I guesstimate that 99% of our comparisons are single elements and thus no extra traceback is needed. Apart from this, |
||||||||||||||||||
compare(actual_item, expected_item) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
assert_equal = functools.partial(assert_close, rtol=0, atol=0) | ||||||||||||||||||
assert_equal_with_image_support = functools.partial(assert_close_with_image_support, rtol=0, atol=0) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def parametrized_error_message(*args, **kwargs): | ||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not convert all inputs to tensors, regardless of whether they're all PIL Images? Does that mean we can't do
assert_close_with_image_support(some_tensor, some_pil_img)
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #7525 (comment) for one reason. Plus, we originally had the option to pass in one tensor and one PIL image and let the function handle the conversion to tensor. However, we often saw extra flakiness in the tests just to this conversion and thus we scraped that feature. Since this irregular anyway, the few tests that need this, perform the conversion themselves now. E.g.
vision/test/transforms_v2_kernel_infos.py
Line 117 in 5579995