Skip to content

remove private imports from torch.testing #7525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 43 additions & 70 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import torch
import torch.testing
from PIL import Image
from torch.utils._pytree import tree_flatten

from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import datapoints, io
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_tensor
Expand Down Expand Up @@ -270,84 +270,57 @@ def combinations_grid(**kwargs):
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]


class ImagePair(TensorLikePair):
def __init__(
self,
actual,
expected,
*,
mae=False,
**other_parameters,
):
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
actual, expected = [to_image_tensor(input) for input in [actual, expected]]
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)

super().__init__(actual, expected, **other_parameters)
self.mae = mae

def compare(self) -> None:
actual, expected = self.actual, self.expected
def assert_close_with_image_support(actual, expected, *, mae=False, atol=None, msg=None, **kwargs):
def compare(actual, expected):
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not convert all inputs to tensors, regardless of whether they're all PIL Images? Does that mean we can't do assert_close_with_image_support(some_tensor, some_pil_img)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #7525 (comment) for one reason. Plus, we originally had the option to pass in one tensor and one PIL image and let the function handle the conversion to tensor. However, we often saw extra flakiness in the tests just to this conversion and thus we scraped that feature. Since this irregular anyway, the few tests that need this, perform the conversion themselves now. E.g.

def pil_reference_wrapper(pil_kernel):

actual, expected = [to_image_tensor(input) for input in [actual, expected]]

self._compare_attributes(actual, expected)
actual, expected = self._equalize_attributes(actual, expected)
enable_mae_comparison = all(isinstance(input, torch.Tensor) for input in [actual, expected])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be

Suggested change
enable_mae_comparison = all(isinstance(input, torch.Tensor) for input in [actual, expected])
enable_mae_comparison = all(isinstance(input, torch.Tensor) for input in [actual, expected]) and mae

? Regardless I don't really understand the reason for enable_mae_comparison. Why can't we just convert all inputs to tensors and apply mae comparison iff mae = True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this as a general function that one can use like assert_close, but with_image_support. Meaning, you can still do assert_close_with_image_support(None, 5) and get a proper error message. If we force everything into a tensor, that would not be possible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But are we ever going to use it when the inputs aren't images (either PIL or Tensors)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we have such usages:

  • In the functional v2 tests, we are using assert_close (or now assert_close_with_image_support) for all kernels and dispatchers and this of course also means bounding boxes, masks and videos.
  • In the consistency tests for the detection and segmentation tests, we test full samples against each other and thus the inputs are more than just images. This is somewhat also the case in our (fairly limited) transforms tests, but will probably increase there when we ramp them up.


value_comparison_failure = False

def msg_callback(default_msg):
# This is a dirty hack that let's us "hook" into the comparison logic of `torch.testing.assert_close`. It
# will only be triggered in case of failed value comparison. This let's us reuse all of the attribute
# checking that `torch.testing.assert_close` does while giving us the ability to ignore the regular value
# comparison.
Comment on lines +286 to +289
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per comment. Note while this is a hack, it doesn't depend on any non-public functionality or assumptions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT

Suggested change
# This is a dirty hack that let's us "hook" into the comparison logic of `torch.testing.assert_close`. It
# will only be triggered in case of failed value comparison. This let's us reuse all of the attribute
# checking that `torch.testing.assert_close` does while giving us the ability to ignore the regular value
# comparison.
# This is a dirty hack that lets us "hook" into the comparison logic of `torch.testing.assert_close`. It
# will only be triggered in case of failed value comparison. This lets us reuse all of the attribute
# checking that `torch.testing.assert_close` does while giving us the ability to ignore the regular value
# comparison.

nonlocal value_comparison_failure
value_comparison_failure = True
Comment on lines +290 to +291
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this is the "hack" part and the rest below is just the same as the implementation for the default msg_callback?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Basically this is our way of detecting that there was value comparison failure since this callback is only called in this specific case.

if msg is None:
return default_msg
elif isinstance(msg, str):
return msg
elif callable(msg):
return msg(default_msg)
else:
raise pytest.UsageError(f"`msg` can be either be `None`, a `str` or a callable, but got {msg}.")

