Skip to content

Commit 8796f2e

Browse files
committed
Add evaluation tests
1 parent 0636052 commit 8796f2e

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

Diff for: im2latex/evaluation/evaluate_im2latex_inference.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""Run validation test for paragraph_text_recognizer module."""
2+
import argparse
3+
import time
4+
import unittest
5+
6+
import pytorch_lightning as pl
7+
import torch
8+
9+
from im2latex.data import Im2Latex100K
10+
from im2latex.im2latex_inference import Im2LatexInference
11+
12+
_TEST_BLEU_SCORE = 0.5951
13+
_TEST_CHARACTER_ERROR_RATE = 0.2506
14+
_TEST_EDIT_DISTANCE = 0.6595
15+
16+
17+
class TestEvaluateIm2LatexInference(unittest.TestCase):
18+
"""Evaluate Im2LatexInference on the Im2Latex100K test dataset."""
19+
20+
@torch.no_grad()
21+
def test_evaluate(self):
22+
dataset = Im2Latex100K(argparse.Namespace(batch_size=16, num_workers=2))
23+
dataset.prepare_data()
24+
dataset.setup()
25+
26+
inference = Im2LatexInference()
27+
trainer = pl.Trainer(gpus=1)
28+
29+
start_time = time.time()
30+
metrics = trainer.test(inference.lit_model, datamodule=dataset)
31+
end_time = time.time()
32+
33+
test_bleu = round(metrics[0]["test_bleu"], 4)
34+
test_cer = round(metrics[0]["test_cer"], 4)
35+
test_edit = round(metrics[0]["test_edit"], 4)
36+
time_taken = round((end_time - start_time) / 60, 2)
37+
38+
print(f"Character error rate: {test_cer}, time_taken: {time_taken} m")
39+
self.assertGreaterEqual(test_bleu, _TEST_BLEU_SCORE)
40+
self.assertLessEqual(test_cer, _TEST_CHARACTER_ERROR_RATE)
41+
self.assertGreaterEqual(test_edit, _TEST_EDIT_DISTANCE)
42+
self.assertLess(time_taken, 20)

0 commit comments

Comments
 (0)