Skip to content

Commit 1450781

Browse files
committed
extract make_* functions out of make_*_loader
1 parent ce7075c commit 1450781

File tree

1 file changed

+37
-9
lines changed

1 file changed

+37
-9
lines changed

test/common_utils.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,34 @@ def get_num_channels(color_space):
492492
return num_channels
493493

494494

495+
def make_image(
496+
spatial_size,
497+
*,
498+
color_space="RGB",
499+
batch_dims=(),
500+
dtype=torch.float32,
501+
device="cpu",
502+
constant_alpha=True,
503+
memory_format=torch.contiguous_format,
504+
):
505+
spatial_size = _parse_spatial_size(spatial_size)
506+
num_channels = get_num_channels(color_space)
507+
max_value = get_max_value(dtype)
508+
509+
data = torch.testing.make_tensor(
510+
(*batch_dims, num_channels, *spatial_size),
511+
low=0,
512+
high=max_value,
513+
dtype=dtype,
514+
device=device,
515+
memory_format=memory_format,
516+
)
517+
if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha:
518+
data[..., -1, :, :] = max_value
519+
520+
return datapoints.Image(data)
521+
522+
495523
def make_image_loader(
496524
size="random",
497525
*,
@@ -505,20 +533,20 @@ def make_image_loader(
505533
num_channels = get_num_channels(color_space)
506534

507535
def fn(shape, dtype, device, memory_format):
508-
max_value = get_max_value(dtype)
509-
data = torch.testing.make_tensor(
510-
shape, low=0, high=max_value, dtype=dtype, device=device, memory_format=memory_format
536+
*batch_dims, _, spatial_size = shape
537+
return make_image(
538+
spatial_size,
539+
color_space=color_space,
540+
batch_dims=batch_dims,
541+
dtype=dtype,
542+
device=device,
543+
constant_alpha=constant_alpha,
544+
memory_format=memory_format,
511545
)
512-
if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha:
513-
data[..., -1, :, :] = max_value
514-
return datapoints.Image(data)
515546

516547
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, memory_format=memory_format)
517548

518549

519-
make_image = from_loader(make_image_loader)
520-
521-
522550
def make_image_loaders(
523551
*,
524552
sizes=DEFAULT_SPATIAL_SIZES,

0 commit comments

Comments
 (0)