Skip to content

Commit d8930fc

Browse files
committed
adding args of datapath
1 parent a03ff9b commit d8930fc

File tree

8 files changed

+30
-36
lines changed

8 files changed

+30
-36
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ bash setup.sh
5252
### 2) Template Generation
5353
**Original template set on MVTec AD:**
5454
```bash
55-
python run.py --mode temp --ttype ALL --dataset MVTec_AD
55+
python run.py --mode temp --ttype ALL --dataset MVTec_AD --datapath <data_path>
5656
```
5757
**Tiny set formed by PTS (60 sheets) on MVTec AD:**
5858
```bash
59-
python run.py --mode temp --ttype PTS --tsize 60 --dataset MVTec_AD
59+
python run.py --mode temp --ttype PTS --tsize 60 --dataset MVTec_AD --datapath <data_path>
6060
```
6161
Since generating pixel-level OPTICS clusters is time-consuming, you can download the "*template*" folder from [Google Drive](https://drive.google.com/drive/folders/1c4XvmugX-ryP168bDMFcScdiYWgYktlu?usp=drive_link) / [Baidu Cloud](https://pan.baidu.com/s/1HH_3FQo1K72HbUvZpfylxw?pwd=eeg9) and copy it into our main folder as:
6262
```
@@ -71,11 +71,11 @@ HETMM/
7171
### 3) Anomaly Prediction
7272
**Original template set on MVTec AD:**
7373
```bash
74-
python run.py --mode test --ttype ALL --dataset MVTec_AD
74+
python run.py --mode test --ttype ALL --dataset MVTec_AD --datapath <data_path>
7575
```
7676
**Tiny set formed by PTS (60 sheets) on MVTec AD:**
7777
```bash
78-
python run.py --mode test --ttype PTS --tsize 60 --dataset MVTec_AD
78+
python run.py --mode test --ttype PTS --tsize 60 --dataset MVTec_AD --datapath <data_path>
7979
```
8080
Please see "*run.sh*" and "*run.py*" for more details.
8181

configs/MVTec_AD.yaml

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1+
dataset: MVTec_AD
12
backbone: Wide_ResNet101_2
3+
temp_size: None
4+
temp_type: ALL
5+
test_batch_size: 32
6+
metric_functions :
7+
- img_AUC
8+
- pix_AUC
9+
- PRO
210
categories:
311
- carpet
412
- grid
@@ -15,7 +23,6 @@ categories:
1523
- toothbrush
1624
- transistor
1725
- zipper
18-
dataset: MVTec_AD
1926
mparams:
2027
bottle: L1:5x3_a=0.3_l=1|L2:5x3_a=0.3_l=1|L3:7x3_a=0.3_l=0.8;img_AUC:clean|pix_AUC:blur_s=4.1_k=39|PRO:clean
2128
cable: L1:13x7_a=0.7_l=1|L2:5x5_a=0.7_l=0.8|L3:7x1_a=0.3_l=0.8;img_AUC:blur_s=5.5_k=39|pix_AUC:blur_s=8.600000000000001_k=61|PRO:clean
@@ -31,12 +38,4 @@ mparams:
3138
toothbrush: L1:13x3_a=0.3_l=1|L2:9x5_a=0.3_l=0.8|L3:7x3_a=0.3_l=0.5;img_AUC:clean|pix_AUC:blur_s=3.0_k=11|PRO:clean
3239
transistor: L1:13x5_a=0.3_l=1|L2:5x3_a=0.5_l=0.5|L3:5x1_a=0.3_l=0.8;img_AUC:blur_s=5.1_k=31|pix_AUC:blur_s=15.0_k=61|PRO:blur_s=6.6_k=39
3340
wood: L1:7x3_a=0.7_l=1|L2:5x3_a=0.7_l=0.8|L3:7x1_a=0.7_l=0.3;img_AUC:clean|pix_AUC:clean|PRO:clean
34-
zipper: L1:13x5_a=0.3_l=1|L2:7x5_a=0.7_l=1|L3:7x1_a=0.7_l=0.3;img_AUC:mp_k=11|pix_AUC:blur_s=10.4_k=61|PRO:blur_s=6.0_k=31
35-
root: /media/data4/chenzx/dataset/anomaly/mvtec_anomaly_detection
36-
temp_size: None
37-
temp_type: ALL
38-
test_batch_size: 32
39-
metric_functions :
40-
- img_AUC
41-
- pix_AUC
42-
- PRO
41+
zipper: L1:13x5_a=0.3_l=1|L2:7x5_a=0.7_l=1|L3:7x1_a=0.7_l=0.3;img_AUC:mp_k=11|pix_AUC:blur_s=10.4_k=61|PRO:blur_s=6.0_k=31

configs/base.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
data :
22
name : $dataset
3-
root : $root
3+
datapath : $datapath
44
out_size : $out_size
55
normalize:
66
mean : [0.485, 0.456, 0.406]

run.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def argParse():
1616
parser.add_argument('--method', default='ATMM')
1717
parser.add_argument('--ttype', choices=['ALL', 'PTS'], default='ALL')
1818
parser.add_argument('--tsize', type=int, default=0)
19+
parser.add_argument('--datapath', help='your own data path')
1920
parser.add_argument('--dataset', type=str, default='MVTec_AD')
2021
parser.add_argument('--categories', type=str, nargs='+', default=None)
2122
parser.add_argument('--half', action='store_true')
@@ -55,7 +56,6 @@ def temp(cfg):
5556
def get_ALL(cfg, tpath):
5657
try:
5758
tdict = cfg.model.load_template(os.path.join(tpath, f'{cfg.model.backbone.lower()}_ALL.pkl'))
58-
print ('Found the original template!')
5959

6060
except:
6161
tdict = tl.gen_by_ALL(cfg.model, cfg.temploader, tpath, cfg.model.backbone.lower(), cfg.half, save=True)
@@ -68,7 +68,6 @@ def get_ALL(cfg, tpath):
6868
else:
6969
try:
7070
tdict = cfg.model.load_template(os.path.join(tpath, tname))
71-
print (f'Found the {cfg.ttype}x{cfg.tsize}!')
7271

7372
except:
7473
tdict = getattr(tl, f'gen_by_{cfg.ttype}')(get_ALL(cfg, tpath), cfg.tsize, tpath, cfg.model.backbone.lower(), num_workers=cfg.num_workers, save=True)
@@ -78,12 +77,9 @@ def get_ALL(cfg, tpath):
7877
if __name__ == '__main__':
7978
args = argParse()
8079
cfg = Cfg(args)
81-
categories = tqdm(cfg.categories) if args.mode == 'test' else cfg.categories
80+
categories = tqdm(cfg.categories)
8281
for category in categories:
83-
if args.mode == 'test':
84-
categories.set_description(category)
85-
else:
86-
print (category)
82+
categories.set_description(category)
8783
cfg.update(category)
8884
globals()[args.mode](cfg)
8985

run.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ ts=60
44

55
# Generating template
66
# Original template
7-
python run.py --mode temp --ttype ALL --dataset $dataset
7+
python run.py --mode temp --ttype ALL --dataset $dataset --datapath <data_path>
88
# PTS template
9-
python run.py --mode temp --ttype $tt --tsize $ts --dataset $dataset
9+
python run.py --mode temp --ttype $tt --tsize $ts --dataset $dataset --datapath <data_path>
1010

1111
# Anomaly detection and localization
1212
# Original template
13-
python run.py --mode test --ttype ALL --dataset $dataset
13+
python run.py --mode test --ttype ALL --dataset $dataset --datapath <data_path>
1414
# PTS template
15-
python run.py --mode test --ttype $tt --tsize $ts --dataset $dataset
15+
python run.py --mode test --ttype $tt --tsize $ts --dataset $dataset --datapath <data_path>

src/config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@ def update(self, category):
3838
os.makedirs(os.path.join(getattr(self, key), self.dataset, self.category), exist_ok=True)
3939

4040
for func in ['data', 'template']:
41-
try:
42-
getattr(self, f'load_{func}')()
43-
except:
44-
pass
41+
getattr(self, f'load_{func}')()
42+
# except:
43+
# pass
4544

4645
self.model.update(category)
4746

src/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ def __init__(self, **params):
5858
super(MVTec_AD, self).__init__(params)
5959

6060
def load_data(self):
61-
self.img_path = os.path.join(self.root, self.category, self.mode if self.mode == 'test' else 'train')
61+
self.img_path = os.path.join(self.datapath, self.category, self.mode if self.mode == 'test' else 'train')
6262
if self.mode == 'test':
63-
self.gt_path = os.path.join(self.root, self.category, 'ground_truth')
63+
self.gt_path = os.path.join(self.datapath, self.category, 'ground_truth')
6464
self.img_paths, self.gt_paths, self.labels, self.types = [], [], [], []
6565
for defect_type in filter(lambda x : os.path.isdir(os.path.join(self.img_path, x)), os.listdir(self.img_path)):
6666
img_paths = [os.path.join(self.img_path, defect_type, x) for x in sorted(filter(lambda x : x.endswith('.png'), \

src/template.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
def gen_by_ALL(model, temploader, tpath, backbone, half=False, save=True):
1515
Out_dict = {}
16-
print (f'Generating the original template')
16+
# print (f'Generating the original template')
1717
with torch.no_grad():
1818
for batch in tqdm(temploader):
1919
x = batch[0].cuda().half() if half else batch[0].cuda()
@@ -29,7 +29,7 @@ def gen_by_ALL(model, temploader, tpath, backbone, half=False, save=True):
2929
def gen_by_PC_OPTICS(tdict, tsize, tpath, backbone, num_workers, save=True, **kwargs):
3030
clu_path = os.path.join(tpath, f'{backbone}_PC_OPTICSx{tsize}.pkl')
3131
if os.path.exists(clu_path):
32-
print ('Found pretrained OPTICS clusters!')
32+
# print ('Found pretrained OPTICS clusters!')
3333
clu_dict = torch.load(clu_path, weights_only=True, map_location='cpu')
3434

3535
else:
@@ -45,7 +45,7 @@ def gen_by_PTS(tdict, tsize, tpath, backbone, num_workers, save=True, **kwargs):
4545
clu_dict = gen_by_PC_OPTICS(tdict, tsize, tpath, backbone, num_workers, save=True, **kwargs)
4646

4747
Out_dict = {}
48-
print ('Generating PTS...')
48+
# print ('Generating PTS...')
4949
for key, Ts in tqdm(tdict.items()):
5050
Out_dict[key] = PTS(Ts, clu_dict[key].to(Ts.device).long(), int(tsize))
5151

@@ -62,7 +62,7 @@ def get_optics_clusters_unit(pixel, min_samples, metric):
6262

6363
clu_dict = {}
6464

65-
print ('Generating OPTICS clusters (very slow)...')
65+
# print ('Generating OPTICS clusters (very slow)...')
6666
for key, Ts in tqdm(tdict.items()):
6767
N, C, H, W = Ts.shape
6868
pixels = F.unfold(Ts, 1).cpu().numpy()

0 commit comments

Comments
 (0)