Skip to content

Commit 2768bd4

Browse files
committed
Filter out large images
1 parent a69433f commit 2768bd4

File tree

1 file changed

+48
-21
lines changed

1 file changed

+48
-21
lines changed

im2latex/data/im2latex_100k.py

+48-21
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,51 @@
11
"""IM2LATEX100K DataModule"""
2-
from itertools import compress
2+
import argparse
3+
import json
4+
import pickle
5+
import shutil
6+
import tarfile
37
from collections import Counter, OrderedDict
48
from concurrent.futures import ThreadPoolExecutor
9+
from itertools import compress
510
from pathlib import Path
6-
import pickle
711
from random import shuffle
8-
from typing import Callable, List, Sequence, Tuple, Union
9-
import json
10-
import argparse
11-
import shutil
12-
import tarfile
12+
from typing import Callable, List, MutableMapping, Sequence, Union
1313

1414
import numpy as np
15+
import toml
1516
from PIL import Image
16-
from torch.utils.data import Sampler
17-
from torch.utils.data import DataLoader
17+
from torch.utils.data import DataLoader, Sampler
1818
from torchvision import transforms
19-
import toml
2019

21-
from im2latex.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info
20+
from im2latex.data.base_data_module import BaseDataModule, _download_raw_dataset, load_and_print_info
2221
from im2latex.data.util import BaseDataset, SequenceOrTensor, convert_strings_to_labels
2322

2423
IMAGE_HEIGHT = None
2524
IMAGE_WIDTH = None
2625
MAX_LABEL_LENGTH = 150
2726
MIN_COUNT = 10
27+
IMAGE_HIGHT_WIDTH_GROUP = [
28+
(32, 128),
29+
(64, 128),
30+
(32, 160),
31+
(64, 160),
32+
(32, 192),
33+
(64, 192),
34+
(32, 224),
35+
(64, 224),
36+
(32, 256),
37+
(64, 256),
38+
(32, 320),
39+
(64, 320),
40+
(32, 384),
41+
(64, 384),
42+
(96, 384),
43+
(32, 480),
44+
(64, 480),
45+
(128, 480),
46+
(160, 480),
47+
]
48+
2849

2950
RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "im2latex_100k"
3051
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
@@ -36,6 +57,8 @@
3657
# 1. implement __repr__
3758
# 2. add parameters to select normalization or raw latex
3859
# 3. rename function
60+
61+
3962
class Im2Latex100K(BaseDataModule):
4063
"""
4164
Im2Latex100K DataModule.
@@ -56,7 +79,7 @@ def __init__(self, args: argparse.Namespace = None) -> None:
5679
self.mapping = list(vocab_dict["vocab"]) # label to string
5780
self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} # string to label
5881

59-
self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
82+
self.dims = list(map(lambda x: (1, *x), IMAGE_HIGHT_WIDTH_GROUP)) # (1, IMAGE_HEIGHT, IMAGE_WIDTH)
6083
assert self.max_label_length <= MAX_LABEL_LENGTH
6184
self.output_dims = (MAX_LABEL_LENGTH, 1)
6285

@@ -88,7 +111,11 @@ def setup(self, stage=None) -> None:
88111

89112
def _load_dataset(data_dict, split: str, augment: bool, max_label_length: int) -> BaseDataset:
90113
# https://stackoverflow.com/questions/13397385/python-filter-and-list-and-apply-filtered-indices-to-another-list
91-
selectors = list(map(lambda y: not len(y) > max_label_length, data_dict[f"y_{split}"]))
114+
selectors_length = list(map(lambda y: not len(y) > max_label_length, data_dict[f"y_{split}"]))
115+
selectors_shape = list(
116+
map(lambda hight_width: hight_width in IMAGE_HIGHT_WIDTH_GROUP, data_dict[f"x_shape_{split}"])
117+
)
118+
selectors = np.logical_and(selectors_length, selectors_shape)
92119
x = list(compress(data_dict[f"x_{split}"], selectors))
93120
y = list(compress(data_dict[f"y_{split}"], selectors))
94121
x_shape = list(compress(data_dict[f"x_shape_{split}"], selectors))
@@ -113,7 +140,7 @@ def train_dataloader(self):
113140
batch_sampler=BucketBatchSampler(
114141
list((i, data_shape) for i, data_shape in enumerate(self.data_train.data_shape)),
115142
self.batch_size,
116-
shuffle=True,
143+
do_shuffle=True,
117144
),
118145
num_workers=self.num_workers,
119146
pin_memory=self.on_gpu,
@@ -126,7 +153,7 @@ def val_dataloader(self):
126153
batch_sampler=BucketBatchSampler(
127154
list((i, data_shape) for i, data_shape in enumerate(self.data_val.data_shape)),
128155
self.batch_size,
129-
shuffle=False,
156+
do_shuffle=False,
130157
),
131158
num_workers=self.num_workers,
132159
pin_memory=self.on_gpu,
@@ -139,7 +166,7 @@ def test_dataloader(self):
139166
batch_sampler=BucketBatchSampler(
140167
list((i, data_shape) for i, data_shape in enumerate(self.data_test.data_shape)),
141168
self.batch_size,
142-
shuffle=False,
169+
do_shuffle=False,
143170
),
144171
num_workers=self.num_workers,
145172
pin_memory=self.on_gpu,
@@ -154,7 +181,7 @@ def _download_and_process_im2latex(vocab_filename, min_count: int = 10):
154181
_process_raw_dataset(metadata, vocab_filename, min_count=min_count)
155182

156183

157-
def _process_raw_dataset(metadata: dict, vocab_filename: Union[Path, str], min_count: int = 10):
184+
def _process_raw_dataset(metadata: MutableMapping, vocab_filename: Union[Path, str], min_count: int = 10):
158185
# unzip tar file
159186
img_tarfile = DL_DATA_DIRNAME / metadata["formula_images_processed"]["filename"]
160187
if not (PROCESSED_DATA_DIRNAME / "formula_images_processed").is_dir():
@@ -176,7 +203,7 @@ def _process_raw_dataset(metadata: dict, vocab_filename: Union[Path, str], min_c
176203
"Save `array (from image)`, `latex`, `image size` and `image file name` to the dictionary "
177204
+ "with the corresponding keys `x_{split}`, `y_{split}`, `x_shape_{split}` and `img_filename_{split}`..."
178205
)
179-
data_dict = {
206+
data_dict: dict = {
180207
"x_train": [],
181208
"x_validate": [],
182209
"x_test": [],
@@ -284,10 +311,10 @@ def __init__(
284311
# https://discuss.pytorch.org/t/tensorflow-esque-bucket-by-sequence-length/41284/13
285312
class BucketBatchSampler(Sampler):
286313
# want inputs to be an array
287-
def __init__(self, idx_imgsize, batch_size, shuffle=False):
314+
def __init__(self, idx_imgsize, batch_size, do_shuffle=False):
288315
self.idx_imgsize = idx_imgsize
289316
self.batch_size = batch_size
290-
self.shuffle = shuffle
317+
self.shuffle = do_shuffle
291318
self.batch_list = self._generate_batch_map()
292319
self.num_batches = len(self.batch_list)
293320

@@ -325,7 +352,7 @@ def __iter__(self):
325352
def build_vocab(min_count: int = 10) -> Sequence[str]:
326353
"""Add the mapping with special symbols."""
327354
# listdir = Path(listdir)
328-
counter = Counter()
355+
counter: Counter = Counter()
329356
vocab = []
330357

331358
formulas = get_all_formulas(split_it=True)

0 commit comments

Comments
 (0)