Skip to content

Commit 02ca225

Browse files
kulinsethpytorchmergebot
authored andcommitted
[MPS] Fixes for Binary ops with casting issues from FP to uint8 (pytorch#94382)
Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#94382 Approved by: https://github.com/razarmehr
1 parent e0e4f1a commit 02ca225

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

aten/src/ATen/native/mps/operations/Copy.mm

+5-11
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,7 @@
11
// Copyright © 2022 Apple Inc.
22

3-
#include <ATen/mps/MPSStream.h>
43
#include <ATen/native/mps/Copy.h>
54
#include <ATen/native/mps/OperationUtils.h>
6-
#include <iostream>
7-
#include <cstring>
8-
#include <ATen/ATen.h>
9-
#include <ATen/Tensor.h>
10-
#include <ATen/Utils.h>
11-
#include <torch/library.h>
12-
#include <ATen/native/Resize.h>
13-
#include <c10/util/Optional.h>
14-
155

166
namespace at::native {
177
namespace mps {
@@ -84,7 +74,11 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
8474
newCachedGraph = new CachedGraph(mpsGraph);
8575

8676
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, src);
87-
MPSGraphTensor* outputTensor = [mpsGraph castTensor:inputTensor toType:dstDType name:@"cast"];
77+
MPSGraphTensor* inputCastTensor = inputTensor;
78+
if (isFloatingType(src.scalar_type()) && dstDType == MPSDataTypeUInt8) {
79+
inputCastTensor = [mpsGraph castTensor:inputTensor toType:MPSDataTypeInt32 name:@"cast"];
80+
}
81+
MPSGraphTensor* outputTensor = [mpsGraph castTensor:inputCastTensor toType:dstDType name:@"cast"];
8882

8983
newCachedGraph->inputTensor_ = inputTensor;
9084
newCachedGraph->outputTensor_ = outputTensor;

test/test_mps.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -8570,8 +8570,10 @@ class TestConsistency(TestCase):
85708570
'block_diag': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
85718571
'bmm': ['f32'],
85728572
'broadcast_shapes': ['f32'],
8573+
'byte': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
8574+
'cat': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
85738575
'ceil': ['f32', 'int32', 'int64', 'f16'],
8574-
'char': ['b8', 'u8'],
8576+
'char': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
85758577
'chunk': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
85768578
'clamp': ['f32', 'i16', 'i32', 'i64', 'u8'],
85778579
'clamp_max': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@@ -8607,17 +8609,19 @@ class TestConsistency(TestCase):
86078609
'flip': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
86088610
'fliplr': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
86098611
'flipud': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
8610-
'float': ['f32'],
8612+
'float': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
86118613
'floor': ['f32', 'f16', 'i16', 'i32', 'i64'],
86128614
'floor_divide': ['f32', 'f16'],
86138615
'frac': ['f16', 'f32'],
86148616
'gather': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
86158617
'gradient': ['f16', 'f32', 'i16'],
8616-
'half': ['f16'],
8618+
'ge': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
8619+
'gt': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
8620+
'half': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
86178621
'hstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
86188622
'index_select': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
86198623
'index_add': ['f16', 'f32', 'i16', 'i32'],
8620-
'int': ['i32'],
8624+
'int': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
86218625
'isclose': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
86228626
'isfinite': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
86238627
'isinf': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@@ -8724,7 +8728,7 @@ class TestConsistency(TestCase):
87248728
'scatter_add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
87258729
'select_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
87268730
'sgn': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
8727-
'short': ['i16'],
8731+
'short': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
87288732
'sigmoid': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'],
87298733
'sign': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8', 'i64'],
87308734
'sin': ['b8', 'f32', 'i16', 'i32', 'u8'],

0 commit comments

Comments
 (0)