if self.mae:
if actual.dtype is torch.uint8:
actual, expected = actual.to(torch.int), expected.to(torch.int)
mae = float(torch.abs(actual - expected).float().mean())
if mae > self.atol:
self._fail(
AssertionError,
f"The MAE of the images is {mae}, but only {self.atol} is allowed.",
)
else:
super()._compare_values(actual, expected)
try:
return torch.testing.assert_close(
actual, expected, atol=atol, msg=msg_callback if enable_mae_comparison else msg, **kwargs
)
except AssertionError:
if not (value_comparison_failure and mae):
raise
Comment on lines +306 to +307
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be misunderstanding, but why do we even bother trying torch.testing.assert_close when mae is True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert_close does a lot more checks than just the value comparison. For us the most important ones are shape, dtype, and device checks. If we don't invoke assert_close, we'll have to do this ourselves or live with the fact that check_dtype=True is ignored when mae=True is set.


if actual.dtype is torch.uint8:
actual, expected = actual.to(torch.int), expected.to(torch.int)
mae_value = float(torch.abs(actual - expected).float().mean())
Comment on lines +309 to +311
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Hopefully correct?) NIT

Suggested change
if actual.dtype is torch.uint8:
actual, expected = actual.to(torch.int), expected.to(torch.int)
mae_value = float(torch.abs(actual - expected).float().mean())
mae_value = torch.abs(actual.float() - expected.float()).mean()

if mae_value > atol:
raise AssertionError(f"The MAE of the images is {mae_value}, but only {atol} is allowed.")

def assert_close(
actual,
expected,
*,
allow_subclasses=True,
rtol=None,
atol=None,
equal_nan=False,
check_device=True,
check_dtype=True,
check_layout=True,
check_stride=False,
msg=None,
**kwargs,
):
"""Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison"""
__tracebackhide__ = True

error_metas = not_close_error_metas(
actual,
expected,
pair_types=(
NonePair,
BooleanPair,
NumberPair,
ImagePair,
TensorLikePair,
),
allow_subclasses=allow_subclasses,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
check_device=check_device,
check_dtype=check_dtype,
check_layout=check_layout,
check_stride=check_stride,
**kwargs,
)
actual_flat, actual_spec = tree_flatten(actual)
expected_flat, expected_spec = tree_flatten(expected)
assert actual_spec, expected_spec

if error_metas:
raise error_metas[0].to_error(msg)
for actual_item, expected_item in zip(actual_flat, expected_flat):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only "regression" compared to the patched assert_close is that we don't get any information where inside a sample an error happened. Consider the input sample to a detection model, i.e. a two-tuple with an image and a target dictionary. If the comparison of the bounding boxes target fails, assert_close would include [1]["boxes"] in the error message, whereas now we will only see the regular message and thus have to figure out ourselves what part of the input caused this.

We can re-implement this behavior if needed, but not sure its worth it. I guesstimate that 99% of our comparisons are single elements and thus no extra traceback is needed.

Apart from this, assert_close_with_image_support should be a faithful reproduction of the patched assert_close we had before.

compare(actual_item, expected_item)


assert_equal = functools.partial(assert_close, rtol=0, atol=0)
assert_equal_with_image_support = functools.partial(assert_close_with_image_support, rtol=0, atol=0)


