|
| 1 | +"""Test for im2latex_inference module.""" |
| 2 | +import json |
| 3 | +import os |
| 4 | +import time |
| 5 | +from pathlib import Path |
| 6 | + |
| 7 | +import editdistance |
| 8 | + |
| 9 | +from im2latex.im2latex_inference import Im2LatexInference |
| 10 | + |
| 11 | +os.environ["CUDA_VISIBLE_DEVICES"] = "" |
| 12 | + |
| 13 | + |
| 14 | +_FILE_DIRNAME = Path(__file__).parents[0].resolve() |
| 15 | +_SUPPORT_DIRNAME = _FILE_DIRNAME / "support" / "im2latex_100k" |
| 16 | + |
| 17 | +# restricting number of samples to prevent CirleCI running out of time |
| 18 | +_NUM_MAX_SAMPLES = 2 if os.environ.get("CIRCLECI", False) else 100 |
| 19 | + |
| 20 | + |
| 21 | +def test_im2latex_inference(): |
| 22 | + """Test Im2LatexInference.""" |
| 23 | + support_filenames = list(_SUPPORT_DIRNAME.glob("*.png")) |
| 24 | + with open(_SUPPORT_DIRNAME / "data_by_file_id.json", "r") as f: |
| 25 | + support_data_by_file_id = json.load(f) |
| 26 | + |
| 27 | + start_time = time.time() |
| 28 | + reasoner = Im2LatexInference() |
| 29 | + end_time = time.time() |
| 30 | + print(f"Time taken to initialize Im2LatexInference: {round(end_time - start_time, 2)}s") |
| 31 | + |
| 32 | + for i, support_filename in enumerate(support_filenames): |
| 33 | + if i >= _NUM_MAX_SAMPLES: |
| 34 | + break |
| 35 | + expected_text = support_data_by_file_id[support_filename.stem]["predicted_text"] |
| 36 | + start_time = time.time() |
| 37 | + predicted_text = _test_im2latex_inference(support_filename, expected_text, reasoner) |
| 38 | + end_time = time.time() |
| 39 | + time_taken = round(end_time - start_time, 2) |
| 40 | + |
| 41 | + cer = _character_error_rate(support_data_by_file_id[support_filename.stem]["ground_truth_text"], predicted_text) |
| 42 | + print(f"Character error rate is {round(cer, 3)} for file {support_filename.name} (time taken: {time_taken}s)") |
| 43 | + |
| 44 | + |
| 45 | +def _test_im2latex_inference(image_filename: Path, expected_text: str, reasoner: Im2LatexInference): |
| 46 | + """Test Im2LatexInference on 1 image.""" |
| 47 | + predicted_text = reasoner.predict(image_filename) |
| 48 | + assert predicted_text == expected_text, f"predicted text does not match expected for {image_filename.name}" |
| 49 | + return predicted_text |
| 50 | + |
| 51 | + |
| 52 | +def _character_error_rate(str_a: str, str_b: str) -> float: |
| 53 | + """Return character error rate.""" |
| 54 | + return editdistance.eval(str_a, str_b) / max(len(str_a), len(str_b)) |
0 commit comments