Skip to content

Commit ec2b533

Browse files
authoredAug 1, 2024··
Merge pull request #15 from UMassCDS/new-form-retrieval
New form retrieval
2 parents a49590a + 04ccce6 commit ec2b533

File tree

6 files changed

+412
-326
lines changed

6 files changed

+412
-326
lines changed
 

‎CHANGELOG.md

+7-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
You should also add project tags for each release in Github, see [Managing releases in a repository](https://docs.github.com/en/repositories/releasing-projects-on-github/managing-releases-in-a-repository).
88

99
## [Unreleased]
10-
1110
### Added
1211
- Merged the MSF-OCR-Streamlit repository into this repository
12+
- User authenticates with DHIS2 password rather than hard coded passkey
13+
- Table headers are corrected based on key-value pairs from DHIS2
14+
- Payload with key-value pairs for DHIS2 is displayed to the user before uploading
15+
- Images are resized before being sent to OpenAI
16+
17+
### Changed
18+
- User must confirm each page and key-value pairs before they're allowed to upload, so upload buttons is not initially selectable
1319

1420
## [1.1.0] - 2024-07-26
1521
### Changed

‎app_llm.py

+242-187
Large diffs are not rendered by default.

‎src/msfocr/data/dhis2.py

+71-77
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import urllib.parse
2-
2+
import json
33
import requests
44

55
# Make sure these are set before trying to make requests
@@ -10,16 +10,12 @@
1010
# TODO It might be clearer to create a Server object class and have this be the __init__() function
1111
def configure_DHIS2_server(username=None, password=None, server_url=None):
1212
global DHIS2_SERVER_URL, DHIS2_USERNAME, DHIS2_PASSWORD
13-
DHIS2_USERNAME = username
14-
DHIS2_PASSWORD = password
15-
DHIS2_SERVER_URL = server_url
16-
17-
18-
def checkResponseStatus(res):
19-
if res.status_code == 401:
20-
raise ValueError("Authentication failed. Check your username and password.")
21-
res.raise_for_status()
22-
13+
if username is not None:
14+
DHIS2_USERNAME = username
15+
if password is not None:
16+
DHIS2_PASSWORD = password
17+
if server_url is not None:
18+
DHIS2_SERVER_URL = server_url
2319

2420
def getAllUIDs(item_type, search_items):
2521
encoded_search_items = [urllib.parse.quote_plus(item) for item in search_items]
@@ -30,10 +26,7 @@ def getAllUIDs(item_type, search_items):
3026
filter_param = 'filter=' + '&filter='.join([f'name:ilike:{term}' for term in encoded_search_items])
3127

3228
url = f'{DHIS2_SERVER_URL}/api/{item_type}?{filter_param}'
33-
response = requests.get(url, auth=(DHIS2_USERNAME, DHIS2_PASSWORD))
34-
checkResponseStatus(response)
35-
36-
data = response.json()
29+
data = getResponse(url)
3730
items = data[item_type]
3831
print(f"{len(data[item_type])} matches found for {search_items}")
3932
if len(items) > 0:
@@ -43,6 +36,15 @@ def getAllUIDs(item_type, search_items):
4336

4437
return uid
4538

39+
def getResponse(url):
40+
response = requests.get(url, auth=(DHIS2_USERNAME, DHIS2_PASSWORD))
41+
42+
if response.status_code == 401:
43+
raise ValueError("Authentication failed. Check your username and password.")
44+
response.raise_for_status()
45+
46+
data = response.json()
47+
return data
4648

4749
def getOrgUnitChildren(uid):
4850
"""
@@ -51,11 +53,7 @@ def getOrgUnitChildren(uid):
5153
:return: List of (org unit child name, org unit child data sets))
5254
"""
5355
url = f'{DHIS2_SERVER_URL}/api/organisationUnits/{uid}?includeChildren=true'
54-
55-
response = requests.get(url, auth=(DHIS2_USERNAME, DHIS2_PASSWORD))
56-
checkResponseStatus(response)
57-
58-
data = response.json()
56+
data = getResponse(url)
5957
items = data['organisationUnits']
6058
children = [(item['name'], item['dataSets'], item['id']) for item in items if item['id'] != uid]
6159

@@ -73,71 +71,67 @@ def getDataSets(data_sets_uids):
7371
uid = uid_obj['id']
7472
url = f'{DHIS2_SERVER_URL}/api/dataSets/{uid}'
7573

76-
response = requests.get(url, auth=(DHIS2_USERNAME, DHIS2_PASSWORD))
77-
checkResponseStatus(response)
78-
79-
data = response.json()
74+
data = getResponse(url)
8075
data_set = (data['name'], data['id'], data['periodType'])
8176
data_sets.append(data_set)
8277

8378
return data_sets
8479

85-
def getCategoryUIDs(dataSet_uid):
80+
def getFormJson(dataSet_uid, period, orgUnit_uid):
8681
"""
87-
Hierarchically searches DHIS2 to generate category UIDs for each dataElement. Also used for retreiving all data elements and categories present in a dataset.
88-
:param data_sets_uid: UID of the dataset
89-
:return: dataElement_to_id (dict[str, str]), dataElement_to_categoryCombo (dict[str, str]), categoryCombos (dict[str, str]), category_list (list[str]), dataElement_list (list[str])
90-
category list eg. ['0-11m','<5y'...]
82+
Gets information about all forms associated with a organisation, dataset, period combination in DHIS2.
83+
:param dataset UID, time period, organisation unit UID
84+
:return json response containing hierarchical information about tabs, tables, non-tabular fields
9185
"""
92-
url = f'{DHIS2_SERVER_URL}/api/dataSets/{dataSet_uid}?fields=dataSetElements'
93-
94-
response = requests.get(url, auth=(DHIS2_USERNAME, DHIS2_PASSWORD))
95-
checkResponseStatus(response)
9686

97-
data = response.json()
98-
99-
items = data['dataSetElements']
100-
101-
dataElement_to_categoryCombo = {}
102-
categoryCombos = {}
103-
categoryOptionCombos = {}
104-
for item in items:
105-
if 'categoryCombo' in item:
106-
dataElement_to_categoryCombo[item['dataElement']['id']] = item['categoryCombo']['id']
107-
categoryCombos[item['categoryCombo']['id']] = {}
108-
109-
for catCombo_id in categoryCombos:
110-
url = f'{DHIS2_SERVER_URL}/api/categoryCombos/{catCombo_id}?fields=categoryOptionCombos'
111-
112-
response = requests.get(url, auth=(DHIS2_USERNAME, DHIS2_PASSWORD))
113-
checkResponseStatus(response)
114-
115-
data = response.json()
116-
117-
items = data['categoryOptionCombos']
118-
119-
for item in items:
120-
url = f"{DHIS2_SERVER_URL}/api/categoryOptionCombos/{item['id']}?fields=name"
121-
122-
response = requests.get(url, auth=(DHIS2_USERNAME, DHIS2_PASSWORD))
123-
checkResponseStatus(response)
87+
# POST empty data payload to trigger form generation
88+
json_export = {}
89+
json_export["dataSet"] = dataSet_uid
90+
json_export["period"] = period
91+
json_export["orgUnit"] = orgUnit_uid
92+
json_export["dataValues"] = []
93+
data_payload = json.dumps(json_export)
94+
posturl = f'{DHIS2_SERVER_URL}/api/dataValueSets?dryRun=true'
95+
96+
response = requests.post(
97+
posturl,
98+
auth=(DHIS2_USERNAME, DHIS2_PASSWORD),
99+
headers={'Content-Type': 'application/json'},
100+
data=data_payload
101+
)
102+
response.raise_for_status()
103+
104+
# Get form now
105+
url = f'{DHIS2_SERVER_URL}/api/dataSets/{dataSet_uid}/form.json?pe={period}&ou={orgUnit_uid}'
106+
data = getResponse(url)
107+
return data
124108

125-
data = response.json()
126-
127-
categoryCombos[catCombo_id][data['name']] = item['id']
128-
129-
if data['name'] not in categoryOptionCombos:
130-
categoryOptionCombos[data['name']] = ''
131-
category_list = list(categoryOptionCombos.keys())
132-
133-
url = f'{DHIS2_SERVER_URL}/api/dataElements?filter=dataSetElements.dataSet.id:eq:{dataSet_uid}&fields=id,formName'
134-
response = requests.get(url, auth=(DHIS2_USERNAME, DHIS2_PASSWORD))
135-
checkResponseStatus(response)
136-
data = response.json()
137-
138-
dataElement_to_id = {item["formName"]:item["id"] for item in data['dataElements']}
139-
dataElement_list = [item["formName"] for item in data['dataElements']]
109+
def get_DE_COC_List(form):
110+
"""
111+
Finds the list of all dataElements (row names in tables) and categoryOptionCombos (column names in tables) within a DHIS2 form
112+
:param json data containing hierarchical information about tabs, tables, non-tabular fields within a organisation, dataset, period combination in DHIS2.
113+
:return List of row names found, List of column names found
114+
"""
115+
url = f'{DHIS2_SERVER_URL}/api/dataElements?paging=false&fields=id,formName'
116+
data = getResponse(url)
117+
allDataElements = {item['id']:item['formName'] for item in data['dataElements'] if 'formName' in item and 'id' in item}
118+
119+
url = f'{DHIS2_SERVER_URL}/api/categoryOptionCombos?paging=false&fields=id,name'
120+
data = getResponse(url)
121+
allCategory = {item['id']:item['name'] for item in data['categoryOptionCombos'] if 'name' in item and 'id' in item}
122+
123+
# Form tabs found in DHIS2
124+
tabs = form['groups']
125+
dataElement_list = {}
126+
categoryOptionCombo_list = {}
127+
for tab in tabs:
128+
for field in tab['fields']:
129+
DE_ID = field['dataElement']
130+
COC_ID = field['categoryOptionCombo']
131+
if allDataElements[DE_ID] not in dataElement_list:
132+
dataElement_list[allDataElements[DE_ID]] = 1
133+
if allCategory[COC_ID] not in categoryOptionCombo_list:
134+
categoryOptionCombo_list[allCategory[COC_ID]] = 1
135+
return list(dataElement_list.keys()), list(categoryOptionCombo_list.keys())
140136

141-
return dataElement_to_id, dataElement_to_categoryCombo, categoryCombos, category_list, dataElement_list
142-
143137

‎src/msfocr/doctr/ocr_functions.py

+21-25
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import numpy as np
1010
import pandas as pd
1111

12-
from msfocr.data import dhis2
1312

1413

1514
def letter_by_letter_similarity(text1, text2):
@@ -152,8 +151,7 @@ def get_sheet_type(res):
152151
period.append(date)
153152
return [dataSet, orgUnit, sorted(period)]
154153

155-
156-
def generate_key_value_pairs(table, dataSet_uid):
154+
def generate_key_value_pairs(table, form):
157155
"""
158156
Generates key-value pairs in the format required to upload data to DHIS2.
159157
{'dataElement': data_element_id,
@@ -162,36 +160,35 @@ def generate_key_value_pairs(table, dataSet_uid):
162160
UIDs like data_element_id, category_id are obtained by querying the DHIS2 metadata.
163161
:param table: DataFrame generated from table detection
164162
:return: List of key value pairs as shown above.
165-
"""
166-
# Save UIDs found in a dictionary to avoid repeated UID querying
167-
id_found = {}
168-
169-
# Get dataElement to UID map for all dataElements in the dataset
170-
dataElement_to_id, dataElement_to_categoryCombo, categoryCombos_to_name_to_id,_,_ = dhis2.getCategoryUIDs(dataSet_uid)
171-
163+
"""
172164
data_element_pairs = []
165+
173166
# Iterate over each cell in the DataFrame
174167
table_array = table.values
175168
columns = table.columns
176169
for row_index in range(table_array.shape[0]):
170+
# Row name in tally sheet
177171
data_element = table_array[row_index][0]
178172
for col_index in range(1, table_array.shape[1]):
173+
# Column name in tally sheet
179174
category = columns[col_index]
180175
cell_value = table_array[row_index][col_index]
181-
if cell_value is not None:
182-
if data_element not in id_found:
183-
# Retrive UIDs for dataElement and categoryOption
184-
data_element_id = dataElement_to_id[data_element]
185-
id_found[data_element] = data_element_id
186-
print(data_element, data_element_id)
187-
else:
188-
data_element_id = id_found[data_element]
189-
190-
# Get category_UID for each dataElement
191-
categoryCombo = dataElement_to_categoryCombo[data_element_id]
192-
categoryOptionCombos = categoryCombos_to_name_to_id[categoryCombo]
193-
category_id = categoryOptionCombos[category]
194-
176+
if cell_value is not None and cell_value!="-" and cell_value!="":
177+
data_element_id = None
178+
category_id = None
179+
# Search for the string in the "label" field of form information
180+
string_search = data_element + " " + category
181+
for group in form['groups']:
182+
for field in group['fields']:
183+
if field['label']==string_search:
184+
data_element_id = field['dataElement']
185+
category_id = field['categoryOptionCombo']
186+
187+
# The following exceptions will be raised if the row or column name in the tally sheet is different from the names used in metadata
188+
# For eg. Pop1: Resident is called Population 1 in metadata
189+
# If this exception is raised the only way forward is for the user to manually change the row/column name to the one used in metadata
190+
if data_element_id is None or category_id is None:
191+
raise Exception(f"Unable to find {string_search} in DHIS2 metadata")
195192
# Append to the list of data elements to be push to DHIS2
196193
data_element_pairs.append(
197194
{"dataElement": data_element_id,
@@ -201,7 +198,6 @@ def generate_key_value_pairs(table, dataSet_uid):
201198

202199
return data_element_pairs
203200

204-
205201
# ocr_model = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
206202
# document = DocumentFile.from_images("IMG_20240514_090947.jpg")
207203
# result = get_word_level_content(ocr_model, document)

‎src/msfocr/llm/ocr_functions.py

+49-19
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import base64
66
import json
77
from concurrent.futures.thread import ThreadPoolExecutor
8+
from io import BytesIO
89

910
import pandas as pd
1011

1112
from openai import OpenAI
12-
from PIL import Image, ExifTags
13-
13+
from PIL import Image, ExifTags, ImageOps
1414

1515
def get_results(uploaded_image_paths):
1616
"""
@@ -59,6 +59,30 @@ def parse_table_data(result):
5959
return table_names, dataframes
6060

6161

62+
def rescale_image(img, limit, maxi=True):
63+
"""Rescales an image file to GPT's proportions (Max 2048 x 768).
64+
65+
Args:
66+
img (_Image_): The image file that needs to be rescaled.
67+
limit (_int_): The maximum size of the dimension in pixels.
68+
maxi (bool, optional): True for resizing the largest dimension, false for smallest. Defaults to True.
69+
70+
Returns:
71+
_Image_: Resized image file.
72+
"""
73+
width, height = img.size
74+
if maxi:
75+
max_dim = max(width, height)
76+
else:
77+
max_dim = min(width, height)
78+
if max_dim > limit:
79+
scale_factor = limit / max_dim
80+
new_width = int(width * scale_factor)
81+
new_height = int(height * scale_factor)
82+
img = img.resize((new_width, new_height))
83+
return img
84+
85+
6286
def encode_image(image_path):
6387
"""
6488
Encodes an image file to base64 string.
@@ -70,7 +94,13 @@ def encode_image(image_path):
7094
:return: Base64 encoded string of the image.
7195
"""
7296
image_path.seek(0)
73-
return base64.b64encode(image_path.read()).decode("utf-8")
97+
with Image.open(image_path) as img:
98+
img = ImageOps.exif_transpose(img)
99+
img = rescale_image(img, 2048, True)
100+
img = rescale_image(img, 768, False)
101+
buffered = BytesIO()
102+
img.save(buffered, format="PNG")
103+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
74104

75105

76106
def extract_text_from_image(image_path):
@@ -139,21 +169,21 @@ def correct_image_orientation(image_path):
139169
:param image_path: The path to the image file.
140170
:return: PIL.Image.Image: The image with corrected orientation.
141171
"""
142-
image = Image.open(image_path)
143-
orientation = None
144-
try:
145-
for orientation in ExifTags.TAGS.keys():
146-
if ExifTags.TAGS[orientation] == 'Orientation':
147-
break
148-
exif = dict(image.getexif().items())
149-
if exif.get(orientation) == 3:
150-
image = image.rotate(180, expand=True)
151-
elif exif.get(orientation) == 6:
152-
image = image.rotate(270, expand=True)
153-
elif exif.get(orientation) == 8:
154-
image = image.rotate(90, expand=True)
155-
except (AttributeError, KeyError, IndexError):
156-
pass
157-
return image
172+
with Image.open(image_path) as image:
173+
orientation = None
174+
try:
175+
for orientation in ExifTags.TAGS.keys():
176+
if ExifTags.TAGS[orientation] == 'Orientation':
177+
break
178+
exif = dict(image.getexif().items())
179+
if exif.get(orientation) == 3:
180+
image = image.rotate(180, expand=True)
181+
elif exif.get(orientation) == 6:
182+
image = image.rotate(270, expand=True)
183+
elif exif.get(orientation) == 8:
184+
image = image.rotate(90, expand=True)
185+
except (AttributeError, KeyError, IndexError):
186+
pass
187+
return image
158188

159189

‎tests/test_doctr_ocr_functions.py

+22-17
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,10 @@ def test_generate_key_value_pairs(test_server_config, requests_mock):
3333
'5-14y': [None]
3434
})
3535

36-
requests_mock.get("http://test.com/api/dataSets/10?fields=dataSetElements", json={'dataSetElements' : [{"categoryCombo": {"id": 5},"dataElement": {"id": 1},"dataSet": {"id": 10}}, {"categoryCombo": {"id": 5},"dataElement": {"id": 3},"dataSet": {"id": 10}}]})
37-
requests_mock.get("http://test.com/api/categoryCombos/5?fields=categoryOptionCombos", json={"categoryOptionCombos": [{"id": 8}, {"id": 9}]})
38-
requests_mock.get("http://test.com/api/categoryOptionCombos/8?fields=name", json={"name": "0-11m"})
39-
requests_mock.get("http://test.com/api/categoryOptionCombos/9?fields=name", json={"name": "12-59m"})
40-
requests_mock.get("http://test.com/api/dataElements?filter=dataSetElements.dataSet.id:eq:10&fields=id,formName", json={'dataElements': [{'formName': 'BCG','id':1},{'formName': 'Polio (OPV) 1 (from 6 wks)','id':3}]})
41-
42-
assert len(ocr_functions.generate_key_value_pairs(df, 10)) == 0
36+
assert len(ocr_functions.generate_key_value_pairs(df, {'groups': [{'fields':[{"label": "Paed (0-59m) vacc target population 0-11m",
37+
"dataElement": "paedid",
38+
"categoryOptionCombo": "0to11mid",
39+
"type": "INTEGER_POSITIVE"}]}]})) == 0
4340

4441
df = pd.DataFrame({
4542
'0': ['BCG', 'Polio (OPV) 0 (birth dose)', 'Polio (OPV) 1 (from 6 wks)'],
@@ -48,16 +45,24 @@ def test_generate_key_value_pairs(test_server_config, requests_mock):
4845
'5-14y': [None, None, None]
4946
})
5047

51-
requests_mock.get("http://test.com/api/dataElements?filter=formName:ilike:BCG", json={"dataElements":[{"id": 1, "displayName": "AVAC_002 BCG"}]})
52-
requests_mock.get("http://test.com/api/categoryOptions?filter=name:ilike:0-11m", json={'categoryOptions': [{'id': 2, 'displayName': '0-11m'}]})
53-
requests_mock.get("http://test.com/api/dataElements?filter=formName:ilike:Polio (OPV) 1 (from 6 wks)", json={'dataElements': [{'id': 3, 'displayName': 'AVAC_006 Polio (OPV) 1 (from 6 wks)'}]})
54-
requests_mock.get("http://test.com/api/categoryOptions?filter=name:ilike:12-59m", json={'categoryOptions': [{'id': 4, 'displayName': '12-59m'}]})
55-
56-
answer = [{'dataElement': '', 'categoryOptions': '', 'value': '45+29'},
57-
{'dataElement': '', 'categoryOptions': '', 'value': '30+18'},
58-
{'dataElement': '', 'categoryOptions': '', 'value': '55+29'}]
59-
60-
data_element_pairs = ocr_functions.generate_key_value_pairs(df, 10)
48+
answer = [{'dataElement': 'bcgid', 'categoryOptions': '0to11mid', 'value': '45+29'},
49+
{'dataElement': 'polioid', 'categoryOptions': '0to11mid', 'value': '30+18'},
50+
{'dataElement': 'polioid', 'categoryOptions': '5to14yid', 'value': '55+29'}]
51+
52+
data_element_pairs = ocr_functions.generate_key_value_pairs(df,
53+
{'groups': [{'fields':[{"label": "BCG 0-11m",
54+
"dataElement": "bcgid",
55+
"categoryOptionCombo": "0to11mid",
56+
"type": "INTEGER_POSITIVE"}]},
57+
{'fields':[{"label": "Polio (OPV) 1 (from 6 wks) 0-11m",
58+
"dataElement": "polioid",
59+
"categoryOptionCombo": "0to11mid",
60+
"type": "INTEGER_POSITIVE"}]},
61+
{'fields':[{"label": "Polio (OPV) 1 (from 6 wks) 12-59m",
62+
"dataElement": "polioid",
63+
"categoryOptionCombo": "5to14yid",
64+
"type": "INTEGER_POSITIVE"}]}]})
65+
6166
assert len(data_element_pairs) == len(answer)
6267

6368
for i in range(len(data_element_pairs)):

0 commit comments

Comments
 (0)
Please sign in to comment.