Skip to content

Commit 44fd98e

Browse files
authored
[DORT] Enable aten::full by implementing extra logics to select EP (microsoft#16699)
DORT only select devices from inputs arguments' (type: torch.Tensor). However, it errors out when a graph doesn't have any inputs (e.g., a single aten::full graph). This PR address this problem by changing the EP selection to - First, inspect graph inputs. If there are some valid devices, use them plus a default one (`OrtBackend.ep: str`). - Otherwise, inspect graph outputs carried by `torch.fx.GraphModule` and use all valid devices plus the default `OrtBackend.ep`. - When both (1) and (2) fail, it uses the default EP specified by `OrtBackend.ep`.
1 parent f236768 commit 44fd98e

File tree

2 files changed

+126
-13
lines changed

2 files changed

+126
-13
lines changed

orttraining/orttraining/python/training/torchdynamo/ort_backend.py

+74-13
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
3131
from torch.fx.passes.operator_support import OperatorSupport
3232
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
33+
from torch.utils import _pytree
3334

3435
import onnxruntime # type: ignore
3536
from onnxruntime.capi import _pybind_state as ORTC
@@ -199,16 +200,59 @@ def _create_onnx_model(onnx_proto):
199200
return onnx.ModelProto.FromString(onnx_proto)
200201

201202

202-
def _create_onnx_session(onnx_proto, ep: str, session_options):
203+
def _create_onnx_session(onnx_proto, eps: Tuple[str, ...], session_options):
203204
# TODO(wechi): Add more EPs per PyTorch device types.
204205
# TODO(wechi): enable external allocators.
205-
return onnxruntime.InferenceSession(onnx_proto, providers=[ep], sess_options=session_options)
206-
207-
208-
def _infer_ep_from_device(device):
209-
if device.type == "cuda":
210-
return "CUDAExecutionProvider"
211-
return "CPUExecutionProvider"
206+
return onnxruntime.InferenceSession(onnx_proto, providers=eps, sess_options=session_options)
207+
208+
209+
def _infer_ep_from_device(*args) -> Tuple[str, ...]:
210+
"""Return the first valid device (i.e., GPU or CPU) in argument list."""
211+
eps = []
212+
for arg in args:
213+
if hasattr(arg, "device"):
214+
device = arg.device
215+
if device.type == "cuda":
216+
eps.append("CUDAExecutionProvider")
217+
elif device.type == "cpu":
218+
eps.append("CPUExecutionProvider")
219+
return tuple(eps)
220+
221+
222+
def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> Tuple[str, ...]:
223+
"""Return the first valid device (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule."""
224+
for node in graph_module.graph.nodes:
225+
if node.op == "output":
226+
# Output node is unique. Let's retrieve output values from
227+
# this node's input list. And then just return.
228+
flattened_output_args, _ = _pytree.tree_flatten(node.args)
229+
output_args = []
230+
for output_arg in flattened_output_args:
231+
if hasattr(output_arg, "meta") and "val" in output_arg.meta:
232+
# Select outputs with "val" information. Without "val",
233+
# it's not possible access output_arg.meta["val"].device.
234+
output_args.append(output_arg.meta["val"])
235+
return _infer_ep_from_device(*output_args)
236+
graph_module_str = graph_module.print_readable(print_output=False)
237+
raise ValueError(f"No output node is found in graph_module: {graph_module_str}")
238+
239+
240+
def _sort_eps(eps: Tuple[str, ...]) -> Tuple[str, ...]:
241+
"""Sort execution providers in eps based on pre-set priority."""
242+
243+
def get_execution_provider_priority(ep: str) -> int:
244+
if ep == "CPUExecutionProvider":
245+
# Lowest priority.
246+
return 2
247+
if ep == "CUDAExecutionProvider":
248+
# Higher priority than CPU but lower than
249+
# other specialized EPs.
250+
return 1
251+
# Highest priority.
252+
return 0
253+
254+
unique_eps = set(eps)
255+
return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True))
212256

213257

