3
3
#include " decode_gif.h"
4
4
#include " decode_jpeg.h"
5
5
#include " decode_png.h"
6
+ #include " decode_webp.h"
6
7
7
8
namespace vision {
8
9
namespace image {
@@ -20,29 +21,43 @@ torch::Tensor decode_image(
20
21
data.dim () == 1 && data.numel () > 0 ,
21
22
" Expected a non empty 1-dimensional tensor" );
22
23
24
+ auto err_msg =
25
+ " Unsupported image file. Only jpeg, png and gif are currently supported." ;
26
+
23
27
auto datap = data.data_ptr <uint8_t >();
24
28
25
29
const uint8_t jpeg_signature[3 ] = {255 , 216 , 255 }; // == "\xFF\xD8\xFF"
30
+ TORCH_CHECK (data.numel () >= 3 , err_msg);
31
+ if (memcmp (jpeg_signature, datap, 3 ) == 0 ) {
32
+ return decode_jpeg (data, mode, apply_exif_orientation);
33
+ }
34
+
26
35
const uint8_t png_signature[4 ] = {137 , 80 , 78 , 71 }; // == "\211PNG"
36
+ TORCH_CHECK (data.numel () >= 4 , err_msg);
37
+ if (memcmp (png_signature, datap, 4 ) == 0 ) {
38
+ return decode_png (data, mode, apply_exif_orientation);
39
+ }
40
+
27
41
const uint8_t gif_signature_1[6 ] = {
28
42
0x47 , 0x49 , 0x46 , 0x38 , 0x39 , 0x61 }; // == "GIF89a"
29
43
const uint8_t gif_signature_2[6 ] = {
30
44
0x47 , 0x49 , 0x46 , 0x38 , 0x37 , 0x61 }; // == "GIF87a"
31
-
32
- if (memcmp (jpeg_signature, datap, 3 ) == 0 ) {
33
- return decode_jpeg (data, mode, apply_exif_orientation);
34
- } else if (memcmp (png_signature, datap, 4 ) == 0 ) {
35
- return decode_png (data, mode, apply_exif_orientation);
36
- } else if (
37
- memcmp (gif_signature_1, datap, 6 ) == 0 ||
45
+ TORCH_CHECK (data.numel () >= 6 , err_msg);
46
+ if (memcmp (gif_signature_1, datap, 6 ) == 0 ||
38
47
memcmp (gif_signature_2, datap, 6 ) == 0 ) {
39
48
return decode_gif (data);
40
- } else {
41
- TORCH_CHECK (
42
- false ,
43
- " Unsupported image file. Only jpeg, png and gif " ,
44
- " are currently supported." );
45
49
}
50
+
51
+ const uint8_t webp_signature_begin[4 ] = {0x52 , 0x49 , 0x46 , 0x46 }; // == "RIFF"
52
+ const uint8_t webp_signature_end[7 ] = {
53
+ 0x57 , 0x45 , 0x42 , 0x50 , 0x56 , 0x50 , 0x38 }; // == "WEBPVP8"
54
+ TORCH_CHECK (data.numel () >= 15 , err_msg);
55
+ if ((memcmp (webp_signature_begin, datap, 4 ) == 0 ) &&
56
+ (memcmp (webp_signature_end, datap + 8 , 7 ) == 0 )) {
57
+ return decode_webp (data);
58
+ }
59
+
60
+ TORCH_CHECK (false , err_msg);
46
61
}
47
62
48
63
} // namespace image
0 commit comments