We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d705d67 commit 6dcbc22Copy full SHA for 6dcbc22
timm/kd/distillation.py
@@ -128,7 +128,7 @@ def apply_kd_loss(
128
prob_s = torch.nn.functional.log_softmax(student_output, dim=-1)
129
130
# Teacher probability calculation
131
- with torch.inference_mode():
+ with torch.no_grad():
132
input_kd = teacher_model.normalize_input(input, student_model)
133
out_t = teacher_model.model(input_kd.detach())
134
prob_t = torch.nn.functional.softmax(out_t, dim=-1)
0 commit comments