Skip to content

Commit 1dd1771

Browse files
committed
去除Lossy conversion from float32 to uint8 警告
1 parent 227b7c9 commit 1dd1771

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

main.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77
入口类
88
99
"""
10-
import re
10+
import argparse
1111
import os
12+
import re
13+
14+
import cv2
1215
import numpy as np
1316
import torch
17+
from skimage import io
1418
from torch import nn
1519
from torchvision import models
16-
import argparse
17-
from skimage import io
18-
import cv2
20+
1921
from interpretability.grad_cam import GradCAM, GradCamPlusPlus
2022
from interpretability.guided_back_propagation import GuidedBackPropagation
2123

@@ -105,7 +107,7 @@ def gen_cam(image, mask):
105107

106108
# 合并heatmap到原始图像
107109
cam = heatmap + np.float32(image)
108-
return norm_image(cam), heatmap
110+
return norm_image(cam), (heatmap * 255).astype(np.uint8)
109111

110112

111113
def norm_image(image):
@@ -166,7 +168,7 @@ def main(args):
166168
grad = gbp(inputs)
167169

168170
gb = gen_gb(grad)
169-
image_dict['gb'] = gb
171+
image_dict['gb'] = norm_image(gb)
170172
# 生成Guided Grad-CAM
171173
cam_gb = gb * mask[..., np.newaxis]
172174
image_dict['cam_gb'] = norm_image(cam_gb)

0 commit comments

Comments
 (0)