|
| 1 | +"""Run validation test for paragraph_text_recognizer module.""" |
| 2 | +import argparse |
| 3 | +import time |
| 4 | +import unittest |
| 5 | + |
| 6 | +import pytorch_lightning as pl |
| 7 | +import torch |
| 8 | + |
| 9 | +from im2latex.data import Im2Latex100K |
| 10 | +from im2latex.im2latex_inference import Im2LatexInference |
| 11 | + |
| 12 | +_TEST_BLEU_SCORE = 0.5951 |
| 13 | +_TEST_CHARACTER_ERROR_RATE = 0.2506 |
| 14 | +_TEST_EDIT_DISTANCE = 0.6595 |
| 15 | + |
| 16 | + |
| 17 | +class TestEvaluateIm2LatexInference(unittest.TestCase): |
| 18 | + """Evaluate Im2LatexInference on the Im2Latex100K test dataset.""" |
| 19 | + |
| 20 | + @torch.no_grad() |
| 21 | + def test_evaluate(self): |
| 22 | + dataset = Im2Latex100K(argparse.Namespace(batch_size=16, num_workers=2)) |
| 23 | + dataset.prepare_data() |
| 24 | + dataset.setup() |
| 25 | + |
| 26 | + inference = Im2LatexInference() |
| 27 | + trainer = pl.Trainer(gpus=1) |
| 28 | + |
| 29 | + start_time = time.time() |
| 30 | + metrics = trainer.test(inference.lit_model, datamodule=dataset) |
| 31 | + end_time = time.time() |
| 32 | + |
| 33 | + test_bleu = round(metrics[0]["test_bleu"], 4) |
| 34 | + test_cer = round(metrics[0]["test_cer"], 4) |
| 35 | + test_edit = round(metrics[0]["test_edit"], 4) |
| 36 | + time_taken = round((end_time - start_time) / 60, 2) |
| 37 | + |
| 38 | + print(f"Character error rate: {test_cer}, time_taken: {time_taken} m") |
| 39 | + self.assertGreaterEqual(test_bleu, _TEST_BLEU_SCORE) |
| 40 | + self.assertLessEqual(test_cer, _TEST_CHARACTER_ERROR_RATE) |
| 41 | + self.assertGreaterEqual(test_edit, _TEST_EDIT_DISTANCE) |
| 42 | + self.assertLess(time_taken, 20) |
0 commit comments