Skip to content

Commit b2f4b4f

Browse files
authored
[Fix] fix inferencer ut (#3117)
1 parent c30d506 commit b2f4b4f

File tree

1 file changed

+8
-27
lines changed

1 file changed

+8
-27
lines changed

tests/test_apis/test_inferencer.py

+8-27
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
import torch.nn as nn
77
from mmengine import ConfigDict
8-
from torch.utils.data import DataLoader, Dataset
98

109
from mmseg.apis import MMSegInferencer
1110
from mmseg.models import EncoderDecoder
@@ -46,33 +45,8 @@ def __init__(self, **kwargs):
4645
super().__init__(**kwargs)
4746

4847

49-
class ExampleDataset(Dataset):
50-
51-
def __init__(self) -> None:
52-
super().__init__()
53-
self.pipeline = [
54-
dict(type='LoadImageFromFile'),
55-
dict(type='LoadAnnotations'),
56-
dict(type='PackSegInputs')
57-
]
58-
59-
def __getitem__(self, idx):
60-
return dict(img=torch.tensor([1]), img_metas=dict())
61-
62-
def __len__(self):
63-
return 1
64-
65-
6648
def test_inferencer():
6749
register_all_modules()
68-
test_dataset = ExampleDataset()
69-
data_loader = DataLoader(
70-
test_dataset,
71-
batch_size=1,
72-
sampler=None,
73-
num_workers=0,
74-
shuffle=False,
75-
)
7650

7751
visualizer = dict(
7852
type='SegLocalVisualizer',
@@ -87,7 +61,14 @@ def test_inferencer():
8761
decode_head=dict(type='InferExampleHead'),
8862
test_cfg=dict(mode='whole')),
8963
visualizer=visualizer,
90-
test_dataloader=data_loader)
64+
test_dataloader=dict(
65+
dataset=dict(
66+
type='ExampleDataset',
67+
pipeline=[
68+
dict(type='LoadImageFromFile'),
69+
dict(type='LoadAnnotations'),
70+
dict(type='PackSegInputs')
71+
]), ))
9172
cfg = ConfigDict(cfg_dict)
9273
model = MODELS.build(cfg.model)
9374

0 commit comments

Comments
 (0)