214258
def _get_onnx_devices(values: Tuple[torch.Tensor, ...]) -> Tuple[ORTC.OrtDevice, ...]: # type: ignore
@@ -346,7 +390,7 @@ class OrtBackend:
346390
3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph.
347391
"""
348392

349-
def __init__(self, ep: str = "", preallocate_output: bool = False, session_options=None):
393+
def __init__(self, ep: str = "CPUExecutionProvider", preallocate_output: bool = False, session_options=None):
350394
self._supported_ops = OrtOperatorSupport()
351395
# TODO: this is a naive implementation of cache without proper guard
352396
self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {}
@@ -418,11 +462,28 @@ def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwar
418462
).SerializeToString()
419463

420464
# Initialize a ORT session to execute this ONNX model.
421-
# TorchDynamo assumes all inputs/outputs are on the same device,
422-
# so we add execution provider only based on the first input's device.
423-
ep = self.ep or _infer_ep_from_device(args[0].device)
465+
# Note that TorchDynamo assumes all inputs/outputs are on the
466+
# same device, but it's subject to change (very likely with
467+
# dynamic shape support), so we add execution providers
468+
# based on the all inputs/outputs plus a default OrtBackend.ep.
469+
eps_from_args = _infer_ep_from_device(args)
470+
eps_from_graph_module = _infer_ep_from_graph_module(graph_module)
471+
if eps_from_args:
472+
# If user feeds CUDA tensor as input argument,
473+
# we want to use CUDA EP.
474+
# Thus, `eps_from_args` (deduced from input arguments)
475+
# has highest priority.
476+
selected_eps = _sort_eps((*eps_from_args, self.ep))
477+
elif eps_from_graph_module:
478+
# If there is no EP in input arguments, we deduce EP from
479+
# graph_module's outputs. Those outputs may come from
480+
# FakeTensorProp or Dynamo's built-in symbolic shape inference.
481+
selected_eps = _sort_eps((*eps_from_graph_module, self.ep))
482+
else:
483+
# No EP found in inputs and outputs, let's use default.
484+
selected_eps = (self.ep,)
424485

425-
onnx_session = _create_onnx_session(onnx_proto, ep, self.session_options)
486+
onnx_session = _create_onnx_session(onnx_proto, selected_eps, self.session_options)
426487
# Cache ORT session. It's reused for the same "graph_module".
427488
self._ort_execution_info.sessions[graph_module] = onnx_session
428489
# Generate ONNX model and extract its input and output names.

orttraining/orttraining/test/python/orttraining_test_dort.py

+52
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,58 @@ def run(fun, list_x):
118118

119119
run_to_copy()
120120

121+
def test_aten_full(self):
122+
torch._dynamo.reset()
123+
124+
def run_no_input_model():
125+
# A function to test.
126+
def no_input_model():
127+
return torch.ops.aten.full([2, 3], 1.5)
128+
129+
@torch._dynamo.optimize(aot_ort)
130+
def optimized_no_input_model():
131+
return no_input_model()
132+
133+
def run(fun):
134+
tensor_x = fun()
135+
return tensor_x
136+
137+
# Baseline.
138+
tensor_x = run(no_input_model)
139+
# ORT result.
140+
tensor_x_new = run(optimized_no_input_model)
141+
142+
torch.testing.assert_close(tensor_x, tensor_x_new)
143+
144+
for _ in range(5):
145+
run_no_input_model()
146+
147+
def test_aten_full_with_device(self):
148+
torch._dynamo.reset()
149+
150+
def run_no_input_model():
151+
# A function to test.
152+
def no_input_model():
153+
return torch.ops.aten.full([2, 3], 1.5, device="cpu")
154+
155+
@torch._dynamo.optimize(aot_ort)
156+
def optimized_no_input_model():
157+
return no_input_model()
158+
159+
def run(fun):
160+
tensor_x = fun()
161+
return tensor_x
162+
163+
# Baseline.
164+
tensor_x = run(no_input_model)
165+
# ORT result.
166+
tensor_x_new = run(optimized_no_input_model)
167+
168+
torch.testing.assert_close(tensor_x, tensor_x_new)
169+
170+
for _ in range(5):
171+
run_no_input_model()
172+
121173
def test_mnist_model(self):
122174
torch._dynamo.reset()
123175
"""Test DORT with a simple nn.Module."""

0 commit comments

Comments
 (0)