Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix confidence computation and filtering #70

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions pero_ocr/core/confidence_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pero_ocr.core.force_alignment import align_text

global_confidence_quantile = 0.33

def get_letter_confidence(logits: np.ndarray, alignment: typing.List[int], blank_ind: int) -> typing.List[float]:
"""Function which estimates confidence of characters as the maximal log-prob aligned to them.
Expand Down Expand Up @@ -70,14 +71,40 @@ def squeeze(sequence):
return result


def get_line_confidence(line, labels=None, aligned_letters=None, log_probs=None):
# There is the same number of outputs as labels (probably transformer model was used) --> each letter has only one
# possible frame in logits thus it is not needed to align them
def get_page_confidence_from_transcription_confidences(transcription_confidences):
if len(transcription_confidences) == 0:
return 0
return np.quantile(transcription_confidences, global_confidence_quantile)


def get_transcription_confidence_from_characters(character_confidences) -> float:
if len(character_confidences) == 0:
return 0
return np.quantile(character_confidences, global_confidence_quantile)


def get_transcription_confidence(line, labels=None, aligned_letters=None, log_probs=None):
character_confidences = get_character_confidences(line, labels, aligned_letters, log_probs)
transcription_confidence = get_transcription_confidence_from_characters(character_confidences)
return transcription_confidence


def get_word_confidence_from_characters(word_character_confidences):
return np.quantile(word_character_confidences, global_confidence_quantile)


def get_character_confidences(line, labels=None, aligned_letters=None, log_probs=None) -> np.ndarray:
if labels is None:
labels = line.get_labels()

if len(labels) == 0:
return np.array([])


# There is the same number of outputs as labels (probably transformer model was used) --> each letter has only one
# possible frame in logits thus it is not needed to align them
if line.logits.shape[0] == len(labels):
return get_line_confidence_transformer(line, labels)
return get_character_confidences_transformer(line, labels)

if log_probs is None:
log_probs = line.get_full_logprobs()
Expand Down Expand Up @@ -107,7 +134,7 @@ def get_line_confidence(line, labels=None, aligned_letters=None, log_probs=None)
return confidences


def get_line_confidence_transformer(line, labels):
def get_character_confidences_transformer(line, labels):
probs = np.exp(line.get_full_logprobs())
confidences = probs[np.arange(len(labels)), labels]
return confidences
3 changes: 3 additions & 0 deletions pero_ocr/core/force_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def viterbi_align(neg_logits: np.ndarray, A: np.ndarray) -> typing.List[int]:


def align_text(neg_logprobs, transcription, blank_symbol):
if neg_logprobs.shape[0] == len(transcription):
return np.array(list(range(neg_logprobs.shape[0])))

logit_characters = force_align(neg_logprobs, transcription, blank_symbol, return_seq_positions=True)

max_probs = (-neg_logprobs).max(axis=-1)
Expand Down
68 changes: 43 additions & 25 deletions pero_ocr/core/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

from pero_ocr.core.crop_engine import EngineLineCropper
from pero_ocr.core.force_alignment import align_text
from pero_ocr.core.confidence_estimation import get_line_confidence
from pero_ocr.core.confidence_estimation import (get_character_confidences, get_transcription_confidence,
get_transcription_confidence_from_characters,
get_word_confidence_from_characters,
get_page_confidence_from_transcription_confidences)
from pero_ocr.core.arabic_helper import ArabicHelper

Num = Union[int, float]
Expand Down Expand Up @@ -52,6 +55,7 @@ def __init__(self, id: str = None,
crop: Optional[np.ndarray] = None,
characters: Optional[List[str]] = None,
logit_coords: Optional[Union[List[Tuple[int]], List[Tuple[None]]]] = None,
character_confidences: Optional[List[Num]] = None,
transcription_confidence: Optional[Num] = None,
index: Optional[int] = None,
category: Optional[str] = None):
Expand All @@ -65,6 +69,7 @@ def __init__(self, id: str = None,
self.crop = crop
self.characters = characters
self.logit_coords = logit_coords
self.character_confidences = character_confidences
self.transcription_confidence = transcription_confidence
self.category = category

Expand All @@ -77,6 +82,28 @@ def get_full_logprobs(self, zero_logit_value: int = -80):
dense_logits = self.get_dense_logits(zero_logit_value)
return log_softmax(dense_logits)

def calculate_confidences(self, default_transcription_confidence=None):
if self.logits is None:
logger.warning(f'Error: Unable to calculate confidences for line {self.id} due to missing logits.')
self.character_confidences = None
self.transcription_confidence = None
return

try:
# logit cropping should not be done for confidence calculation - only for word alignment
log_probs = self.get_full_logprobs()
self.character_confidences = get_character_confidences(self, log_probs=log_probs)
except KeyboardInterrupt:
raise
except Exception as e:
logger.warning(f'Error: Unable to calculate confidences for line {self.id} due to exception: {e}.')
self.character_confidences = None

if self.character_confidences is not None:
self.transcription_confidence = get_transcription_confidence_from_characters(self.character_confidences)
else:
self.transcription_confidence = default_transcription_confidence

def to_pagexml(self, region_element: ET.SubElement, fallback_id: int, validate_id: bool = False):
text_line = ET.SubElement(region_element, "TextLine")
text_line.set("id", export_id(self.id, validate_id))
Expand Down Expand Up @@ -177,6 +204,9 @@ def from_pagexml_parse_custom(self, custom_str):
self.heights = heights.tolist()

def to_altoxml(self, text_block, arabic_helper, min_line_confidence, version: ALTOVersion):
if self.character_confidences is None or self.transcription_confidence is None:
self.calculate_confidences()

if self.transcription_confidence is not None and self.transcription_confidence < min_line_confidence:
return

Expand Down Expand Up @@ -229,9 +259,6 @@ def to_altoxml_text(self, text_line, arabic_helper,
if arabic_helper.is_arabic_line(self.transcription):
arabic_line = True

logits = None
logprobs = None
aligned_letters = None
try:
label = self.get_labels()
blank_idx = self.logits.shape[1] - 1
Expand All @@ -242,16 +269,6 @@ def to_altoxml_text(self, text_line, arabic_helper,
except (ValueError, IndexError, TypeError) as e:
logger.warning(f'Error: Alto export, unable to align line {self.id} due to exception: {e}.')

if logits is not None and logits.shape[0] > 0:
max_val = np.max(logits, axis=1)
logits = logits - max_val[:, np.newaxis]
probs = np.exp(logits)
probs = probs / np.sum(probs, axis=1, keepdims=True)
probs = np.max(probs, axis=1)
self.transcription_confidence = np.quantile(probs, .50)
else:
self.transcription_confidence = 0.0

average_word_width = (text_line_hpos + text_line_width) / len(self.transcription.split())
for w, word in enumerate(self.transcription.split()):
string = ET.SubElement(text_line, "String")
Expand All @@ -274,9 +291,6 @@ def to_altoxml_text(self, text_line, arabic_helper,
splitted_transcription = self.transcription.split()
lm_const = line_coords.shape[1] / logits.shape[0]
letter_counter = 0
confidences = get_line_confidence(self, np.array(label), aligned_letters, logprobs)
# if self.transcription_confidence is None:
self.transcription_confidence = np.quantile(confidences, .50)
for w, word in enumerate(words):
extension = 2
while line_coords.size > 0 and extension < 40:
Expand All @@ -301,9 +315,9 @@ def to_altoxml_text(self, text_line, arabic_helper,
if self.transcription_confidence == 1:
word_confidence = 1
else:
if confidences.size != 0:
word_confidence = np.quantile(
confidences[letter_counter:letter_counter + len(splitted_transcription[w])], .50)
if self.character_confidences.size != 0:
word_character_confidences = self.character_confidences[letter_counter:letter_counter + len(splitted_transcription[w])]
word_confidence = get_word_confidence_from_characters(word_character_confidences)

string = ET.SubElement(text_line, "String")

Expand Down Expand Up @@ -667,13 +681,19 @@ def __init__(self, id: str = None, page_size: List[Tuple[int]] = (0, 0), file: s
self.page_size = page_size # (height, width)
self.regions: List[RegionLayout] = []
self.reading_order = None
self.confidence = None

if file is not None:
self.from_pagexml(file)

if self.reading_order is not None and len(self.regions) > 0:
self.sort_regions_by_reading_order()

def calculate_confidence(self):
transcription_confidences = [line.transcription_confidence for line in self.lines_iterator([None, 'text'])
if line.transcription_confidence is not None]
self.confidence = get_page_confidence_from_transcription_confidences(transcription_confidences)

def from_pagexml_string(self, pagexml_string: str):
self.from_pagexml(BytesIO(pagexml_string.encode('utf-8')))

Expand Down Expand Up @@ -973,7 +993,7 @@ def render_to_image(self, image, thickness: int = 2, circles: bool = True,
def lines_iterator(self, categories: list = None):
for region in self.regions:
for line in region.lines:
if not categories or line.category in categories:
if not categories or not line.category or line.category in categories:
yield line

def get_quality(self, x: int = None, y: int = None, width: int = None, height: int = None, power: int = 6):
Expand Down Expand Up @@ -1021,8 +1041,6 @@ def get_quality(self, x: int = None, y: int = None, width: int = None, height: i
counter += 1

lm_const = line_coords.shape[1] / logits.shape[0]
confidences = get_line_confidence(line, np.array(label), aligned_letters, logprobs)
line.transcription_confidence = np.quantile(confidences, .50)
for w, word in enumerate(words):
extension = 2
while True:
Expand All @@ -1038,9 +1056,9 @@ def get_quality(self, x: int = None, y: int = None, width: int = None, height: i
hpos = int(np.min(all_x))
if x and y and height and width:
if vpos >= y and vpos <= (y+height) and hpos >= x and hpos <= (x+width):
bbox_confidences.append(confidences[only_letters[w]])
bbox_confidences.append(line.character_confidences[only_letters[w]])
else:
bbox_confidences.append(confidences[only_letters[w]])
bbox_confidences.append(line.character_confidences[only_letters[w]])

if len(bbox_confidences) != 0:
return (1 / len(bbox_confidences) * (np.power(bbox_confidences, power).sum())) ** (1 / power)
Expand Down
39 changes: 7 additions & 32 deletions pero_ocr/document_ocr/page_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pero_ocr.utils import compose_path, config_get_list
from pero_ocr.core.layout import PageLayout, RegionLayout, TextLine
import pero_ocr.core.crop_engine as cropper
from pero_ocr.core.confidence_estimation import get_line_confidence
from pero_ocr.ocr_engine.pytorch_ocr_engine import PytorchEngineLineOCR
from pero_ocr.ocr_engine.transformer_ocr_engine import TransformerEngineLineOCR
from pero_ocr.layout_engines.simple_region_engine import SimpleThresholdRegion
Expand Down Expand Up @@ -552,7 +551,7 @@ def process_page(self, img, page_layout: PageLayout):
logits=line_logits,
characters=self.ocr_engine.characters,
logit_coords=line_logit_coords)
new_line.transcription_confidence = self.get_line_confidence(new_line)
new_line.calculate_confidences(default_transcription_confidence=self.default_confidence)

if not self.update_transcription_by_confidence:
self.update_line(line, new_line)
Expand Down Expand Up @@ -581,17 +580,6 @@ def substitute_transcriptions(self, lines_to_process: List[TextLine]):
for line, transcription_substituted in zip(lines_to_process, transcriptions_substituted):
line.transcription = transcription_substituted

def get_line_confidence(self, line):
if line.transcription:
try:
log_probs = line.get_full_logprobs()[line.logit_coords[0]:line.logit_coords[1]]
confidences = get_line_confidence(line, log_probs=log_probs)
return np.quantile(confidences, .50)
except (ValueError, IndexError) as e:
logger.warning(f'PageOCR is unable to get confidence of line {line.id} due to exception: {e}.')
return self.default_confidence
return self.default_confidence

@property
def provides_ctc_logits(self):
return isinstance(self.ocr_engine, PytorchEngineLineOCR) or isinstance(self.ocr_engine, TransformerEngineLineOCR)
Expand All @@ -602,6 +590,7 @@ def update_line(line, new_line):
line.logits = new_line.logits
line.characters = new_line.characters
line.logit_coords = new_line.logit_coords
line.character_confidences = new_line.character_confidences
line.transcription_confidence = new_line.transcription_confidence


Expand Down Expand Up @@ -653,31 +642,13 @@ def __init__(self, config, device=None, config_path='', ):
if self.run_decoder:
self.decoder = page_decoder_factory(config, self.device, config_path=config_path)

@staticmethod
def compute_line_confidence(line, threshold=None):
logits = line.get_dense_logits()
log_probs = logits - np.logaddexp.reduce(logits, axis=1)[:, np.newaxis]
best_ids = np.argmax(log_probs, axis=-1)
best_probs = np.exp(np.max(log_probs, axis=-1))
worst_best_prob = get_prob(best_ids, best_probs)
# print(worst_best_prob, np.sum(np.exp(best_probs) < threshold), best_probs.shape, np.nonzero(np.exp(best_probs) < threshold))
# for i in np.nonzero(np.exp(best_probs) < threshold)[0]:
# print(best_probs[i-1:i+2], best_ids[i-1:i+2])

return worst_best_prob

@property
def provides_ctc_logits(self):
if not self.ocrs:
return False

return any(ocr.provides_ctc_logits for ocr in self.ocrs.values())

def update_confidences(self, page_layout):
for line in page_layout.lines_iterator():
if line.logits is not None:
line.transcription_confidence = self.compute_line_confidence(line)

def filter_confident_lines(self, page_layout):
for region in page_layout.regions:
region.lines = [line for line in region.lines if line.transcription_confidence > self.filter_confident_lines_threshold]
Expand All @@ -694,14 +665,18 @@ def process_page(self, image, page_layout):
page_layout = self.line_croppers[key].process_page(image, page_layout)
if self.run_ocr and key in self.ocrs:
page_layout = self.ocrs[key].process_page(image, page_layout)

if self.run_decoder:
page_layout = self.decoder.process_page(page_layout)

self.update_confidences(page_layout)
for line in page_layout.lines_iterator():
line.calculate_confidences()

if self.filter_confident_lines_threshold > 0:
page_layout = self.filter_confident_lines(page_layout)

page_layout.calculate_confidence()

return page_layout

def init_config_sections(self, config, config_path, section_name, section_factory) -> dict:
Expand Down
4 changes: 2 additions & 2 deletions user_scripts/merge_ocr_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys

from pero_ocr.core.layout import PageLayout
from pero_ocr.core.confidence_estimation import get_line_confidence
from pero_ocr.core.confidence_estimation import get_character_confidences


def parse_arguments():
Expand Down Expand Up @@ -39,7 +39,7 @@ def get_confidences(line):
char_map = dict([(c, i) for i, c in enumerate(line.characters)])
c_idx = np.asarray([char_map[c] for c in line.transcription])
try:
confidences = get_line_confidence(line, c_idx)
confidences = get_character_confidences(line, c_idx)
except ValueError:
print('ERROR: Known error in get_line_confidence() - Please, fix it. Logit slice has zero length.')
confidences = np.ones(len(line.transcription)) * 0.5
Expand Down