@@ -492,6 +492,34 @@ def get_num_channels(color_space):
492
492
return num_channels
493
493
494
494
495
+ def make_image (
496
+ spatial_size ,
497
+ * ,
498
+ color_space = "RGB" ,
499
+ batch_dims = (),
500
+ dtype = torch .float32 ,
501
+ device = "cpu" ,
502
+ constant_alpha = True ,
503
+ memory_format = torch .contiguous_format ,
504
+ ):
505
+ spatial_size = _parse_spatial_size (spatial_size )
506
+ num_channels = get_num_channels (color_space )
507
+ max_value = get_max_value (dtype )
508
+
509
+ data = torch .testing .make_tensor (
510
+ (* batch_dims , num_channels , * spatial_size ),
511
+ low = 0 ,
512
+ high = max_value ,
513
+ dtype = dtype ,
514
+ device = device ,
515
+ memory_format = memory_format ,
516
+ )
517
+ if color_space in {"GRAY_ALPHA" , "RGBA" } and constant_alpha :
518
+ data [..., - 1 , :, :] = max_value
519
+
520
+ return datapoints .Image (data )
521
+
522
+
495
523
def make_image_loader (
496
524
size = "random" ,
497
525
* ,
@@ -505,20 +533,20 @@ def make_image_loader(
505
533
num_channels = get_num_channels (color_space )
506
534
507
535
def fn (shape , dtype , device , memory_format ):
508
- max_value = get_max_value (dtype )
509
- data = torch .testing .make_tensor (
510
- shape , low = 0 , high = max_value , dtype = dtype , device = device , memory_format = memory_format
536
+ * batch_dims , _ , spatial_size = shape
537
+ return make_image (
538
+ spatial_size ,
539
+ color_space = color_space ,
540
+ batch_dims = batch_dims ,
541
+ dtype = dtype ,
542
+ device = device ,
543
+ constant_alpha = constant_alpha ,
544
+ memory_format = memory_format ,
511
545
)
512
- if color_space in {"GRAY_ALPHA" , "RGBA" } and constant_alpha :
513
- data [..., - 1 , :, :] = max_value
514
- return datapoints .Image (data )
515
546
516
547
return ImageLoader (fn , shape = (* extra_dims , num_channels , * size ), dtype = dtype , memory_format = memory_format )
517
548
518
549
519
- make_image = from_loader (make_image_loader )
520
-
521
-
522
550
def make_image_loaders (
523
551
* ,
524
552
sizes = DEFAULT_SPATIAL_SIZES ,
0 commit comments