Skip to content

Commit

Permalink
Use float32 metrics in mnist_eager
Browse files Browse the repository at this point in the history
float32 should be fine for mnist loss and accuracy metrics and float64
is not available on TPUs.
  • Loading branch information
iganichev committed Jul 21, 2018
1 parent 61ec602 commit dfafba4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions official/mnist/mnist_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):

def test(model, dataset):
"""Perform an evaluation of `model` on the examples from `dataset`."""
avg_loss = tfe.metrics.Mean('loss')
accuracy = tfe.metrics.Accuracy('accuracy')
avg_loss = tfe.metrics.Mean('loss', dtype=tf.float32)
accuracy = tfe.metrics.Accuracy('accuracy', dtype=tf.float32)

for (images, labels) in dataset:
logits = model(images, training=False)
Expand Down

0 comments on commit dfafba4

Please sign in to comment.