Skip to content

Commit d299c50

Browse files
nektor211emedvedev
authored andcommitted
Python 3 fixes (#12)
* aocr.util.data_get fixed python3 StringIO import * util.dataset.generate: fixed line parsing * .gitignore: added *~ * util.dataset.generate: use enumerate * util.dataset.generate: fix bytes handling * model.model: use xrange, pass range as list * utils.data_gen: python3 IO handling fixed
1 parent 2a48d82 commit d299c50

File tree

4 files changed

+25
-13
lines changed

4 files changed

+25
-13
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,5 @@ misc/
116116
data/evaluation_data
117117
.DS_Store
118118
.venv
119+
120+
*~

aocr/model/model.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def __init__(self,
172172
)
173173

174174
insert = table.insert(
175-
tf.constant(range(len(DataGen.CHARMAP)), dtype=tf.int64),
175+
tf.constant(list(range(len(DataGen.CHARMAP))), dtype=tf.int64),
176176
tf.constant(DataGen.CHARMAP),
177177
)
178178

@@ -425,17 +425,17 @@ def visualize_attention(self, filename, attentions, output, label, flag_incorrec
425425
(mw, h),
426426
Image.ANTIALIAS)
427427
img_data = np.asarray(img, dtype=np.uint8)
428-
for idx in range(len(output)):
428+
for idx in xrange(len(output)):
429429
output_filename = os.path.join(output_dir, 'image_%d.jpg' % (idx))
430430
attention = attentions[idx][:(int(mw/4)-1)]
431431
attention_orig = np.zeros(mw)
432-
for i in range(mw):
432+
for i in xrange(mw):
433433
if i/4-1 > 0 and i/4-1 < len(attention):
434434
attention_orig[i] = attention[int(i/4)-1]
435435
attention_orig = np.convolve(attention_orig, [0.199547, 0.200226, 0.200454, 0.200226, 0.199547], mode='same')
436436
attention_orig = np.maximum(attention_orig, 0.3)
437437
attention_out = np.zeros((h, mw))
438-
for i in range(mw):
438+
for i in xrange(mw):
439439
attention_out[:, i] = attention_orig[i]
440440
if len(img_data.shape) == 3:
441441
attention_out = attention_out[:, :, np.newaxis]

aocr/util/data_gen.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33

44
from .bucketdata import BucketData
55
from PIL import Image
6-
from StringIO import StringIO
7-
6+
try:
7+
from StringIO import StringIO as IO
8+
except ImportError:
9+
from io import BytesIO as IO # to handle py2 vs 3
810

911
class DataGen(object):
1012
GO_ID = 1
@@ -54,7 +56,7 @@ def gen(self, batch_size):
5456
raw_images, raw_labels = sess.run([images, labels])
5557
for img, lex in zip(raw_images, raw_labels):
5658

57-
if self.max_width and (Image.open(StringIO(img)).size[0] <= self.max_width):
59+
if self.max_width and (Image.open(IO(img)).size[0] <= self.max_width):
5860

5961
word = self.convert_lex(lex)
6062

@@ -71,6 +73,8 @@ def gen(self, batch_size):
7173
self.clear()
7274

7375
def convert_lex(self, lex):
76+
if isinstance(lex, bytes):
77+
lex = lex.decode()
7478
assert lex and len(lex) < self.bucket_specs[-1][1]
7579

7680
return np.array(

aocr/util/dataset.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import tensorflow as tf
22
import logging
33

4+
import sys
5+
6+
if sys.version_info[0] < 3:
7+
text_type = unicode
8+
else:
9+
text_type = str
410

511
def _bytes_feature(value):
612
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
@@ -17,21 +23,21 @@ def generate(annotations_path, output_path, log_step=5000):
1723
writer = tf.python_io.TFRecordWriter(output_path)
1824
count = 0
1925

20-
with open(annotations_path, 'r') as file:
21-
for (img_path, label) in file.readlines():
22-
idx += 1
26+
with open(annotations_path, 'r') as f:
27+
for idx, line in enumerate(f):
28+
(img_path, label) = line.rstrip('\n').split('\t', 1)
2329
with open(img_path, 'rb') as img_file:
2430
img = img_file.read()
2531

2632
example = tf.train.Example(features=tf.train.Features(feature={
2733
'image': _bytes_feature(img),
28-
'label': _bytes_feature(label)}))
34+
'label': _bytes_feature(text_type.encode(label))}))
2935

3036
writer.write(example.SerializeToString())
3137

3238
if idx % log_step == 0:
33-
logging.info('Processed %s pairs.', idx)
39+
logging.info('Processed %s pairs.', idx+1)
3440

35-
logging.info('Dataset is ready: %i pairs.', idx)
41+
logging.info('Dataset is ready: %i pairs.', idx+1)
3642

3743
writer.close()

0 commit comments

Comments
 (0)