Skip to content

Commit dc96c28

Browse files
committed
Minimal tensorflow 2 port
- Run `tf_upgrade_v2` - Manually map some older tf.contrib objects not covered by `tf_upgrade_v2` (`MutableHashTable`, `layers.batch_norm`, `layers.dropout`.) - Port the dropped `rnn_cell_impl._linear` function. - Remove .value accessors for `get_shape()` values.
1 parent 2ef51cb commit dc96c28

File tree

8 files changed

+133
-79
lines changed

8 files changed

+133
-79
lines changed

aocr/__main__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .util.data_gen import DataGen
1818
from .util.export import Exporter
1919

20-
tf.logging.set_verbosity(tf.logging.ERROR)
20+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
2121

2222

2323
def process_args(args, defaults):
@@ -212,7 +212,7 @@ def main(args=None):
212212
console.setFormatter(formatter)
213213
logging.getLogger('').addHandler(console)
214214

215-
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
215+
with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True)) as sess:
216216

217217
if parameters.phase == 'dataset':
218218
dataset.generate(

aocr/model/cnn.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ def var_random(name, shape, regularizable=False):
1515
:param regularizable:
1616
:return:
1717
'''
18-
v = tf.get_variable(name, shape=shape, initializer=tf.contrib.layers.xavier_initializer())
18+
v = tf.compat.v1.get_variable(name, shape=shape, initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
1919
if regularizable:
20-
with tf.name_scope(name + '/Regularizer/'):
21-
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, tf.nn.l2_loss(v))
20+
with tf.compat.v1.name_scope(name + '/Regularizer/'):
21+
tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, tf.nn.l2_loss(v))
2222
return v
2323

2424

@@ -29,8 +29,8 @@ def max_2x2pool(incoming, name):
2929
:param name:
3030
:return:
3131
'''
32-
with tf.variable_scope(name):
33-
return tf.nn.max_pool(incoming, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME')
32+
with tf.compat.v1.variable_scope(name):
33+
return tf.nn.max_pool2d(input=incoming, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME')
3434

3535

3636
def max_2x1pool(incoming, name):
@@ -40,8 +40,8 @@ def max_2x1pool(incoming, name):
4040
:param name:
4141
:return:
4242
'''
43-
with tf.variable_scope(name):
44-
return tf.nn.max_pool(incoming, ksize=(1, 2, 1, 1), strides=(1, 2, 1, 1), padding='SAME')
43+
with tf.compat.v1.variable_scope(name):
44+
return tf.nn.max_pool2d(input=incoming, ksize=(1, 2, 1, 1), strides=(1, 2, 1, 1), padding='SAME')
4545

4646

4747
def ConvRelu(incoming, num_filters, filter_size, name):
@@ -54,14 +54,14 @@ def ConvRelu(incoming, num_filters, filter_size, name):
5454
:return:
5555
'''
5656
num_filters_from = incoming.get_shape().as_list()[3]
57-
with tf.variable_scope(name):
57+
with tf.compat.v1.variable_scope(name):
5858
conv_W = var_random(
5959
'W',
6060
tuple(filter_size) + (num_filters_from, num_filters),
6161
regularizable=True
6262
)
6363

64-
after_conv = tf.nn.conv2d(incoming, conv_W, strides=(1, 1, 1, 1), padding='SAME')
64+
after_conv = tf.nn.conv2d(incoming, filters=conv_W, strides=(1, 1, 1, 1), padding='SAME')
6565

6666
return tf.nn.relu(after_conv)
6767

@@ -73,7 +73,7 @@ def batch_norm(incoming, is_training):
7373
:param is_training:
7474
:return:
7575
'''
76-
return tf.contrib.layers.batch_norm(incoming, is_training=is_training, scale=True, decay=0.99)
76+
return tf.compat.v1.layers.batch_normalization(incoming, training=is_training, scale=True, momentum=0.99)
7777

7878

7979
def ConvReluBN(incoming, num_filters, filter_size, name, is_training):
@@ -87,22 +87,22 @@ def ConvReluBN(incoming, num_filters, filter_size, name, is_training):
8787
:return:
8888
'''
8989
num_filters_from = incoming.get_shape().as_list()[3]
90-
with tf.variable_scope(name):
90+
with tf.compat.v1.variable_scope(name):
9191
conv_W = var_random(
9292
'W',
9393
tuple(filter_size) + (num_filters_from, num_filters),
9494
regularizable=True
9595
)
9696

97-
after_conv = tf.nn.conv2d(incoming, conv_W, strides=(1, 1, 1, 1), padding='SAME')
97+
after_conv = tf.nn.conv2d(incoming, filters=conv_W, strides=(1, 1, 1, 1), padding='SAME')
9898

9999
after_bn = batch_norm(after_conv, is_training)
100100

101101
return tf.nn.relu(after_bn)
102102

103103

104104
def dropout(incoming, is_training, keep_prob=0.5):
105-
return tf.contrib.layers.dropout(incoming, keep_prob=keep_prob, is_training=is_training)
105+
return tf.compat.v1.layers.dropout(incoming, rate=1 - keep_prob, training=is_training)
106106

107107

108108
def tf_create_attention_map(incoming):

aocr/model/model.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(self,
113113
self.height = tf.constant(DataGen.IMAGE_HEIGHT, dtype=tf.int32)
114114
self.height_float = tf.constant(DataGen.IMAGE_HEIGHT, dtype=tf.float64)
115115

116-
self.img_pl = tf.placeholder(tf.string, name='input_image_as_bytes')
116+
self.img_pl = tf.compat.v1.placeholder(tf.string, name='input_image_as_bytes')
117117
self.img_data = tf.cond(
118118
tf.less(tf.rank(self.img_pl), 1),
119119
lambda: tf.expand_dims(self.img_pl, 0),
@@ -156,7 +156,7 @@ def __init__(self,
156156
forward_only=self.forward_only,
157157
use_gru=use_gru)
158158

159-
table = tf.contrib.lookup.MutableHashTable(
159+
table = tf.lookup.experimental.MutableHashTable(
160160
key_dtype=tf.int64,
161161
value_dtype=tf.string,
162162
default_value="",
@@ -226,12 +226,12 @@ def __init__(self,
226226
self.updates = []
227227
self.summaries_by_bucket = []
228228

229-
params = tf.trainable_variables()
230-
opt = tf.train.AdadeltaOptimizer(learning_rate=initial_learning_rate)
229+
params = tf.compat.v1.trainable_variables()
230+
opt = tf.compat.v1.train.AdadeltaOptimizer(learning_rate=initial_learning_rate)
231231
loss_op = self.attention_decoder_model.loss
232232

233233
if self.reg_val > 0:
234-
reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
234+
reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
235235
logging.info('Adding %s regularization losses', len(reg_losses))
236236
logging.debug('REGULARIZATION_LOSSES: %s', reg_losses)
237237
loss_op = self.reg_val * tf.reduce_sum(reg_losses) + loss_op
@@ -242,14 +242,14 @@ def __init__(self,
242242

243243
# Summaries for loss, variables, gradients, gradient norms and total gradient norm.
244244
summaries = [
245-
tf.summary.scalar("loss", loss_op),
246-
tf.summary.scalar("total_gradient_norm", tf.global_norm(gradients))
245+
tf.compat.v1.summary.scalar("loss", loss_op),
246+
tf.compat.v1.summary.scalar("total_gradient_norm", tf.linalg.global_norm(gradients))
247247
]
248-
all_summaries = tf.summary.merge(summaries)
248+
all_summaries = tf.compat.v1.summary.merge(summaries)
249249
self.summaries_by_bucket.append(all_summaries)
250250

251251
# update op - apply gradients
252-
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
252+
update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
253253
with tf.control_dependencies(update_ops):
254254
self.updates.append(
255255
opt.apply_gradients(
@@ -258,7 +258,7 @@ def __init__(self,
258258
)
259259
)
260260

261-
self.saver_all = tf.train.Saver(tf.all_variables())
261+
self.saver_all = tf.compat.v1.train.Saver(tf.compat.v1.all_variables())
262262
self.checkpoint_path = os.path.join(self.model_dir, "model.ckpt")
263263

264264
ckpt = tf.train.get_checkpoint_state(model_dir)
@@ -268,7 +268,7 @@ def __init__(self,
268268
self.saver_all.restore(self.sess, ckpt.model_checkpoint_path)
269269
else:
270270
logging.info("Created model with fresh parameters.")
271-
self.sess.run(tf.initialize_all_variables())
271+
self.sess.run(tf.compat.v1.initialize_all_variables())
272272

273273
def predict(self, image_file_data):
274274
input_feed = {}
@@ -370,7 +370,7 @@ def train(self, data_path, num_epoch):
370370
loss = 0.0
371371
current_step = 0
372372
skipped_counter = 0
373-
writer = tf.summary.FileWriter(self.model_dir, self.sess.graph)
373+
writer = tf.compat.v1.summary.FileWriter(self.model_dir, self.sess.graph)
374374

375375
logging.info('Starting the training process.')
376376
for batch in s_gen.gen(self.batch_size):
@@ -498,18 +498,18 @@ def _prepare_image(self, image):
498498
dims = tf.shape(img)
499499
width = self.max_width
500500

501-
max_width = tf.to_int32(tf.ceil(tf.truediv(dims[1], dims[0]) * self.height_float))
502-
max_height = tf.to_int32(tf.ceil(tf.truediv(width, max_width) * self.height_float))
501+
max_width = tf.cast(tf.math.ceil(tf.truediv(dims[1], dims[0]) * self.height_float), dtype=tf.int32)
502+
max_height = tf.cast(tf.math.ceil(tf.truediv(width, max_width) * self.height_float), dtype=tf.int32)
503503

504504
resized = tf.cond(
505505
tf.greater_equal(width, max_width),
506506
lambda: tf.cond(
507507
tf.less_equal(dims[0], self.height),
508-
lambda: tf.to_float(img),
509-
lambda: tf.image.resize_images(img, [self.height, max_width],
508+
lambda: tf.cast(img, dtype=tf.float32),
509+
lambda: tf.image.resize(img, [self.height, max_width],
510510
method=tf.image.ResizeMethod.BICUBIC),
511511
),
512-
lambda: tf.image.resize_images(img, [max_height, width],
512+
lambda: tf.image.resize(img, [max_height, width],
513513
method=tf.image.ResizeMethod.BICUBIC)
514514
)
515515

0 commit comments

Comments
 (0)