3
3
import inspect
4
4
from typing import Callable
5
5
6
+ from torch import Tensor
7
+
6
8
from lmdeploy .utils import get_logger
7
9
8
10
from ..devices import DeviceContext , get_device_manager
@@ -64,6 +66,8 @@ def __init__(self, func_name: str):
64
66
self .func_name = func_name
65
67
self .dispatched_func = self .load_and_call
66
68
self .device_manager .register_context_callback (self .device_callback )
69
+ self .device_type = None
70
+ self .device_map = {'cuda' : 'cuda' , 'npu' : 'dlinfer' , 'maca' : 'dlinfer' , 'camb' : 'dlinfer' }
67
71
68
72
def device_callback (self , context : DeviceContext ):
69
73
"""device context callback."""
@@ -88,7 +92,11 @@ def load_func(self, device: str):
88
92
89
93
def load_and_call (self , * args , ** kwargs ):
90
94
"""load and call."""
91
- device = self .device_manager .current_context ().device_type
95
+ if self .device_type is None :
96
+ device_type = self .device_manager .current_context ().device_type
97
+ self .device_type = next (
98
+ (arg .device .type for arg in args if isinstance (arg , Tensor ) and arg .device .type != 'cpu' ), device_type )
99
+ device = self .device_map [self .device_type ]
92
100
if device not in self .impl_map :
93
101
self .load_func (device )
94
102
self .dispatched_func = self .impl_map [device ]
0 commit comments