Skip to content

Commit 262bd64

Browse files
smessmerfacebook-github-bot
authored andcommitted
Show old kernel location when there are mismatches (pytorch#46850)
Summary: Pull Request resolved: pytorch#46850 So far, in the error messages when kernel signatures mismatched, we showed the location where the second kernel came from, but we didn't show the location of the first kernel. This PR now shows the location of both. ghstack-source-id: 115468616 Test Plan: waitforsandcastle Reviewed By: ezyang Differential Revision: D24540368 fbshipit-source-id: 3b4474062879d17f9bb7870ad3814343edc1b755
1 parent dfdc1db commit 262bd64

File tree

4 files changed

+36
-20
lines changed

4 files changed

+36
-20
lines changed

aten/src/ATen/BatchingRegistrations.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ Tensor permute_batching_rule(const Tensor& self, IntArrayRef dims) {
222222
VmapDimVector all_dims_physical;
223223
all_dims_physical.reserve(self_physical.tensor().dim());
224224
for (int64_t bdim = 0; bdim < self_physical.numBatchDims(); bdim++) {
225-
all_dims_physical.push_back(bdim);
225+
all_dims_physical.push_back(bdim);
226226
}
227227
all_dims_physical.insert(
228228
all_dims_physical.end(),

aten/src/ATen/core/dispatch/OperatorEntry.cpp

+13-6
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,15 @@ std::list<AnnotatedKernel>::iterator OperatorEntry::registerKernel(
8383
// that would also invalidate the old TypedOperatorHandles.
8484
if (cpp_signature.has_value()) {
8585
if (cpp_signature_.has_value()) {
86-
TORCH_INTERNAL_ASSERT(*cpp_signature == *cpp_signature_,
87-
"Tried to register a kernel (", debug, ") for operator ", name_," for dispatch key ", toString(dispatch_key),
88-
", but the C++ function signature ", cpp_signature->name(), " mismatched with a previous kernel that had the signature ",
89-
cpp_signature_->name()
86+
TORCH_INTERNAL_ASSERT(*cpp_signature == cpp_signature_->signature,
87+
"Tried to register a kernel (", debug, ") for operator ", name_," (",
88+
(this->schema_.has_value() ? this->schema_->debug : "no debug info"),
89+
") for dispatch key ", toString(dispatch_key), ", but the C++ function signature ",
90+
cpp_signature->name(), " mismatched with a previous kernel (", cpp_signature_->debug,
91+
") that had the signature ", cpp_signature_->signature.name()
9092
);
9193
} else {
92-
cpp_signature_ = *cpp_signature;
94+
cpp_signature_ = CppSignatureWithDebug { *cpp_signature, debug };
9395
}
9496
}
9597

@@ -103,7 +105,12 @@ std::list<AnnotatedKernel>::iterator OperatorEntry::registerKernel(
103105
auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::Math];
104106

105107
if (k.size() > 0) {
106-
TORCH_WARN("Registering a kernel (", debug, ") for operator ", name_, " for dispatch key ", toString(dispatch_key), " that overwrote a previously registered kernel with the same dispatch key for the same operator.");
108+
TORCH_WARN("Registering a kernel (", debug, ") for operator ", name_, " (",
109+
(this->schema_.has_value() ? this->schema_->debug : "no debug info"),
110+
") for dispatch key ", toString(dispatch_key),
111+
" that overwrote a previously registered kernel (",
112+
(cpp_signature_.has_value() ? cpp_signature_->debug : "no debug info"),
113+
") with the same dispatch key for the same operator.");
107114
}
108115

109116
if (manuallyBoxedKernel_.has_value()) {

aten/src/ATen/core/dispatch/OperatorEntry.h

+13-6
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,15 @@ class CAFFE2_API OperatorEntry final {
157157
// Asserts that the given FuncType is correct for calling this operator in an unboxed way.
158158
template<class FuncType>
159159
void assertSignatureIsCorrect() {
160-
TORCH_INTERNAL_ASSERT(!cpp_signature_.has_value() || (CppSignature::make<FuncType>() == *cpp_signature_),
160+
TORCH_INTERNAL_ASSERT(!cpp_signature_.has_value() || (CppSignature::make<FuncType>() == cpp_signature_->signature),
161161
"Tried to access operator ", name_, " with a wrong signature. Accessed with ",
162162
CppSignature::make<FuncType>().name(),
163163
" but the operator was registered with ",
164-
cpp_signature_->name(),
165-
" (",
164+
cpp_signature_->signature.name(),
165+
" (schema: ",
166166
(schema_.has_value() ? schema_->debug : "unknown debug info"),
167+
", kernel: ",
168+
cpp_signature_->debug,
167169
") This likely happened in a call to OperatorHandle::typed<Return (Args...)>(). Please make sure that the function signature matches the signature in the operator registration call."
168170
);
169171
}
@@ -230,12 +232,17 @@ class CAFFE2_API OperatorEntry final {
230232
AnnotatedKernel missingKernel_;
231233
static const AnnotatedKernel ambiguousAutogradOtherKernel_;
232234

233-
// signature_hash_ is set to the hash of the function signature if any of
235+
// cpp_signature_ stores function signature if any of
234236
// the kernels was created in a way that allowed us to know the function
235237
// signature (i.e. by supplying an unboxed C++ kernel function).
236-
// If this is set, it will be used in unboxed function calls
238+
// If this is set, it will be used to check that future kernel
239+
// registrations match and it will be used in unboxed function calls
237240
// to verify their arguments against the known function signature.
238-
c10::optional<CppSignature> cpp_signature_;
241+
struct CppSignatureWithDebug {
242+
CppSignature signature;
243+
std::string debug;
244+
};
245+
c10::optional<CppSignatureWithDebug> cpp_signature_;
239246

240247
// Whether this operator needs to be observed with RecordFunction
241248
const bool is_observed_;

aten/src/ATen/core/op_registration/op_registration_test.cpp

+9-7
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,8 @@ TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenRegis
310310
std::string output = testing::internal::GetCapturedStderr();
311311
EXPECT_THAT(output, testing::HasSubstr("_test::dummy"));
312312
EXPECT_THAT(output, testing::HasSubstr("CPU"));
313-
EXPECT_THAT(output, testing::HasSubstr("overwrote a previously registered kernel with the same dispatch key for the same operator"));
313+
EXPECT_THAT(output, testing::HasSubstr("overwrote a previously registered kernel "));
314+
EXPECT_THAT(output, testing::HasSubstr(" with the same dispatch key for the same operator"));
314315
}
315316

316317
TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenRegisteringInSameOpCall_thenFails) {
@@ -348,7 +349,8 @@ TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenRegistering_then
348349
std::string output = testing::internal::GetCapturedStderr();
349350
EXPECT_THAT(output, testing::HasSubstr("_test::dummy"));
350351
EXPECT_THAT(output, testing::HasSubstr("catch all"));
351-
EXPECT_THAT(output, testing::HasSubstr("overwrote a previously registered kernel with the same dispatch key for the same operator"));
352+
EXPECT_THAT(output, testing::HasSubstr("overwrote a previously registered kernel "));
353+
EXPECT_THAT(output, testing::HasSubstr(" with the same dispatch key for the same operator"));
352354
}
353355

354356
TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenRegisteringInSameOpCall_thenFails) {
@@ -701,7 +703,7 @@ TEST(OperatorRegistrationTest, whenRegisteringMismatchingKernelsInSameOpCall_the
701703
auto registrar1 = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
702704
.kernel<DummyKernelWithIntParam>(c10::DispatchKey::CPU)
703705
.kernel<MockKernel>(c10::DispatchKey::CUDA, &called_kernel));
704-
}, "mismatched with a previous kernel that had the signature");
706+
}, "mismatched with a previous kernel");
705707
}
706708

707709
void backend_fallback_kernel(const c10::OperatorHandle& op, c10::Stack* stack) {
@@ -944,7 +946,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringWithMismatchingC
944946
expectThrows<c10::Error>([] {
945947
auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
946948
.kernel(DispatchKey::CPU, [] (const int64_t&) {}));
947-
}, "mismatched with a previous kernel that had the signature");
949+
}, "mismatched with a previous kernel");
948950
}
949951

