Skip to content

Commit 9f2f7ea

Browse files
committed
Add inference tests
1 parent d935e2e commit 9f2f7ea

File tree

5 files changed

+71
-0
lines changed

5 files changed

+71
-0
lines changed
Loading
Loading
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
"7944775fc9": {
3+
"ground_truth_text": "\\alpha _ { 1 } ^ { r } \\gamma _ { 1 } + \\dots + \\alpha _ { N } ^ { r } \\gamma _ { N } = 0 \\quad ( r = 1 , . . . , R ) \\; ,",
4+
"predicted_text": "\\alpha _ { 1 } ^ { \\gamma } \\gamma _ { 1 } + . . . + \\alpha _ { N } ^ { \\gamma } \\gamma _ { N } = 0 \\quad ( r = 1 , . . , R ) \\, ,",
5+
"character_error_rate": 0.111
6+
},
7+
"566cf0c6f5": {
8+
"ground_truth_text": "\\dot { z } _ { 1 } = - N ^ { z } ( z _ { 1 } ) = - g ( z _ { 1 } ) = - \\frac { z _ { 1 } } { P _ { z } ( z _ { 2 } - z _ { 1 } ) } ; ~ ~ ~ \\dot { z } _ { 2 } = - \\frac { z _ { 2 } } { P _ { z } ( z _ { 2 } - z _ { 1 } ) }",
9+
"predicted_text": "\\dot { z } _ { 1 } = - N ^ { z } ( z _ { 1 } ) = - g ( z _ { 1 } ) = - \\frac { z _ { 1 } } { z _ { 2 } ( z _ { 2 } - z _ { 1 } ) } ; \\quad \\dot { z } _ { 2 } = - \\frac { z _ { 2 } } { \\bar { z } _ { z } ( z _ { 2 } - z _ { 1 } ) }",
10+
"character_error_rate": 0.074
11+
},
12+
"4c0185889d": {
13+
"ground_truth_text": "{ \\cal L } ( J ) = \\frac { 1 } { 2 } \\partial _ { \\mu } \\phi \\partial ^ { \\mu } \\phi + \\frac { J } { 2 } \\phi ^ { 2 } + \\frac { \\lambda \\mu ^ { 2 \\varepsilon } } { 4 ! } \\phi ^ { 4 } + { \\cal L } _ { \\mathrm { C T } } ( J ) - \\mu ^ { - 2 \\varepsilon } \\frac { \\zeta } { 2 } \\; J ^ { 2 } .",
14+
"predicted_text": "{ \\cal L } ( J ) = \\frac { 1 } { 2 } \\partial _ { \\mu } \\phi \\partial ^ { \\mu } \\phi + \\frac { 1 } { 2 } \\phi ^ { 2 } + \\frac { \\lambda \\mu ^ { 2 } } { 4 ! } \\phi ^ { 4 } + { \\cal L } _ { \\mathrm { C T } } ( J ) - \\mu ^ { 2 } \\frac { \\xi } { 2 } \\, J ^ { 2 } .",
15+
"character_error_rate": 0.154
16+
}
17+
}
+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Test for im2latex_inference module."""
2+
import json
3+
import os
4+
import time
5+
from pathlib import Path
6+
7+
import editdistance
8+
9+
from im2latex.im2latex_inference import Im2LatexInference
10+
11+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
12+
13+
14+
_FILE_DIRNAME = Path(__file__).parents[0].resolve()
15+
_SUPPORT_DIRNAME = _FILE_DIRNAME / "support" / "im2latex_100k"
16+
17+
# restricting number of samples to prevent CirleCI running out of time
18+
_NUM_MAX_SAMPLES = 2 if os.environ.get("CIRCLECI", False) else 100
19+
20+
21+
def test_im2latex_inference():
22+
"""Test Im2LatexInference."""
23+
support_filenames = list(_SUPPORT_DIRNAME.glob("*.png"))
24+
with open(_SUPPORT_DIRNAME / "data_by_file_id.json", "r") as f:
25+
support_data_by_file_id = json.load(f)
26+
27+
start_time = time.time()
28+
reasoner = Im2LatexInference()
29+
end_time = time.time()
30+
print(f"Time taken to initialize Im2LatexInference: {round(end_time - start_time, 2)}s")
31+
32+
for i, support_filename in enumerate(support_filenames):
33+
if i >= _NUM_MAX_SAMPLES:
34+
break
35+
expected_text = support_data_by_file_id[support_filename.stem]["predicted_text"]
36+
start_time = time.time()
37+
predicted_text = _test_im2latex_inference(support_filename, expected_text, reasoner)
38+
end_time = time.time()
39+
time_taken = round(end_time - start_time, 2)
40+
41+
cer = _character_error_rate(support_data_by_file_id[support_filename.stem]["ground_truth_text"], predicted_text)
42+
print(f"Character error rate is {round(cer, 3)} for file {support_filename.name} (time taken: {time_taken}s)")
43+
44+
45+
def _test_im2latex_inference(image_filename: Path, expected_text: str, reasoner: Im2LatexInference):
46+
"""Test Im2LatexInference on 1 image."""
47+
predicted_text = reasoner.predict(image_filename)
48+
assert predicted_text == expected_text, f"predicted text does not match expected for {image_filename.name}"
49+
return predicted_text
50+
51+
52+
def _character_error_rate(str_a: str, str_b: str) -> float:
53+
"""Return character error rate."""
54+
return editdistance.eval(str_a, str_b) / max(len(str_a), len(str_b))

0 commit comments

Comments
 (0)