Skip to content

Commit 6dcbc22

Browse files
committed
Keep as no_grad
1 parent d705d67 commit 6dcbc22

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

timm/kd/distillation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def apply_kd_loss(
128128
prob_s = torch.nn.functional.log_softmax(student_output, dim=-1)
129129

130130
# Teacher probability calculation
131-
with torch.inference_mode():
131+
with torch.no_grad():
132132
input_kd = teacher_model.normalize_input(input, student_model)
133133
out_t = teacher_model.model(input_kd.detach())
134134
prob_t = torch.nn.functional.softmax(out_t, dim=-1)

0 commit comments

Comments
 (0)