File tree Expand file tree Collapse file tree 1 file changed +15
-3
lines changed
PyTorch/Classification/ConvNets/image_classification Expand file tree Collapse file tree 1 file changed +15
-3
lines changed Original file line number Diff line number Diff line change @@ -123,6 +123,16 @@ def forward(self):
123
123
self ._forward = self .optimize (self ._forward_fn )
124
124
return self ._forward
125
125
126
+ def train (self ):
127
+ self .model .train ()
128
+ if self .loss is not None :
129
+ self .loss .train ()
130
+
131
+ def eval (self ):
132
+ self .model .eval ()
133
+ if self .loss is not None :
134
+ self .loss .eval ()
135
+
126
136
127
137
class Trainer :
128
138
def __init__ (
@@ -145,12 +155,14 @@ def __init__(
145
155
self .steps_since_update = 0
146
156
147
157
def train (self ):
148
- self .executor .model .train ()
158
+ self .executor .train ()
159
+ if self .use_ema :
160
+ self .ema_executor .train ()
149
161
150
162
def eval (self ):
151
- self .executor .model . eval ()
163
+ self .executor .eval ()
152
164
if self .use_ema :
153
- self .executor . model .eval ()
165
+ self .ema_executor .eval ()
154
166
155
167
def train_step (self , input , target , step = None ):
156
168
loss = self .executor .forward_backward (input , target )
You can’t perform that action at this time.
0 commit comments