950952
TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringCatchAllAndBackendWithMismatchingCppSignatures_thenFails) {
@@ -953,7 +955,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringCatchAllAndBacke
953955
expectThrows<c10::Error>([] {
954956
auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
955957
.kernel(DispatchKey::CPU, [] (const int64_t&) {}));
956-
}, "mismatched with a previous kernel that had the signature");
958+
}, "mismatched with a previous kernel");
957959
}
958960

959961
TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringBackendAndCatchAllWithMismatchingCppSignatures_thenFails) {
@@ -962,7 +964,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringBackendAndCatchA
962964
expectThrows<c10::Error>([] {
963965
auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
964966
.catchAllKernel([] (const int64_t&) {}));
965-
}, "mismatched with a previous kernel that had the signature");
967+
}, "mismatched with a previous kernel");
966968
}
967969

968970
TEST(OperatorRegistrationTest, givenLambdaKernel_whenAccessingWithMismatchingCppSignatures_thenFails) {
@@ -989,7 +991,7 @@ TEST(OperatorRegistrationTest, givenTorchLibrary_whenRegisteringWithMismatchingC
989991
m.impl("dummy", DispatchKey::CPU, [] (int64_t) {});
990992
expectThrows<c10::Error>([&] {
991993
m.impl("dummy", DispatchKey::CUDA, [] (const int64_t&) {});
992-
}, "mismatched with a previous kernel that had the signature");
994+
}, "mismatched with a previous kernel");
993995
}
994996

995997
TEST(OperatorRegistrationTest, givenTorchLibrary_whenAccessingWithMismatchingCppSignatures_thenFails) {

0 commit comments

Comments
 (0)