Skip to content

Commit 2f62684

Browse files
committed
fix bug:预测图像没有在cuda上
1 parent 5b672d3 commit 2f62684

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

rest.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,13 @@ def inference(image, h, w):
5858
:return: text
5959
"""
6060
image = torch.FloatTensor(image)
61+
image = image.to(device)
6162

6263
if h > w:
6364
predict = v_net(image)[0].detach().cpu().numpy() # [W,num_classes]
6465
else:
6566
predict = h_net(image)[0].detach().cpu().numpy() # [W,num_classes]
6667

67-
image.to(device)
68-
6968
label = np.argmax(predict[:], axis=1)
7069
label = [alpha[class_id] for class_id in label]
7170
label = [k for k, g in itertools.groupby(list(label))]

0 commit comments

Comments
 (0)