Skip to content

Commit

Permalink
Fix type annotation issues
Browse files Browse the repository at this point in the history
  • Loading branch information
EhsanKia committed Nov 11, 2021
1 parent 356b1ab commit 6099755
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 16 deletions.
7 changes: 5 additions & 2 deletions catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import numpy
import pytesseract
import random
import typing
import unicodedata

from PIL import Image
from typing import Dict, Iterator, List, Set
from typing import Dict, Iterator, List, Set, Tuple

# The expected color for the video background.
TOP_COLOR = (110, 233, 238)
Expand Down Expand Up @@ -111,6 +112,7 @@ def run_ocr(item_rows: List[numpy.ndarray], lang: str = 'eng') -> Set[str]:
parsed_text = pytesseract.image_to_string(
Image.fromarray(cv2.vconcat(item_rows)),
lang=lang, config=_get_tesseract_config(lang))
assert isinstance(parsed_text, str), 'Tesseract returned bytes'

# Split the results and remove empty lines.
clean_names = {_cleanup_name(item, lang)
Expand Down Expand Up @@ -294,7 +296,8 @@ def _detect_locale(item_rows: List[numpy.ndarray], locale: str) -> str:
image = Image.fromarray(cv2.vconcat(item_rows))

try:
osd_data = pytesseract.image_to_osd(image, output_type=pytesseract.Output.DICT)
osd_data = typing.cast(Dict[str, str], pytesseract.image_to_osd(
image, output_type=pytesseract.Output.DICT))
except pytesseract.TesseractError:
return 'en-us'

Expand Down
10 changes: 5 additions & 5 deletions critters.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def scan(video_file: str, locale: str = 'en-us') -> ScanResult:
)


def parse_video(filename: str) -> List[CritterImage]:
def parse_video(filename: str) -> List[CritterIcon]:
"""Parses a whole video and returns icons for all critters found."""
all_icons: List[CritterImage] = []
all_icons: List[CritterIcon] = []
section_count: Dict[CritterType, int] = collections.defaultdict(int)
for critter_type, frame in _read_frames(filename):
section_count[critter_type] += 1
Expand All @@ -83,7 +83,7 @@ def parse_video(filename: str) -> List[CritterImage]:
return _remove_blanks(all_icons)


def match_critters(critter_icons: List[CritterImage]) -> List[str]:
def match_critters(critter_icons: List[CritterIcon]) -> List[str]:
"""Matches icons against database of critter images, finding best matches."""
matched_critters = set()
critter_db = _get_critter_db()
Expand Down Expand Up @@ -209,7 +209,7 @@ def _parse_frame(frame: numpy.ndarray) -> Iterator[numpy.ndarray]:
yield frame[y+8:y+88, x+16:x+96]


def _remove_blanks(all_icons: List[numpy.ndarray]) -> List[numpy.ndarray]:
def _remove_blanks(all_icons: List[CritterIcon]) -> List[CritterIcon]:
"""Remove all icons that show empty critter boxes."""
filtered_icons = []
for icon in all_icons:
Expand Down Expand Up @@ -256,5 +256,5 @@ def slow_similarity_metric(critter):


if __name__ == "__main__":
results = scan('examples/extra/critters_badpage.mp4')
results = scan('examples/critters.mp4')
print('\n'.join(results.items))
2 changes: 1 addition & 1 deletion music.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _parse_frame(frame: numpy.ndarray) -> Iterator[List[numpy.ndarray]]:
# then it averages the frame across the Y-axis to find the area rows.
# Lastly, it finds the y-positions marking the start/end of each row.
thresh = cv2.inRange(frame[:410], bg_color - 30, bg_color + 30)
separators = numpy.diff(thresh.mean(axis=1) > 100).nonzero()[0]
separators = numpy.nonzero(numpy.diff(thresh.mean(axis=1) > 100))[0]
if len(separators) < 2:
return

Expand Down
8 changes: 4 additions & 4 deletions reactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def scan(image_file: str, locale: str = 'en-us') -> ScanResult:
)


def parse_image(filename: str) -> List[ReactionImage]:
def parse_image(filename: str) -> List[numpy.ndarray]:
"""Parses a screenshot and returns icons for all reactions found."""
icon_pages: Dict[int, List[ReactionImage]] = {}
icon_pages: Dict[int, List[numpy.ndarray]] = {}
assertion_error: Optional[AssertionError] = None

cap = cv2.VideoCapture(filename)
Expand All @@ -84,10 +84,10 @@ def parse_image(filename: str) -> List[ReactionImage]:
if assertion_error and (filename.endswith('.jpg') or not icon_pages):
raise assertion_error

return itertools.chain.from_iterable(icon_pages.values())
return [icon for page in icon_pages.values() for icon in page]


def match_reactions(reaction_icons: List[ReactionImage]) -> List[str]:
def match_reactions(reaction_icons: List[numpy.ndarray]) -> List[str]:
"""Matches icons against database of reactions images, finding best matches."""
matched_reactions = set()
reaction_db = _get_reaction_db()
Expand Down
8 changes: 4 additions & 4 deletions recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _parse_frame(frame: numpy.ndarray) -> Iterable[List[numpy.ndarray]]:
# then it averages the frame across the Y-axis to find the area rows.
# Lastly, it finds the y-positions marking the start/end of each row.
thresh = cv2.inRange(frame, (185, 215, 218), (210, 230, 237))
separators = numpy.diff(thresh.mean(axis=1) > 195).nonzero()[0]
separators = numpy.nonzero(numpy.diff(thresh.mean(axis=1) > 195))[0]

# We do a first pass finding all sensible y positions.
y_positions = []
Expand Down Expand Up @@ -197,8 +197,8 @@ def _get_color_db() -> Dict[int, Tuple[int, int, int]]:
"""Fetches the item database for a given locale, with caching."""
with open(os.path.join('recipes', 'colors.json')) as fp:
colors_data = json.load(fp)
return {int(color_id): tuple(reversed(rgb))
for color_id, rgb in colors_data.items()}
return {int(color_id): (b, g, r)
for color_id, (r, g, b) in colors_data.items()}


def _get_candidate_recipes(card: numpy.ndarray) -> Iterable[RecipeCard]:
Expand Down Expand Up @@ -245,5 +245,5 @@ def slow_similarity_metric(recipe):


if __name__ == "__main__":
results = scan('examples/extra/recipes_img.jpg')
results = scan('examples/recipes.mp4')
print('\n'.join(results.items))

0 comments on commit 6099755

Please sign in to comment.