Skip to content

Commit

Permalink
fix batch_size in transformer_main.py (tensorflow#4897)
Browse files Browse the repository at this point in the history
* fix batch_size in transformer_main.py

fix batch_size in transformer_main.py which causes ResourceExhaustedError: OOM during training Transformer models using models/official/transformer

* small format change

change format from one line to multiple ones in order to pass lint tests

* remove trailing space and add comment
  • Loading branch information
Jiang Yu authored and Taylor Robie committed Jul 26, 2018
1 parent c1588f0 commit 2d7a0d6
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions official/transformer/transformer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,9 +555,12 @@ def run_transformer(flags_obj):

params["use_synthetic_data"] = flags_obj.use_synthetic_data

# Set batch size parameter, which depends on TPU and distribution settings.
params["batch_size"] = (
flags_obj.batch_size or params["default_batch_size_tpu"])
# Set batch size parameter, which depends on the availability of
# TPU and GPU, and distribution settings.
params["batch_size"] = (flags_obj.batch_size or (
params["default_batch_size_tpu"] if params["use_tpu"]
else params["default_batch_size"]))

if not params["use_tpu"]:
params["batch_size"] = distribution_utils.per_device_batch_size(
params["batch_size"], num_gpus)
Expand Down

0 comments on commit 2d7a0d6

Please sign in to comment.