Skip to content

Commit bdb608e

Browse files
committed
[PyTorch/ConvNets] Fixing error for MixUp
1 parent 7f67aa4 commit bdb608e

File tree

1 file changed

+15
-3
lines changed
  • PyTorch/Classification/ConvNets/image_classification

1 file changed

+15
-3
lines changed

PyTorch/Classification/ConvNets/image_classification/training.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,16 @@ def forward(self):
123123
self._forward = self.optimize(self._forward_fn)
124124
return self._forward
125125

126+
def train(self):
127+
self.model.train()
128+
if self.loss is not None:
129+
self.loss.train()
130+
131+
def eval(self):
132+
self.model.eval()
133+
if self.loss is not None:
134+
self.loss.eval()
135+
126136

127137
class Trainer:
128138
def __init__(
@@ -145,12 +155,14 @@ def __init__(
145155
self.steps_since_update = 0
146156

147157
def train(self):
148-
self.executor.model.train()
158+
self.executor.train()
159+
if self.use_ema:
160+
self.ema_executor.train()
149161

150162
def eval(self):
151-
self.executor.model.eval()
163+
self.executor.eval()
152164
if self.use_ema:
153-
self.executor.model.eval()
165+
self.ema_executor.eval()
154166

155167
def train_step(self, input, target, step=None):
156168
loss = self.executor.forward_backward(input, target)

0 commit comments

Comments
 (0)