|
3 | 3 | import logging
|
4 | 4 | from typing import Optional
|
5 | 5 | from requests.models import Response
|
| 6 | +import pandas as pd |
| 7 | +from PIL import Image, ImageOps,ExifTags |
| 8 | +from io import BytesIO |
| 9 | +import base64 |
6 | 10 |
|
7 | 11 | import openai
|
8 | 12 | from openai import APIConnectionError, AuthenticationError, APIStatusError
|
9 | 13 |
|
| 14 | +from msfocr.llm import ocr_functions |
| 15 | + |
| 16 | +'Part1-testing llm_ocr_function' |
| 17 | + |
| 18 | +def test_parse_table_data(): |
| 19 | + result = { |
| 20 | + 'tables': [ |
| 21 | + { |
| 22 | + 'table_name': 'Paediatric vaccination target group', |
| 23 | + 'headers': ['', '0-11m', '12-59m', '5-14y'], |
| 24 | + 'data': [['Paed (0-59m) vacc target population', '', '', '']] |
| 25 | + }, |
| 26 | + { |
| 27 | + 'table_name': 'Routine paediatric vaccinations', |
| 28 | + 'headers': ['', '0-11m', '12-59m', '5-14y'], |
| 29 | + 'data': [ |
| 30 | + ['BCG', '45+29', '-', '-'], |
| 31 | + ['HepB (birth dose, within 24h)', '-', '-', '-'], |
| 32 | + ['HepB (birth dose, 24h or later)', '-', '-', '-'], |
| 33 | + ['Polio (OPV) 0 (birth dose)', '30+18', '-', '-'], |
| 34 | + ['Polio (OPV) 1 (from 6 wks)', '55+29', '-', '-'], |
| 35 | + ['Polio (OPV) 2', '77+19', '8', '-'], |
| 36 | + ['Polio (OPV) 3', '116+8', '15+3', '-'], |
| 37 | + ['Polio (IPV)', '342+42', '-', '-'], |
| 38 | + ['DTP+Hib+HepB (pentavalent) 1', '88+37', '3', '-'], |
| 39 | + ['DTP+Hib+HepB (pentavalent) 2', '125+16', '14+1', '-'], |
| 40 | + ['DTP+Hib+HepB (pentavalent) 3', '107+5', '23+6', '-'] |
| 41 | + ] |
| 42 | + } |
| 43 | + ], |
| 44 | + 'non_table_data': { |
| 45 | + 'Health Structure': 'W14', |
| 46 | + 'Supervisor': 'BKL', |
| 47 | + 'Start Date (YYYY-MM-DD)': '', |
| 48 | + 'End Date (YYYY-MM-DD)': '', |
| 49 | + 'Vaccination': 'paediatric' |
| 50 | + } |
| 51 | + } |
| 52 | + |
| 53 | + expected_table_names = [ |
| 54 | + 'Paediatric vaccination target group', |
| 55 | + 'Routine paediatric vaccinations' |
| 56 | + ] |
| 57 | + |
| 58 | + expected_dataframes = [ |
| 59 | + pd.DataFrame([['', '0-11m', '12-59m', '5-14y'], ['Paed (0-59m) vacc target population', '', '', '']]), |
| 60 | + pd.DataFrame([ |
| 61 | + ['', '0-11m', '12-59m', '5-14y'], |
| 62 | + ['BCG', '45+29', '-', '-'], |
| 63 | + ['HepB (birth dose, within 24h)', '-', '-', '-'], |
| 64 | + ['HepB (birth dose, 24h or later)', '-', '-', '-'], |
| 65 | + ['Polio (OPV) 0 (birth dose)', '30+18', '-', '-'], |
| 66 | + ['Polio (OPV) 1 (from 6 wks)', '55+29', '-', '-'], |
| 67 | + ['Polio (OPV) 2', '77+19', '8', '-'], |
| 68 | + ['Polio (OPV) 3', '116+8', '15+3', '-'], |
| 69 | + ['Polio (IPV)', '342+42', '-', '-'], |
| 70 | + ['DTP+Hib+HepB (pentavalent) 1', '88+37', '3', '-'], |
| 71 | + ['DTP+Hib+HepB (pentavalent) 2', '125+16', '14+1', '-'], |
| 72 | + ['DTP+Hib+HepB (pentavalent) 3', '107+5', '23+6', '-'] |
| 73 | + ]) |
| 74 | + ] |
| 75 | + |
| 76 | + table_names, dataframes = ocr_functions.parse_table_data(result) |
| 77 | + |
| 78 | + assert table_names == expected_table_names |
| 79 | + for df, expected_df in zip(dataframes, expected_dataframes): |
| 80 | + pd.testing.assert_frame_equal(df, expected_df) |
| 81 | + |
| 82 | + |
| 83 | +def test_rescale_image(): |
| 84 | + # Create a simple image for testing |
| 85 | + img = Image.new('RGB', (3000, 1500), color='red') |
| 86 | + |
| 87 | + # Test resizing largest dimension |
| 88 | + resized_img = ocr_functions.rescale_image(img, 2048, True) |
| 89 | + assert max(resized_img.size) == 2048 |
| 90 | + assert resized_img.size == (2048, 1024) # Expected resized dimensions |
| 91 | + |
| 92 | + # Test resizing smallest dimension |
| 93 | + resized_img = ocr_functions.rescale_image(img, 768, False) |
| 94 | + assert min(resized_img.size) == 768 |
| 95 | + # 768 / 1024 * 2048 = 1536 |
| 96 | + assert resized_img.size == (1536, 768) |
| 97 | + |
| 98 | + |
| 99 | +def test_encode_image(): |
| 100 | + # Create a simple image for testing |
| 101 | + img = Image.new('RGB', (3000, 1500), color = 'red') |
| 102 | + buffered = BytesIO() |
| 103 | + img.save(buffered, format="PNG") |
| 104 | + buffered.seek(0) |
| 105 | + |
| 106 | + # Encode the image using the encode_image function |
| 107 | + encoded_string = ocr_functions.encode_image(buffered) |
| 108 | + |
| 109 | + # Verify that the encoded string is a valid base64 string |
| 110 | + decoded_image = base64.b64decode(encoded_string) |
| 111 | + assert decoded_image[:8] == b'\x89PNG\r\n\x1a\n' |
| 112 | + |
| 113 | + # Optionally, check if the image can be successfully loaded back |
| 114 | + img_back = Image.open(BytesIO(decoded_image)) |
| 115 | + assert max(img_back.size) == 2048 or min(img_back.size) == 768 |
| 116 | + |
| 117 | + |
| 118 | +def create_test_image_with_orientation(orientation): |
| 119 | + # Create a simple image |
| 120 | + img = Image.new('RGB', (100, 50), color='red') |
| 121 | + buffered = BytesIO() |
| 122 | + img.save(buffered, format="JPEG") |
| 123 | + buffered.seek(0) |
| 124 | + |
| 125 | + # Load the image and manually set the orientation EXIF tag |
| 126 | + img_with_orientation = Image.open(buffered) |
| 127 | + exif = img_with_orientation.getexif() |
| 128 | + exif[274] = orientation # 274 is the EXIF tag code for Orientation |
| 129 | + exif_bytes = exif.tobytes() |
| 130 | + |
| 131 | + # Save the image with the new EXIF data |
| 132 | + buffered = BytesIO() |
| 133 | + img_with_orientation.save(buffered, format="JPEG", exif=exif_bytes) |
| 134 | + buffered.seek(0) |
| 135 | + return buffered |
| 136 | + |
| 137 | + |
| 138 | +def assert_color_within_tolerance(color1, color2, tolerance=1): |
| 139 | + """ |
| 140 | + Assert that two colors are within a given tolerance. |
| 141 | +
|
| 142 | + :param color1: The first color as an (R, G, B) tuple. |
| 143 | + :param color2: The second color as an (R, G, B) tuple. |
| 144 | + :param tolerance: The tolerance for each color component. |
| 145 | + """ |
| 146 | + for c1, c2 in zip(color1, color2): |
| 147 | + assert abs(c1 - c2) <= tolerance |
| 148 | + |
| 149 | + |
| 150 | +def test_correct_image_orientation(): |
| 151 | + # Test for orientation 3 (180 degrees) |
| 152 | + img_data = create_test_image_with_orientation(3) |
| 153 | + corrected_image = ocr_functions.correct_image_orientation(img_data) |
| 154 | + assert corrected_image.size == (100, 50) |
| 155 | + # Check if top-left pixel is red after rotating 180 degrees |
| 156 | + assert_color_within_tolerance(corrected_image.getpixel((0, 0)), (255, 0, 0)) |
| 157 | + |
| 158 | + # Test for orientation 6 (270 degrees) |
| 159 | + img_data = create_test_image_with_orientation(6) |
| 160 | + corrected_image = ocr_functions.correct_image_orientation(img_data) |
| 161 | + assert corrected_image.size == (50, 100) |
| 162 | + # Check if bottom-left pixel is red after rotating 270 degrees |
| 163 | + assert_color_within_tolerance(corrected_image.getpixel((0, corrected_image.size[1] - 1)), (255, 0, 0)) |
| 164 | + |
| 165 | + # Test for orientation 8 (90 degrees) |
| 166 | + img_data = create_test_image_with_orientation(8) |
| 167 | + corrected_image = ocr_functions.correct_image_orientation(img_data) |
| 168 | + assert corrected_image.size == (50, 100) |
| 169 | + # Check if top-right pixel is red after rotating 90 degrees |
| 170 | + assert_color_within_tolerance(corrected_image.getpixel((corrected_image.size[0] - 1, 0)), (255, 0, 0)) |
| 171 | + |
| 172 | + |
| 173 | +'Part2-testing openai api call' |
10 | 174 | # Setup logging
|
11 | 175 | logging.basicConfig(level=logging.INFO)
|
12 | 176 | logger = logging.getLogger(__name__)
|
|
0 commit comments