Skip to content

Commit 57eede2

Browse files
Sid MohanSid Mohan
Sid Mohan
authored and
Sid Mohan
committed
py311 passed
1 parent 51b30ec commit 57eede2

File tree

5 files changed

+72
-14
lines changed

5 files changed

+72
-14
lines changed

datafog/processing/image_processing/donut_processor.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77

88
import requests
99
from PIL import Image
10+
import numpy as np
1011

1112
from .image_downloader import ImageDownloader
1213

1314

1415
class DonutProcessor:
1516
def __init__(self, model_path="naver-clova-ix/donut-base-finetuned-cord-v2"):
16-
1717
self.ensure_installed("torch")
1818
self.ensure_installed("transformers")
1919

@@ -36,13 +36,31 @@ def ensure_installed(self, package_name):
3636
[sys.executable, "-m", "pip", "install", package_name]
3737
)
3838

39-
async def parse_image(self, image: Image) -> str:
39+
def preprocess_image(self, image: Image.Image) -> np.ndarray:
40+
# Convert to RGB if the image is not already in RGB mode
41+
if image.mode != 'RGB':
42+
image = image.convert('RGB')
43+
44+
# Convert to numpy array
45+
image_np = np.array(image)
46+
47+
# Ensure the image is 3D (height, width, channels)
48+
if image_np.ndim == 2:
49+
image_np = np.expand_dims(image_np, axis=-1)
50+
image_np = np.repeat(image_np, 3, axis=-1)
51+
52+
return image_np
53+
54+
async def parse_image(self, image: Image.Image) -> str:
4055
"""Process w/ DonutProcessor and VisionEncoderDecoderModel"""
56+
# Preprocess the image
57+
image_np = self.preprocess_image(image)
58+
4159
task_prompt = "<s_cord-v2>"
4260
decoder_input_ids = self.processor.tokenizer(
4361
task_prompt, add_special_tokens=False, return_tensors="pt"
4462
).input_ids
45-
pixel_values = self.processor(image, return_tensors="pt").pixel_values
63+
pixel_values = self.processor(images=image_np, return_tensors="pt").pixel_values
4664

4765
outputs = self.model.generate(
4866
pixel_values.to(self.device),
@@ -71,8 +89,8 @@ def process_url(self, url: str) -> str:
7189
image = self.downloader.download_image(url)
7290
return self.parse_image(image)
7391

74-
def download_image(self, url: str) -> Image:
92+
def download_image(self, url: str) -> Image.Image:
7593
"""Download an image from URL."""
7694
response = requests.get(url)
7795
image = Image.open(BytesIO(response.content))
78-
return image
96+
return image

datafog/services/image_service.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,28 @@
44
from PIL import Image
55

66
from datafog.processing.image_processing.donut_processor import DonutProcessor
7-
from datafog.processing.image_processing.image_downloader import ImageDownloader
7+
# from datafog.processing.image_processing.image_downloader import ImageDownloader
88
from datafog.processing.image_processing.pytesseract_processor import (
99
PytesseractProcessor,
1010
)
11+
import aiohttp
12+
from PIL import Image
13+
import io
14+
import ssl
15+
import certifi
16+
17+
class ImageDownloader:
18+
async def download_image(self, url: str) -> Image.Image:
19+
ssl_context = ssl.create_default_context(cafile=certifi.where())
20+
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=ssl_context)) as session:
21+
async with session.get(url) as response:
22+
if response.status == 200:
23+
image_data = await response.read()
24+
return Image.open(io.BytesIO(image_data))
25+
else:
26+
raise Exception(f"Failed to download image. Status code: {response.status}")
27+
28+
1129

1230

1331
class ImageService:
@@ -20,8 +38,19 @@ def __init__(self, use_donut: bool = False, use_tesseract: bool = True):
2038
PytesseractProcessor() if self.use_tesseract else None
2139
)
2240

41+
# async def download_images(self, urls: List[str]) -> List[Image.Image]:
42+
# async def download_image(url: str) -> Image.Image:
43+
# return await self.downloader.download_image(url)
44+
45+
# tasks = [asyncio.create_task(download_image(url)) for url in urls]
46+
# return await asyncio.gather(*tasks)
47+
2348
async def download_images(self, urls: List[str]) -> List[Image.Image]:
24-
return await self.downloader.download_images(urls)
49+
async def download_image(url: str) -> Image.Image:
50+
return await self.downloader.download_image(url)
51+
52+
tasks = [asyncio.create_task(download_image(url)) for url in urls]
53+
return await asyncio.gather(*tasks, return_exceptions=True)
2554

2655
async def ocr_extract(
2756
self,

requirements-dev.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ just
66
isort
77
black
88
blacken-docs
9+
certifi
910
flake8
1011
prettier
1112
tox
12-
pytest
13+
pytest==7.4.0
14+
pytest-asyncio==0.21.0
1315
pytest-cov
1416
mypy
1517
autoflake

tests/test_image_service.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212

1313
import pytest
14+
import asyncio
1415
from PIL import Image
1516

1617
from datafog.services.image_service import ImageService
@@ -21,13 +22,21 @@
2122
]
2223

2324

25+
# @pytest.mark.asyncio
26+
# async def test_download_images():
27+
# image_service1 = ImageService()
28+
# images = await image_service1.download_images(urls)
29+
# assert len(images) == 2
30+
# assert all(isinstance(image, Image.Image) for image in images)
2431
@pytest.mark.asyncio
2532
async def test_download_images():
26-
image_service1 = ImageService()
27-
images = await image_service1.download_images(urls)
28-
assert len(images) == 2
29-
assert all(isinstance(image, Image.Image) for image in images)
30-
33+
image_service = ImageService()
34+
try:
35+
images = await image_service.download_images(urls)
36+
assert len(images) == 2
37+
assert all(isinstance(image, Image.Image) for image in images)
38+
finally:
39+
await asyncio.sleep(0) # Allow pending callbacks to run
3140

3241
@pytest.mark.asyncio
3342
async def test_ocr_extract_with_tesseract():

tox.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[tox]
2-
envlist = py312
2+
envlist = py311
33
isolated_build = True
44

55
[testenv]

0 commit comments

Comments
 (0)