5
5
import sys
6
6
from io import BytesIO
7
7
8
+ import numpy as np
8
9
import requests
9
10
from PIL import Image
10
11
13
14
14
15
class DonutProcessor :
15
16
def __init__ (self , model_path = "naver-clova-ix/donut-base-finetuned-cord-v2" ):
16
-
17
17
self .ensure_installed ("torch" )
18
18
self .ensure_installed ("transformers" )
19
19
@@ -36,13 +36,31 @@ def ensure_installed(self, package_name):
36
36
[sys .executable , "-m" , "pip" , "install" , package_name ]
37
37
)
38
38
39
- async def parse_image (self , image : Image ) -> str :
39
+ def preprocess_image (self , image : Image .Image ) -> np .ndarray :
40
+ # Convert to RGB if the image is not already in RGB mode
41
+ if image .mode != "RGB" :
42
+ image = image .convert ("RGB" )
43
+
44
+ # Convert to numpy array
45
+ image_np = np .array (image )
46
+
47
+ # Ensure the image is 3D (height, width, channels)
48
+ if image_np .ndim == 2 :
49
+ image_np = np .expand_dims (image_np , axis = - 1 )
50
+ image_np = np .repeat (image_np , 3 , axis = - 1 )
51
+
52
+ return image_np
53
+
54
+ async def parse_image (self , image : Image .Image ) -> str :
40
55
"""Process w/ DonutProcessor and VisionEncoderDecoderModel"""
56
+ # Preprocess the image
57
+ image_np = self .preprocess_image (image )
58
+
41
59
task_prompt = "<s_cord-v2>"
42
60
decoder_input_ids = self .processor .tokenizer (
43
61
task_prompt , add_special_tokens = False , return_tensors = "pt"
44
62
).input_ids
45
- pixel_values = self .processor (image , return_tensors = "pt" ).pixel_values
63
+ pixel_values = self .processor (images = image_np , return_tensors = "pt" ).pixel_values
46
64
47
65
outputs = self .model .generate (
48
66
pixel_values .to (self .device ),
@@ -71,7 +89,7 @@ def process_url(self, url: str) -> str:
71
89
image = self .downloader .download_image (url )
72
90
return self .parse_image (image )
73
91
74
- def download_image (self , url : str ) -> Image :
92
+ def download_image (self , url : str ) -> Image . Image :
75
93
"""Download an image from URL."""
76
94
response = requests .get (url )
77
95
image = Image .open (BytesIO (response .content ))
0 commit comments