Skip to content

picodet_l蒸馏picodet_xs报错 #9252

@qiuzhewei

Description

@qiuzhewei

问题确认 Search before asking

  • 我已经搜索过问题,但是没有找到解答。I have searched the question and found no related answer.

请提出你的问题 Please ask your question

你好,
我相同picodet l蒸馏picode xs,以下是我的步骤

  1. 训练得到picodet_l模型
  2. 写slim配置文件,文件内容如下
_BASE_: [
  '../../picodet/picodet_l_320_coco_lcnet.yml',
]

pretrain_weights: output_picodet_l_320_teacher/115.pdparams


slim: Distill
distill_loss: KnowledgeDistillationKLDivLoss
  1. 开始蒸馏,训练命令:
    python tools/train.py -c configs/picodet/picodet_xs_320_coco_lcnet.yml --slim_config configs/slim/distill/picode_xs_distill.yml --eval

结果报错:

Traceback (most recent call last):
  File "tools/train.py", line 212, in <module>
    main()
  File "tools/train.py", line 208, in main
    run(FLAGS, cfg)
  File "tools/train.py", line 161, in run
    trainer.train(FLAGS.eval)
  File "/ossfs/workspace/code/paddledetection/ppdet/engine/trainer.py", line 623, in train
    outputs = model(data)
  File "/opt/conda/lib/python3.8/site-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/ossfs/workspace/code/paddledetection/ppdet/slim/distill_model.py", line 97, in forward
    teacher_loss = self.teacher_model(inputs)
  File "/opt/conda/lib/python3.8/site-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/ossfs/workspace/code/paddledetection/ppdet/modeling/architectures/meta_arch.py", line 60, in forward
    out = self.get_loss()
  File "/ossfs/workspace/code/paddledetection/ppdet/modeling/architectures/picodet.py", line 82, in get_loss
    head_outs, _ = self._forward()
  File "/ossfs/workspace/code/paddledetection/ppdet/modeling/architectures/picodet.py", line 66, in _forward
    fpn_feats = self.neck(body_feats)
  File "/opt/conda/lib/python3.8/site-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/ossfs/workspace/code/paddledetection/ppdet/modeling/necks/lc_pan.py", line 126, in forward
    assert len(inputs) == len(self.in_channels)
AssertionError

有什么解决办法吗?我单独训练picodet l和picodet xs的时候都是好的

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions