Skip to content

PyTorch DirectML Operator Roadmap

Adele101 edited this page Aug 8, 2022 · 6 revisions

The table below comprises PyTorch operators that have at least one GPU kernel.

The is_implemented column indicates the following DirectML support:

  • ✅ = the op has at least one DirectML kernel registered
  • 🚧 = the op will be implemented in the next milestone (~3 months)

Please note that this is not a comprehensive list of all PyTorch operators, but rather a list of current support and plans for near future support. If you would require support for an operator that is not on this list, please file an issue.

op_name is_implemented
0 aten::abs_
1 aten::abs
2 aten::abs.out
3 aten::_adaptive_avg_pool2d
4 aten::adaptive_avg_pool2d
5 aten::adaptive_avg_pool2d.out
6 aten::_adaptive_avg_pool2d_backward
7 aten::add_
8 aten::add
9 aten::add.out
10 aten::add_.Scalar
11 aten::add.Scalar
12 aten::add_.Tensor
13 aten::add.Tensor
14 aten::addcdiv.out
15 aten::addcmul.out
16 aten::addmm
17 aten::addmm_
18 aten::addmm.out
19 aten::alias
20 aten::all
21 aten::all.dim
22 aten::all.out
23 aten::__and__
24 aten::any
25 aten::any.dim
26 aten::any.out
27 aten::arange
28 aten::arange.start_out
29 aten::as_strided_
30 aten::as_strided
31 aten::atan_
32 aten::atan
33 aten::atan.out
34 aten::avg_pool2d
35 aten::avg_pool2d.out
36 aten::avg_pool2d_backward
37 aten::avg_pool2d_backward.grad_input
38 aten::batch_norm
39 aten::_batch_norm_impl_index
40 aten::bernoulli
41 aten::bernoulli_
42 aten::bernoulli_.float
43 aten::bernoulli.out
44 aten::bernoulli_.Tensor
45 aten::binary_cross_entropy_with_logits
46 aten::binary_cross_entropy_with_logits_backward
47 aten::bitwise_and
48 aten::bitwise_and.Scalar_out
49 aten::bitwise_and.Tensor_out
50 aten::bitwise_or
51 aten::bitwise_or.Scalar_out
52 aten::bitwise_or.Tensor_out
53 aten::broadcast_tensors
54 aten::broadcast_to
55 aten::can_cast
56 aten::cat
57 aten::ceil
58 aten::ceil_
59 aten::ceil.out
60 aten::chunk
61 aten::clamp
62 aten::clamp.out
63 aten::clamp_max.out
64 aten::clamp_min.out
65 aten::clone
66 aten::contiguous
67 aten::conv2d
68 aten::convolution
69 aten::_convolution
70 aten::convolution_backward
71 aten::convolution_backward_overrideable
72 aten::convolution_overrideable
73 aten::conv_transpose2d
74 aten::copy_
75 aten::_copy_from
76 aten::cross_entropy_loss
77 aten::detach
78 aten::detach_
79 aten::div_
80 aten::div
81 aten::div.out
82 aten::div.out_mode
83 aten::div_.Scalar
84 aten::div.Scalar
85 aten::div.Tensor
86 aten::div_.Tensor
87 aten::div.Tensor_mode
88 aten::div_.Tensor_mode
89 aten::dropout
90 aten::dropout_
91 aten::empty
92 aten::empty.memory_format
93 aten::empty_like
94 aten::empty_strided
95 aten::eq
96 aten::eq.Scalar
97 aten::eq_.Scalar
98 aten::eq.Scalar_out
99 aten::eq_.Tensor
100 aten::eq.Tensor
101 aten::eq.Tensor_out
102 aten::exp
103 aten::exp_
104 aten::exp.out
105 aten::exp2.out
106 aten::expand
107 aten::expand_as
108 aten::expm1.out
109 aten::fill_
110 aten::fill_.Scalar
111 aten::flatten
112 aten::flip
113 aten::floor_
114 aten::floor
115 aten::floor.out
116 aten::full
117 aten::full_like
118 aten::ge
119 aten::ge_.Scalar
120 aten::ge.Scalar
121 aten::ge.Scalar_out
122 aten::ge_.Tensor
123 aten::ge.Tensor
124 aten::ge.Tensor_out
125 aten::gt
126 aten::gt_.Scalar
127 aten::gt.Scalar
128 aten::gt.Scalar_out
129 aten::gt_.Tensor
130 aten::gt.Tensor
131 aten::gt.Tensor_out
132 aten::hardsigmoid
133 aten::hardsigmoid_
134 aten::hardsigmoid.out
135 aten::hardsigmoid_backward
136 aten::hardswish
137 aten::hardswish_
138 aten::hardswish.out
139 aten::hardswish_backward
140 aten::hardtanh_
141 aten::hardtanh
142 aten::hardtanh.out
143 aten::hardtanh_backward
144 aten::hardtanh_backward.grad_input
145 aten::index
146 aten::index.Tensor
147 aten::index_put_
148 aten::index_put
149 aten::_index_put_impl_
150 aten::index_select
151 aten::index_select.out
152 aten::is_nonzero
153 aten::item
154 aten::l1_loss
155 aten::l1_loss_backward
156 aten::le
157 aten::le.Scalar
158 aten::le_.Scalar
159 aten::le.Scalar_out
160 aten::le_.Tensor
161 aten::le.Tensor
162 aten::le.Tensor_out
163 aten::leaky_relu
164 aten::leaky_relu_
165 aten::leaky_relu.out
166 aten::leaky_relu_backward
167 aten::lift
168 aten::linear
169 aten::_local_scalar_dense
170 aten::log
171 aten::log_
172 aten::log.out
173 aten::log10
174 aten::log10_
175 aten::log10.out
176 aten::log1p
177 aten::log1p_
178 aten::log1p.out
179 aten::log2
180 aten::log2_
181 aten::log2.out
182 aten::logical_and
183 aten::logical_and_
184 aten::logical_and.out
185 aten::logical_not_
186 aten::logical_not
187 aten::logical_not.out
188 aten::logical_or_
189 aten::logical_or
190 aten::logical_or.out
191 aten::logical_xor
192 aten::logical_xor_
193 aten::logical_xor.out
194 aten::_log_softmax
195 aten::log_softmax
196 aten::_log_softmax_backward_data
197 aten::lt
198 aten::lt_.Scalar
199 aten::lt.Scalar
200 aten::lt.Scalar_out
201 aten::lt.Tensor
202 aten::lt_.Tensor
203 aten::lt.Tensor_out
204 aten::masked_fill_
205 aten::masked_fill_.Scalar
206 aten::masked_fill_.Tensor
207 aten::masked_select
208 aten::masked_select.out
209 aten::max
210 aten::max.dim
211 aten::maximum
212 aten::maximum.out
213 aten::max_pool2d
214 aten::max_pool2d_with_indices
215 aten::max_pool2d_with_indices.out
216 aten::max_pool2d_with_indices_backward
217 aten::max_pool2d_with_indices_backward.grad_input
218 aten::mean
219 aten::mean.dim
220 aten::mean.out
221 aten::meshgrid
222 aten::min
223 aten::minimum
224 aten::minimum.out
225 aten::mm
226 aten::mm.out
227 aten::mul
228 aten::mul_
229 aten::mul.out
230 aten::mul_.Scalar
231 aten::mul.Scalar
232 aten::mul_.Tensor
233 aten::mul.Tensor
234 aten::narrow
235 aten::native_batch_norm
236 aten::native_batch_norm_backward
237 aten::ne
238 aten::ne.Scalar
239 aten::ne_.Scalar
240 aten::ne.Scalar_out
241 aten::ne_.Tensor
242 aten::ne.Tensor
243 aten::ne.Tensor_out
244 aten::neg
245 aten::neg_
246 aten::neg.out
247 aten::new_empty
248 aten::new_full
249 aten::new_zeros
250 aten::nll_loss
251 aten::nll_loss_backward
252 aten::nll_loss_backward.grad_input
253 aten::nll_loss_forward
254 aten::nll_loss_nd
255 aten::nonzero
256 aten::nonzero.out
257 aten::nonzero_numpy
258 aten::ones_like
259 aten::__or__
260 aten::permute
261 aten::pow
262 aten::pow_.Scalar
263 aten::pow_.Tensor
264 aten::pow.Tensor_Scalar
265 aten::pow.Tensor_Scalar_out
266 aten::pow.Tensor_Tensor
267 aten::pow.Tensor_Tensor_out
268 aten::prelu
269 aten::prelu_backward
270 aten::prod
271 aten::prod.dim_int
272 aten::prod.int_out
273 aten::random_
274 aten::random_.from
275 aten::random_.to
276 aten::randperm
277 aten::randperm.generator_out
278 aten::reciprocal_
279 aten::reciprocal
280 aten::reciprocal.out
281 aten::relu_
282 aten::relu
283 aten::remainder.Scalar
284 aten::remainder_.Scalar
285 aten::remainder.Scalar_out
286 aten::remainder_.Tensor
287 aten::remainder.Tensor
288 aten::remainder.Tensor_out
289 aten::repeat
290 aten::reshape
291 aten::_reshape_alias
292 aten::resize_
293 aten::result_type
294 aten::round_
295 aten::round
296 aten::round.out
297 aten::rsqrt_
298 aten::rsqrt
299 aten::rsqrt.out
300 aten::rsub
301 aten::rsub.Scalar
302 aten::rsub.Tensor
303 aten::scalar_tensor
304 aten::scatter_
305 aten::scatter_.src
306 aten::select
307 aten::sgn
308 aten::sgn_
309 aten::sgn.out
310 aten::sigmoid_
311 aten::sigmoid
312 aten::sigmoid.out
313 aten::sigmoid_backward
314 aten::sigmoid_backward.grad_input
315 aten::sign
316 aten::sign_
317 aten::sign.out
318 aten::silu
319 aten::silu_
320 aten::silu.out
321 aten::slice
322 aten::slice_backward
323 aten::smooth_l1_loss
324 aten::smooth_l1_loss_backward
325 aten::_softmax
326 aten::sort
327 aten::split
328 aten::split_with_sizes
329 aten::sqrt
330 aten::sqrt_
331 aten::sqrt.out
332 aten::squeeze
333 aten::squeeze_
334 aten::stack
335 aten::sub
336 aten::sub.out
337 aten::sub_.Tensor
338 aten::sub.Tensor
339 aten::sum
340 aten::sum.dim_IntList
341 aten::sum.IntList_out
342 aten::t
343 aten::threshold_
344 aten::threshold
345 aten::threshold.out
346 aten::threshold_backward
347 aten::to
348 aten::_to_copy
349 aten::topk
350 aten::topk.values
351 aten::transpose
352 aten::unbind
353 aten::unfold
354 aten::uniform_
355 aten::_unique2
356 aten::_unsafe_view
357 aten::unsqueeze_
358 aten::unsqueeze
359 aten::upsample_bilinear2d
360 aten::upsample_bilinear2d.out
361 aten::upsample_bilinear2d.vec
362 aten::upsample_bilinear2d_backward
363 aten::upsample_bilinear2d_backward.grad_input
364 aten::upsample_bilinear2d_backward.vec
365 aten::upsample_nearest2d
366 aten::upsample_nearest2d.out
367 aten::upsample_nearest2d_backward
368 aten::upsample_nearest2d_backward.grad_input
369 aten::value_selecting_reduction_backward
370 aten::view
371 aten::where
372 aten::where.self
373 aten::zero_
374 aten::zeros
375 aten::zeros_like
376 aten::addcdiv_ 🚧
377 aten::addcmul_ 🚧
378 aten::argmax 🚧
379 aten::baddbmm 🚧
380 aten::bitwise_not 🚧
381 aten::bmm 🚧
382 aten::cdist 🚧
383 aten::_cdist_forward 🚧
384 aten::cos 🚧
385 aten::cumsum 🚧
386 aten::diag 🚧
387 aten::diag_backward 🚧
388 aten::embedding 🚧
389 aten::embedding_backward 🚧
390 aten::embedding_dense_backward 🚧
391 aten::floor_divide 🚧
392 aten::gelu 🚧
393 aten::gelu_backward 🚧
394 aten::isinf 🚧
395 aten::isnan 🚧
396 aten::layer_norm 🚧
397 aten::matmul 🚧
398 aten::native_layer_norm 🚧
399 aten::native_layer_norm_backward 🚧
400 aten::nll_loss2d 🚧
401 aten::nll_loss2d_backward 🚧
402 aten::nll_loss2d_forward 🚧
403 aten::norm 🚧
404 aten::select_backward 🚧
405 aten::sin 🚧
406 aten::softmax 🚧
407 aten::_softmax_backward_data 🚧
408 aten::tanh 🚧
409 aten::tanh_backward 🚧
Clone this wiki locally