-
Notifications
You must be signed in to change notification settings - Fork 3k
Open
Labels
Description
利用paddle.summary()函数数想显示ppyolo_r50vd_dcn_1x_visdrone(configs/sniper/ppyolo_r50vd_dcn_1x_visdrone.yml)模型的参数,但是给出了下面的错误提示:
---------------print params-----------------
Traceback (most recent call last):
File "tools/train.py", line 171, in <module>
main()
File "tools/train.py", line 167, in main
run(FLAGS, cfg)
File "tools/train.py", line 118, in run
trainer = Trainer(cfg, mode='train')
File "/PaddleDetectionNew/ppdet/engine/trainer.py", line 90, in __init__
self.model = create(cfg.architecture)
File "/PaddleDetectionNew/ppdet/core/workspace.py", line 275, in create
return cls(**cls_kwargs)
File "/PaddleDetectionNew/ppdet/modeling/architectures/yolo.py", line 61, in __init__
params_info = paddle.summary(model, (1, 3, 320, 320))
File "/anaconda3/envs/dxl/lib/python3.7/site-packages/paddle/hapi/model_summary.py", line 223, in summary
result, params_info = summary_string(net, _input_size, dtypes, input)
File "<decorator-gen-278>", line 2, in summary_string
File "/anaconda3/envs/dxl/lib/python3.7/site-packages/paddle/fluid/dygraph/base.py", line 331, in _decorate_function
return func(*args, **kwargs)
File "/anaconda3/envs/dxl/lib/python3.7/site-packages/paddle/hapi/model_summary.py", line 353, in summary_string
model(*x)
File "/anaconda3/envs/dxl/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py", line 914, in __call__
outputs = self.forward(*inputs, **kwargs)
File "/PaddleDetectionNew/ppdet/modeling/backbones/resnet.py", line 581, in forward
x = inputs['image']
File "/anaconda3/envs/dxl/lib/python3.7/site-packages/paddle/fluid/dygraph/varbase_patch_methods.py", line 598, in __getitem__
return self._getitem_index_not_tensor(item)
ValueError: (InvalidArgument) Currently, Tensor.__indices__() only allows indexing by Integers, Slices, Ellipsis, None, tuples of these types and list of Bool and Integers, but received str in 1th slice item (at /paddle/paddle/fluid/pybind/imperative.cc:637)添加的代码为:
@register
class YOLOv3(BaseArch):
__category__ = 'architecture'
__shared__ = ['data_format']
__inject__ = ['post_process']
def __init__(self,
backbone='DarkNet',
neck='YOLOv3FPN',
yolo_head='YOLOv3Head',
post_process='BBoxPostProcess',
data_format='NCHW',
for_mot=False):
"""
YOLOv3 network, see https://arxiv.org/abs/1804.02767
Args:
backbone (nn.Layer): backbone instance
neck (nn.Layer): neck instance
yolo_head (nn.Layer): anchor_head instance
bbox_post_process (object): `BBoxPostProcess` instance
data_format (str): data format, NCHW or NHWC
for_mot (bool): whether return other features for multi-object tracking
models, default False in pure object detection models.
"""
super(YOLOv3, self).__init__(data_format=data_format)
self.backbone = backbone
self.neck = neck
self.yolo_head = yolo_head
self.post_process = post_process
self.for_mot = for_mot
self.return_idx = isinstance(post_process, JDEBBoxPostProcess)
print("---------------print params-----------------")
model = self.backbone
params_info = paddle.summary(model, (1, 3, 320, 320))
print(params_info)就是最后几行代码,从print()函数到结束。我也查过贵方的issue中关于打印模型结构的问题,但是仍然没有找到怎么解决。
参考的api文档:https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/summary_cn.html。
问题:想知道应该怎么使用这个函数来显示模型结构,就那上面这个ppyolo_r50vd_dcn_1x_visdrone模型来举例?
paddlepaddle:2.2.0;PaddleDetection:2.3