Skip to content

使用paddle.summary()显示模型结构参数等信息 #4970

@edwinlong95

Description

@edwinlong95

利用paddle.summary()函数数想显示ppyolo_r50vd_dcn_1x_visdrone(configs/sniper/ppyolo_r50vd_dcn_1x_visdrone.yml)模型的参数,但是给出了下面的错误提示:

---------------print params-----------------
Traceback (most recent call last):
  File "tools/train.py", line 171, in <module>
    main()
  File "tools/train.py", line 167, in main
    run(FLAGS, cfg)
  File "tools/train.py", line 118, in run
    trainer = Trainer(cfg, mode='train')
  File "/PaddleDetectionNew/ppdet/engine/trainer.py", line 90, in __init__
    self.model = create(cfg.architecture)
  File "/PaddleDetectionNew/ppdet/core/workspace.py", line 275, in create
    return cls(**cls_kwargs)
  File "/PaddleDetectionNew/ppdet/modeling/architectures/yolo.py", line 61, in __init__
    params_info = paddle.summary(model, (1, 3, 320, 320))
  File "/anaconda3/envs/dxl/lib/python3.7/site-packages/paddle/hapi/model_summary.py", line 223, in summary
    result, params_info = summary_string(net, _input_size, dtypes, input)
  File "<decorator-gen-278>", line 2, in summary_string
  File "/anaconda3/envs/dxl/lib/python3.7/site-packages/paddle/fluid/dygraph/base.py", line 331, in _decorate_function
    return func(*args, **kwargs)
  File "/anaconda3/envs/dxl/lib/python3.7/site-packages/paddle/hapi/model_summary.py", line 353, in summary_string
    model(*x)
  File "/anaconda3/envs/dxl/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py", line 914, in __call__
    outputs = self.forward(*inputs, **kwargs)
  File "/PaddleDetectionNew/ppdet/modeling/backbones/resnet.py", line 581, in forward
    x = inputs['image']
  File "/anaconda3/envs/dxl/lib/python3.7/site-packages/paddle/fluid/dygraph/varbase_patch_methods.py", line 598, in __getitem__
    return self._getitem_index_not_tensor(item)
ValueError: (InvalidArgument) Currently, Tensor.__indices__() only allows indexing by Integers, Slices, Ellipsis, None, tuples of these types and list of Bool and Integers, but received str in 1th slice item (at /paddle/paddle/fluid/pybind/imperative.cc:637)

添加的代码为:

@register
class YOLOv3(BaseArch):
    __category__ = 'architecture'
    __shared__ = ['data_format']
    __inject__ = ['post_process']

    def __init__(self,
                 backbone='DarkNet',
                 neck='YOLOv3FPN',
                 yolo_head='YOLOv3Head',
                 post_process='BBoxPostProcess',
                 data_format='NCHW',
                 for_mot=False):
        """
        YOLOv3 network, see https://arxiv.org/abs/1804.02767

        Args:
            backbone (nn.Layer): backbone instance
            neck (nn.Layer): neck instance
            yolo_head (nn.Layer): anchor_head instance
            bbox_post_process (object): `BBoxPostProcess` instance
            data_format (str): data format, NCHW or NHWC
            for_mot (bool): whether return other features for multi-object tracking
                models, default False in pure object detection models.
        """
        super(YOLOv3, self).__init__(data_format=data_format)
        self.backbone = backbone
        self.neck = neck
        self.yolo_head = yolo_head
        self.post_process = post_process
        self.for_mot = for_mot
        self.return_idx = isinstance(post_process, JDEBBoxPostProcess)
        print("---------------print params-----------------")
        model = self.backbone
        params_info = paddle.summary(model, (1, 3, 320, 320))
        print(params_info)

就是最后几行代码,从print()函数到结束。我也查过贵方的issue中关于打印模型结构的问题,但是仍然没有找到怎么解决。
参考的api文档:https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/summary_cn.html。
问题:想知道应该怎么使用这个函数来显示模型结构,就那上面这个ppyolo_r50vd_dcn_1x_visdrone模型来举例?

paddlepaddle:2.2.0;PaddleDetection:2.3

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions