Skip to content

Commit e58b36a

Browse files
committed
Removed confidence value calculation (currently unused) to reduce run time
1 parent 45e73ea commit e58b36a

File tree

2 files changed

+36
-41
lines changed

2 files changed

+36
-41
lines changed

app_doctr.py

+16-40
Original file line numberDiff line numberDiff line change
@@ -44,33 +44,14 @@
4444
@st.cache_resource
4545
def create_ocr():
4646
"""
47-
Load docTR ocr model and img2table docTR model
47+
Load img2table docTR model
4848
"""
49-
ocr_model = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
5049
doctr_ocr = DocTR(detect_language=False)
51-
return ocr_model, doctr_ocr
50+
return doctr_ocr
5251

5352
@st.cache_data(show_spinner=False)
54-
def get_uploaded_images(tally_sheet):
55-
"""
56-
List of images uploaded by user as docTR DocumentFiles
57-
:param Files uploaded by user
58-
:return List of images uploaded by user as docTR DocumentFiles
59-
"""
60-
res = []
61-
for sheet in tally_sheet:
62-
sheet.seek(0, 0)
63-
image = sheet.read()
64-
res.append(DocumentFile.from_images(image))
65-
return res
66-
67-
@st.cache_data(show_spinner=False)
68-
def get_results(uploaded_images):
69-
return [doctr_ocr_functions.get_word_level_content(ocr_model, doc) for doc in uploaded_images]
70-
71-
@st.cache_data
72-
def get_tabular_content_wrapper(_doctr_ocr, img, confidence_lookup_dict):
73-
return doctr_ocr_functions.get_tabular_content(_doctr_ocr, img, confidence_lookup_dict)
53+
def get_tabular_content_wrapper(_doctr_ocr, img):
54+
return doctr_ocr_functions.get_tabular_content(_doctr_ocr, img)
7455

7556
@st.cache_data
7657
def get_DE_COC_List_wrapper(form):
@@ -288,7 +269,7 @@ def evaluate_cells(table_dfs):
288269
return table_dfs
289270

290271
def clean_up(table_dfs):
291-
"""Uses simple_eval to perform math operations on each cell, defaulting to input if failed.
272+
"""Cleans up values in table that are returned as the string "None" by OCR model into empty string ""
292273
293274
Args:
294275
table_dfs (_List_): List of table data frames
@@ -367,7 +348,7 @@ def authenticate():
367348
key=st.session_state['upload_key'])
368349

369350
# OCR Model
370-
ocr_model, doctr_ocr = create_ocr()
351+
doctr_ocr = create_ocr()
371352

372353
# Once images are uploaded
373354
if len(tally_sheet_images) > 0:
@@ -451,22 +432,17 @@ def authenticate():
451432
# ***************************************
452433

453434
# Populate streamlit with data recognized from tally sheets
454-
with st.spinner("Running image recognition..."):
455-
uploaded_images = get_uploaded_images(tally_sheet_images)
456-
results = get_results(uploaded_images)
457-
435+
458436
# Spinner for data upload. If it's going to be on screen for long, make it bespoke
459-
table_dfs, page_nums_to_display = [], []
460-
for i, sheet in enumerate(tally_sheet_images):
461-
image = uploaded_images[i]
462-
result = results[i]
463-
confidence_lookup_dict = doctr_ocr_functions.get_confidence_values(result)
464-
img = Image(src=sheet)
465-
table_df, confidence_df = get_tabular_content_wrapper(doctr_ocr, img, confidence_lookup_dict)
466-
table_dfs.extend(table_df)
467-
page_nums_to_display.extend([str(i + 1)] * len(table_df))
468-
table_dfs = clean_up(table_dfs)
469-
table_dfs = evaluate_cells(table_dfs)
437+
with st.spinner("Running image recognition..."):
438+
table_dfs, page_nums_to_display = [], []
439+
for i, sheet in enumerate(tally_sheet_images):
440+
img = Image(src=sheet)
441+
table_df = get_tabular_content_wrapper(doctr_ocr, img)
442+
table_dfs.extend(table_df)
443+
page_nums_to_display.extend([str(i + 1)] * len(table_df))
444+
table_dfs = clean_up(table_dfs)
445+
table_dfs = evaluate_cells(table_dfs)
470446

471447

472448
# Form session state initialization

src/msfocr/doctr/ocr_functions.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def get_confidence_values(res):
7474
return confidence_dict
7575

7676

77-
def get_tabular_content(model, image, confidence_dict=None):
77+
def get_tabular_content_with_confidence(model, image, confidence_dict=None):
7878
"""
7979
Runs the input image in the OCR model. Detects all tables and content within tables and stores results as
8080
a list of pandas dataFrames (table_df). Calculates confidence values for all detected values in table_df
@@ -114,6 +114,25 @@ def get_tabular_content(model, image, confidence_dict=None):
114114

115115
return table_df, confidence_df
116116

117+
def get_tabular_content(model, image):
118+
"""
119+
Runs the input image in the OCR model. Detects all tables and content within tables and stores results as
120+
a list of pandas dataFrames (table_df).
121+
:param model: OCR model
122+
:param image: Image to be tested (Image object from img2table package)
123+
:return: Dataframe table_df
124+
"""
125+
extracted_tables = image.extract_tables(ocr=model,
126+
implicit_rows=False,
127+
borderless_tables=False,
128+
min_confidence=50)
129+
130+
table_df = []
131+
for _, table in enumerate(extracted_tables):
132+
table_df.append(table.df)
133+
134+
return table_df
135+
117136
def get_sheet_type(res):
118137
"""
119138
Finds the type of the tally sheet (dataSet, orgUnit, period) from the result of OCR model, where

0 commit comments

Comments
 (0)