1
1
"""IM2LATEX100K DataModule"""
2
- from itertools import compress
2
+ import argparse
3
+ import json
4
+ import pickle
5
+ import shutil
6
+ import tarfile
3
7
from collections import Counter , OrderedDict
4
8
from concurrent .futures import ThreadPoolExecutor
9
+ from itertools import compress
5
10
from pathlib import Path
6
- import pickle
7
11
from random import shuffle
8
- from typing import Callable , List , Sequence , Tuple , Union
9
- import json
10
- import argparse
11
- import shutil
12
- import tarfile
12
+ from typing import Callable , List , MutableMapping , Sequence , Union
13
13
14
14
import numpy as np
15
+ import toml
15
16
from PIL import Image
16
- from torch .utils .data import Sampler
17
- from torch .utils .data import DataLoader
17
+ from torch .utils .data import DataLoader , Sampler
18
18
from torchvision import transforms
19
- import toml
20
19
21
- from im2latex .data .base_data_module import _download_raw_dataset , BaseDataModule , load_and_print_info
20
+ from im2latex .data .base_data_module import BaseDataModule , _download_raw_dataset , load_and_print_info
22
21
from im2latex .data .util import BaseDataset , SequenceOrTensor , convert_strings_to_labels
23
22
24
23
IMAGE_HEIGHT = None
25
24
IMAGE_WIDTH = None
26
25
MAX_LABEL_LENGTH = 150
27
26
MIN_COUNT = 10
27
+ IMAGE_HIGHT_WIDTH_GROUP = [
28
+ (32 , 128 ),
29
+ (64 , 128 ),
30
+ (32 , 160 ),
31
+ (64 , 160 ),
32
+ (32 , 192 ),
33
+ (64 , 192 ),
34
+ (32 , 224 ),
35
+ (64 , 224 ),
36
+ (32 , 256 ),
37
+ (64 , 256 ),
38
+ (32 , 320 ),
39
+ (64 , 320 ),
40
+ (32 , 384 ),
41
+ (64 , 384 ),
42
+ (96 , 384 ),
43
+ (32 , 480 ),
44
+ (64 , 480 ),
45
+ (128 , 480 ),
46
+ (160 , 480 ),
47
+ ]
48
+
28
49
29
50
RAW_DATA_DIRNAME = BaseDataModule .data_dirname () / "raw" / "im2latex_100k"
30
51
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
36
57
# 1. implement __repr__
37
58
# 2. add parameters to select normalization or raw latex
38
59
# 3. rename function
60
+
61
+
39
62
class Im2Latex100K (BaseDataModule ):
40
63
"""
41
64
Im2Latex100K DataModule.
@@ -56,7 +79,7 @@ def __init__(self, args: argparse.Namespace = None) -> None:
56
79
self .mapping = list (vocab_dict ["vocab" ]) # label to string
57
80
self .inverse_mapping = {v : k for k , v in enumerate (self .mapping )} # string to label
58
81
59
- self .dims = (1 , IMAGE_HEIGHT , IMAGE_WIDTH )
82
+ self .dims = list ( map ( lambda x : ( 1 , * x ), IMAGE_HIGHT_WIDTH_GROUP )) # (1, IMAGE_HEIGHT, IMAGE_WIDTH)
60
83
assert self .max_label_length <= MAX_LABEL_LENGTH
61
84
self .output_dims = (MAX_LABEL_LENGTH , 1 )
62
85
@@ -88,7 +111,11 @@ def setup(self, stage=None) -> None:
88
111
89
112
def _load_dataset (data_dict , split : str , augment : bool , max_label_length : int ) -> BaseDataset :
90
113
# https://stackoverflow.com/questions/13397385/python-filter-and-list-and-apply-filtered-indices-to-another-list
91
- selectors = list (map (lambda y : not len (y ) > max_label_length , data_dict [f"y_{ split } " ]))
114
+ selectors_length = list (map (lambda y : not len (y ) > max_label_length , data_dict [f"y_{ split } " ]))
115
+ selectors_shape = list (
116
+ map (lambda hight_width : hight_width in IMAGE_HIGHT_WIDTH_GROUP , data_dict [f"x_shape_{ split } " ])
117
+ )
118
+ selectors = np .logical_and (selectors_length , selectors_shape )
92
119
x = list (compress (data_dict [f"x_{ split } " ], selectors ))
93
120
y = list (compress (data_dict [f"y_{ split } " ], selectors ))
94
121
x_shape = list (compress (data_dict [f"x_shape_{ split } " ], selectors ))
@@ -113,7 +140,7 @@ def train_dataloader(self):
113
140
batch_sampler = BucketBatchSampler (
114
141
list ((i , data_shape ) for i , data_shape in enumerate (self .data_train .data_shape )),
115
142
self .batch_size ,
116
- shuffle = True ,
143
+ do_shuffle = True ,
117
144
),
118
145
num_workers = self .num_workers ,
119
146
pin_memory = self .on_gpu ,
@@ -126,7 +153,7 @@ def val_dataloader(self):
126
153
batch_sampler = BucketBatchSampler (
127
154
list ((i , data_shape ) for i , data_shape in enumerate (self .data_val .data_shape )),
128
155
self .batch_size ,
129
- shuffle = False ,
156
+ do_shuffle = False ,
130
157
),
131
158
num_workers = self .num_workers ,
132
159
pin_memory = self .on_gpu ,
@@ -139,7 +166,7 @@ def test_dataloader(self):
139
166
batch_sampler = BucketBatchSampler (
140
167
list ((i , data_shape ) for i , data_shape in enumerate (self .data_test .data_shape )),
141
168
self .batch_size ,
142
- shuffle = False ,
169
+ do_shuffle = False ,
143
170
),
144
171
num_workers = self .num_workers ,
145
172
pin_memory = self .on_gpu ,
@@ -154,7 +181,7 @@ def _download_and_process_im2latex(vocab_filename, min_count: int = 10):
154
181
_process_raw_dataset (metadata , vocab_filename , min_count = min_count )
155
182
156
183
157
- def _process_raw_dataset (metadata : dict , vocab_filename : Union [Path , str ], min_count : int = 10 ):
184
+ def _process_raw_dataset (metadata : MutableMapping , vocab_filename : Union [Path , str ], min_count : int = 10 ):
158
185
# unzip tar file
159
186
img_tarfile = DL_DATA_DIRNAME / metadata ["formula_images_processed" ]["filename" ]
160
187
if not (PROCESSED_DATA_DIRNAME / "formula_images_processed" ).is_dir ():
@@ -176,7 +203,7 @@ def _process_raw_dataset(metadata: dict, vocab_filename: Union[Path, str], min_c
176
203
"Save `array (from image)`, `latex`, `image size` and `image file name` to the dictionary "
177
204
+ "with the corresponding keys `x_{split}`, `y_{split}`, `x_shape_{split}` and `img_filename_{split}`..."
178
205
)
179
- data_dict = {
206
+ data_dict : dict = {
180
207
"x_train" : [],
181
208
"x_validate" : [],
182
209
"x_test" : [],
@@ -284,10 +311,10 @@ def __init__(
284
311
# https://discuss.pytorch.org/t/tensorflow-esque-bucket-by-sequence-length/41284/13
285
312
class BucketBatchSampler (Sampler ):
286
313
# want inputs to be an array
287
- def __init__ (self , idx_imgsize , batch_size , shuffle = False ):
314
+ def __init__ (self , idx_imgsize , batch_size , do_shuffle = False ):
288
315
self .idx_imgsize = idx_imgsize
289
316
self .batch_size = batch_size
290
- self .shuffle = shuffle
317
+ self .shuffle = do_shuffle
291
318
self .batch_list = self ._generate_batch_map ()
292
319
self .num_batches = len (self .batch_list )
293
320
@@ -325,7 +352,7 @@ def __iter__(self):
325
352
def build_vocab (min_count : int = 10 ) -> Sequence [str ]:
326
353
"""Add the mapping with special symbols."""
327
354
# listdir = Path(listdir)
328
- counter = Counter ()
355
+ counter : Counter = Counter ()
329
356
vocab = []
330
357
331
358
formulas = get_all_formulas (split_it = True )
0 commit comments