-
Notifications
You must be signed in to change notification settings - Fork 3k
Closed
Labels
Description
问题确认 Search before asking
- 我已经搜索过问题,但是没有找到解答。I have searched the question and found no related answer.
请提出你的问题 Please ask your question
你好,
我相同picodet l蒸馏picode xs,以下是我的步骤
- 训练得到picodet_l模型
- 写slim配置文件,文件内容如下
_BASE_: [
'../../picodet/picodet_l_320_coco_lcnet.yml',
]
pretrain_weights: output_picodet_l_320_teacher/115.pdparams
slim: Distill
distill_loss: KnowledgeDistillationKLDivLoss
- 开始蒸馏,训练命令:
python tools/train.py -c configs/picodet/picodet_xs_320_coco_lcnet.yml --slim_config configs/slim/distill/picode_xs_distill.yml --eval
结果报错:
Traceback (most recent call last):
File "tools/train.py", line 212, in <module>
main()
File "tools/train.py", line 208, in main
run(FLAGS, cfg)
File "tools/train.py", line 161, in run
trainer.train(FLAGS.eval)
File "/ossfs/workspace/code/paddledetection/ppdet/engine/trainer.py", line 623, in train
outputs = model(data)
File "/opt/conda/lib/python3.8/site-packages/paddle/nn/layer/layers.py", line 1429, in __call__
return self.forward(*inputs, **kwargs)
File "/ossfs/workspace/code/paddledetection/ppdet/slim/distill_model.py", line 97, in forward
teacher_loss = self.teacher_model(inputs)
File "/opt/conda/lib/python3.8/site-packages/paddle/nn/layer/layers.py", line 1429, in __call__
return self.forward(*inputs, **kwargs)
File "/ossfs/workspace/code/paddledetection/ppdet/modeling/architectures/meta_arch.py", line 60, in forward
out = self.get_loss()
File "/ossfs/workspace/code/paddledetection/ppdet/modeling/architectures/picodet.py", line 82, in get_loss
head_outs, _ = self._forward()
File "/ossfs/workspace/code/paddledetection/ppdet/modeling/architectures/picodet.py", line 66, in _forward
fpn_feats = self.neck(body_feats)
File "/opt/conda/lib/python3.8/site-packages/paddle/nn/layer/layers.py", line 1429, in __call__
return self.forward(*inputs, **kwargs)
File "/ossfs/workspace/code/paddledetection/ppdet/modeling/necks/lc_pan.py", line 126, in forward
assert len(inputs) == len(self.in_channels)
AssertionError
有什么解决办法吗?我单独训练picodet l和picodet xs的时候都是好的