|
30 | 30 | from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
31 | 31 | from torch.fx.passes.operator_support import OperatorSupport
|
32 | 32 | from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
|
| 33 | +from torch.utils import _pytree |
33 | 34 |
|
34 | 35 | import onnxruntime # type: ignore
|
35 | 36 | from onnxruntime.capi import _pybind_state as ORTC
|
@@ -199,16 +200,59 @@ def _create_onnx_model(onnx_proto):
|
199 | 200 | return onnx.ModelProto.FromString(onnx_proto)
|
200 | 201 |
|
201 | 202 |
|
202 |
| -def _create_onnx_session(onnx_proto, ep: str, session_options): |
| 203 | +def _create_onnx_session(onnx_proto, eps: Tuple[str, ...], session_options): |
203 | 204 | # TODO(wechi): Add more EPs per PyTorch device types.
|
204 | 205 | # TODO(wechi): enable external allocators.
|
205 |
| - return onnxruntime.InferenceSession(onnx_proto, providers=[ep], sess_options=session_options) |
206 |
| - |
207 |
| - |
208 |
| -def _infer_ep_from_device(device): |
209 |
| - if device.type == "cuda": |
210 |
| - return "CUDAExecutionProvider" |
211 |
| - return "CPUExecutionProvider" |
| 206 | + return onnxruntime.InferenceSession(onnx_proto, providers=eps, sess_options=session_options) |
| 207 | + |
| 208 | + |
| 209 | +def _infer_ep_from_device(*args) -> Tuple[str, ...]: |
| 210 | + """Return the first valid device (i.e., GPU or CPU) in argument list.""" |
| 211 | + eps = [] |
| 212 | + for arg in args: |
| 213 | + if hasattr(arg, "device"): |
| 214 | + device = arg.device |
| 215 | + if device.type == "cuda": |
| 216 | + eps.append("CUDAExecutionProvider") |
| 217 | + elif device.type == "cpu": |
| 218 | + eps.append("CPUExecutionProvider") |
| 219 | + return tuple(eps) |
| 220 | + |
| 221 | + |
| 222 | +def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> Tuple[str, ...]: |
| 223 | + """Return the first valid device (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule.""" |
| 224 | + for node in graph_module.graph.nodes: |
| 225 | + if node.op == "output": |
| 226 | + # Output node is unique. Let's retrieve output values from |
| 227 | + # this node's input list. And then just return. |
| 228 | + flattened_output_args, _ = _pytree.tree_flatten(node.args) |
| 229 | + output_args = [] |
| 230 | + for output_arg in flattened_output_args: |
| 231 | + if hasattr(output_arg, "meta") and "val" in output_arg.meta: |
| 232 | + # Select outputs with "val" information. Without "val", |
| 233 | + # it's not possible access output_arg.meta["val"].device. |
| 234 | + output_args.append(output_arg.meta["val"]) |
| 235 | + return _infer_ep_from_device(*output_args) |
| 236 | + graph_module_str = graph_module.print_readable(print_output=False) |
| 237 | + raise ValueError(f"No output node is found in graph_module: {graph_module_str}") |
| 238 | + |
| 239 | + |
| 240 | +def _sort_eps(eps: Tuple[str, ...]) -> Tuple[str, ...]: |
| 241 | + """Sort execution providers in eps based on pre-set priority.""" |
| 242 | + |
| 243 | + def get_execution_provider_priority(ep: str) -> int: |
| 244 | + if ep == "CPUExecutionProvider": |
| 245 | + # Lowest priority. |
| 246 | + return 2 |
| 247 | + if ep == "CUDAExecutionProvider": |
| 248 | + # Higher priority than CPU but lower than |
| 249 | + # other specialized EPs. |
| 250 | + return 1 |
| 251 | + # Highest priority. |
| 252 | + return 0 |
| 253 | + |
| 254 | + unique_eps = set(eps) |
| 255 | + return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True)) |
212 | 256 |
|
213 | 257 |
|
214 | 258 | def _get_onnx_devices(values: Tuple[torch.Tensor, ...]) -> Tuple[ORTC.OrtDevice, ...]: # type: ignore
|
@@ -346,7 +390,7 @@ class OrtBackend:
|
346 | 390 | 3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph.
|
347 | 391 | """
|
348 | 392 |
|
349 |
| - def __init__(self, ep: str = "", preallocate_output: bool = False, session_options=None): |
| 393 | + def __init__(self, ep: str = "CPUExecutionProvider", preallocate_output: bool = False, session_options=None): |
350 | 394 | self._supported_ops = OrtOperatorSupport()
|
351 | 395 | # TODO: this is a naive implementation of cache without proper guard
|
352 | 396 | self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {}
|
@@ -418,11 +462,28 @@ def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwar
|
418 | 462 | ).SerializeToString()
|
419 | 463 |
|
420 | 464 | # Initialize a ORT session to execute this ONNX model.
|
421 |
| - # TorchDynamo assumes all inputs/outputs are on the same device, |
422 |
| - # so we add execution provider only based on the first input's device. |
423 |
| - ep = self.ep or _infer_ep_from_device(args[0].device) |
| 465 | + # Note that TorchDynamo assumes all inputs/outputs are on the |
| 466 | + # same device, but it's subject to change (very likely with |
| 467 | + # dynamic shape support), so we add execution providers |
| 468 | + # based on the all inputs/outputs plus a default OrtBackend.ep. |
| 469 | + eps_from_args = _infer_ep_from_device(args) |
| 470 | + eps_from_graph_module = _infer_ep_from_graph_module(graph_module) |
| 471 | + if eps_from_args: |
| 472 | + # If user feeds CUDA tensor as input argument, |
| 473 | + # we want to use CUDA EP. |
| 474 | + # Thus, `eps_from_args` (deduced from input arguments) |
| 475 | + # has highest priority. |
| 476 | + selected_eps = _sort_eps((*eps_from_args, self.ep)) |
| 477 | + elif eps_from_graph_module: |
| 478 | + # If there is no EP in input arguments, we deduce EP from |
| 479 | + # graph_module's outputs. Those outputs may come from |
| 480 | + # FakeTensorProp or Dynamo's built-in symbolic shape inference. |
| 481 | + selected_eps = _sort_eps((*eps_from_graph_module, self.ep)) |
| 482 | + else: |
| 483 | + # No EP found in inputs and outputs, let's use default. |
| 484 | + selected_eps = (self.ep,) |
424 | 485 |
|
425 |
| - onnx_session = _create_onnx_session(onnx_proto, ep, self.session_options) |
| 486 | + onnx_session = _create_onnx_session(onnx_proto, selected_eps, self.session_options) |
426 | 487 | # Cache ORT session. It's reused for the same "graph_module".
|
427 | 488 | self._ort_execution_info.sessions[graph_module] = onnx_session
|
428 | 489 | # Generate ONNX model and extract its input and output names.
|
|
0 commit comments