def parametrized_error_message(*args, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from common_utils import (
assert_equal,
assert_equal_with_image_support,
DEFAULT_EXTRA_DIMS,
make_bounding_box,
make_detection_mask,
Expand Down Expand Up @@ -532,4 +533,4 @@ def make_datapoints():
torch.manual_seed(12)
expected_output = t_ref(*dp)

assert_equal(expected_output, output)
assert_equal_with_image_support(expected_output, output)
3 changes: 2 additions & 1 deletion test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from common_utils import (
assert_equal,
assert_equal_with_image_support,
assert_run_python_script,
cpu_and_gpu,
make_bounding_box,
Expand Down Expand Up @@ -383,7 +384,7 @@ def was_applied(output, inpt):
return False

# Make sure nothing fishy is going on
assert_equal(output, inpt)
assert_equal_with_image_support(output, inpt)
return True

first_simple_tensor_input, other_simple_tensor_inputs, other_inputs = split_on_simple_tensor(flat_inputs)
Expand Down
23 changes: 12 additions & 11 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import torchvision.transforms.v2 as v2_transforms
from common_utils import (
ArgsKwargs,
assert_close,
assert_close_with_image_support,
assert_equal,
assert_equal_with_image_support,
make_bounding_box,
make_detection_mask,
make_image,
Expand Down Expand Up @@ -580,7 +581,7 @@ def check_call_consistency(
f"`is_simple_tensor` path in `_transform`."
) from exc

assert_close(
assert_close_with_image_support(
output_prototype_tensor,
output_legacy_tensor,
msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
Expand All @@ -597,7 +598,7 @@ def check_call_consistency(
f"`datapoints.Image` path in `_transform`."
) from exc

assert_close(
assert_close_with_image_support(
output_prototype_image,
output_prototype_tensor,
msg=lambda msg: f"Output for datapoint and tensor images is not equal: \n\n{msg}",
Expand Down Expand Up @@ -627,7 +628,7 @@ def check_call_consistency(
f"`PIL.Image.Image` path in `_transform`."
) from exc

assert_close(
assert_close_with_image_support(
output_prototype_pil,
output_legacy_pil,
msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
Expand Down Expand Up @@ -757,7 +758,7 @@ def test_jit_consistency(config, args_kwargs):
torch.manual_seed(0)
output_prototype_scripted = prototype_transform_scripted(image)

assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)
torch.testing.assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)


class TestContainerTransforms:
Expand Down Expand Up @@ -899,7 +900,7 @@ def test_randaug(self, inpt, interpolation, mocker):
expected_output = t_ref(inpt)
output = t(inpt)

assert_close(expected_output, output, atol=1, rtol=0.1)
assert_close_with_image_support(expected_output, output, atol=1, rtol=0.1)

@pytest.mark.parametrize(
"inpt",
Expand Down Expand Up @@ -951,7 +952,7 @@ def test_trivial_aug(self, inpt, interpolation, mocker):
expected_output = t_ref(inpt)
output = t(inpt)

assert_close(expected_output, output, atol=1, rtol=0.1)
assert_close_with_image_support(expected_output, output, atol=1, rtol=0.1)

@pytest.mark.parametrize(
"inpt",
Expand Down Expand Up @@ -1004,7 +1005,7 @@ def test_augmix(self, inpt, interpolation, mocker):
expected_output = t_ref(inpt)
output = t(inpt)

assert_equal(expected_output, output)
assert_equal_with_image_support(expected_output, output)

@pytest.mark.parametrize(
"inpt",
Expand Down Expand Up @@ -1033,7 +1034,7 @@ def test_aa(self, inpt, interpolation):
torch.manual_seed(12)
output = t(inpt)

assert_equal(expected_output, output)
assert_equal_with_image_support(expected_output, output)


def import_transforms_from_references(reference):
Expand Down Expand Up @@ -1127,7 +1128,7 @@ def test_transform(self, t_ref, t, data_kwargs):
torch.manual_seed(12)
expected_output = t_ref(*dp)

assert_equal(expected_output, output)
assert_equal_with_image_support(expected_output, output)


seg_transforms = import_transforms_from_references("segmentation")
Expand Down Expand Up @@ -1196,7 +1197,7 @@ def check(self, t, t_ref, data_kwargs=None):
expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
expected = (expected_image, expected_mask)

assert_equal(actual, expected)
assert_equal_with_image_support(actual, expected)

@pytest.mark.parametrize(
("t_ref", "t", "data_kwargs"),
Expand Down
12 changes: 6 additions & 6 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch

from common_utils import (
assert_close,
assert_close_with_image_support,
cache,
cpu_and_gpu,
DEFAULT_SQUARE_SPATIAL_SIZE,
Expand Down Expand Up @@ -142,7 +142,7 @@ def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
actual = kernel_scripted(input, *other_args, **kwargs)
expected = kernel_eager(input, *other_args, **kwargs)

assert_close(
assert_close_with_image_support(
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
Expand Down Expand Up @@ -200,7 +200,7 @@ def test_batched_vs_single(self, test_id, info, args_kwargs, device):
single_inputs = self._unbatch(batched_input, data_dims=data_dims)
expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)

assert_close(
assert_close_with_image_support(
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
Expand Down Expand Up @@ -231,7 +231,7 @@ def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
output_cuda = info.kernel(input_cuda, *other_args, **kwargs)

assert_close(
assert_close_with_image_support(
output_cuda,
output_cpu,
check_device=False,
Expand Down Expand Up @@ -262,7 +262,7 @@ def test_against_reference(self, test_id, info, args_kwargs):
# metadata regardless of whether the kernel takes it explicitly or not
expected = info.reference_fn(input, *other_args, **kwargs)

assert_close(
assert_close_with_image_support(
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
Expand Down Expand Up @@ -290,7 +290,7 @@ def test_float32_vs_uint8(self, test_id, info, args_kwargs):

expected = F.convert_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32)

assert_close(
assert_close_with_image_support(
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
Expand Down