5
5
import torch
6
6
import torch .nn as nn
7
7
from mmengine import ConfigDict
8
- from torch .utils .data import DataLoader , Dataset
9
8
10
9
from mmseg .apis import MMSegInferencer
11
10
from mmseg .models import EncoderDecoder
@@ -46,33 +45,8 @@ def __init__(self, **kwargs):
46
45
super ().__init__ (** kwargs )
47
46
48
47
49
- class ExampleDataset (Dataset ):
50
-
51
- def __init__ (self ) -> None :
52
- super ().__init__ ()
53
- self .pipeline = [
54
- dict (type = 'LoadImageFromFile' ),
55
- dict (type = 'LoadAnnotations' ),
56
- dict (type = 'PackSegInputs' )
57
- ]
58
-
59
- def __getitem__ (self , idx ):
60
- return dict (img = torch .tensor ([1 ]), img_metas = dict ())
61
-
62
- def __len__ (self ):
63
- return 1
64
-
65
-
66
48
def test_inferencer ():
67
49
register_all_modules ()
68
- test_dataset = ExampleDataset ()
69
- data_loader = DataLoader (
70
- test_dataset ,
71
- batch_size = 1 ,
72
- sampler = None ,
73
- num_workers = 0 ,
74
- shuffle = False ,
75
- )
76
50
77
51
visualizer = dict (
78
52
type = 'SegLocalVisualizer' ,
@@ -87,7 +61,14 @@ def test_inferencer():
87
61
decode_head = dict (type = 'InferExampleHead' ),
88
62
test_cfg = dict (mode = 'whole' )),
89
63
visualizer = visualizer ,
90
- test_dataloader = data_loader )
64
+ test_dataloader = dict (
65
+ dataset = dict (
66
+ type = 'ExampleDataset' ,
67
+ pipeline = [
68
+ dict (type = 'LoadImageFromFile' ),
69
+ dict (type = 'LoadAnnotations' ),
70
+ dict (type = 'PackSegInputs' )
71
+ ]), ))
91
72
cfg = ConfigDict (cfg_dict )
92
73
model = MODELS .build (cfg .model )
93
74
0 commit comments