Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions clang/lib/DPCT/RulesInclude/InclusionHeaders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "InclusionHeaders.h"
#include "PreProcessor.h"
#include <optional>

namespace clang {
namespace dpct {
Expand All @@ -33,7 +34,7 @@ class LastInclusionLocationUpdater {
bool UpdateNeeded;
};

std::string applyUserDefinedHeader(const std::string &FileName) {
std::optional<std::string> applyUserDefinedHeader(const std::string &FileName) {
// Apply user-defined rule if needed
auto It = MapNames::HeaderRuleMap.find(FileName);
if (It != MapNames::HeaderRuleMap.end() &&
Expand All @@ -54,11 +55,12 @@ std::string applyUserDefinedHeader(const std::string &FileName) {
for (auto &Header : Rule.Includes) {
PrintHeader(Header);
}
PrintHeader(Rule.Out);
if (!Rule.Out.empty())
PrintHeader(Rule.Out);
OS << Rule.Postfix;
return ReplHeaderStr;
}
return "";
return std::nullopt;
}

void insertHeaders(std::shared_ptr<DpctFileInfo> File,
Expand Down Expand Up @@ -150,6 +152,12 @@ void IncludesCallbacks::InclusionDirective(
Updater.give_up();
};

// Apply user-defined rule if needed
if (auto ReplacedStr = applyUserDefinedHeader(FileName.str()); ReplacedStr) {
EmplaceReplacement(std::move(ReplacedStr.value()));
return;
}

if (Global.isInAnalysisScope(IncludedFile)) {
IncludeFileMap[IncludedFile] = false;
Global.getIncludingFileSet().insert(IncludedFile);
Expand Down Expand Up @@ -208,14 +216,6 @@ void IncludesCallbacks::InclusionDirective(
!Global.getSourceManager().isWrittenInMainFile(HashLoc))
return;


// Apply user-defined rule if needed
if (auto ReplacedStr = applyUserDefinedHeader(FileName.str());
!ReplacedStr.empty()) {
EmplaceReplacement(std::move(ReplacedStr));
return;
}

do {
auto InfoPtr = DpctInclusionHeadersMap::findHeaderInfo(FileName);
if (!InfoPtr)
Expand Down
73 changes: 71 additions & 2 deletions clang/lib/DPCT/UserDefinedRules/UserDefinedRules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,19 @@ void registerMacroRule(MetaRuleObject &R) {
}
}

static std::map<std::string, OutputBuilder> APIRulesMap;

void registerAPIRule(MetaRuleObject &R) {
using namespace clang::dpct;
// register rule
// register all rules for CXXConstructExpr
OutputBuilder OB;
OB.Kind = OutputBuilder::Kind::Top;
OB.RuleName = R.RuleId;
OB.RuleFile = R.RuleFile;
APIRulesMap[R.In] = OB;
APIRulesMap[R.In].parse(R.Out);

// register all rules for CallExpr
registerMigrationRule(
R.RuleId, [In = R.In, HET = R.RuleAttributes.HasExplicitTemplateArgs] {
return std::make_unique<clang::dpct::UserDefinedAPIRule>(In, HET);
Expand Down Expand Up @@ -673,7 +683,7 @@ int OutputBuilder::consumeArgIndex(std::string &OutStr, size_t &Idx,
}
Idx = i;

if (Idx >= OutStr.size()) {
if (Idx > OutStr.size()) {
llvm::errs() << RuleFile << ":Error: in rule " << RuleName
<< ", a positive integer is expected after " << Keyword
<< "\n";
Expand Down Expand Up @@ -827,6 +837,39 @@ void clang::dpct::UserDefinedAPIRule::registerMatcher(
APIName, HasExplicitTemplateArgs)))))))
.bind("call"),
this);
MF.addMatcher(
cxxConstructExpr(hasType(namedDecl(hasAnyName(APIName)))).bind("ctor"),
this);
}

static void buildRewriterStrForCtor(const CXXConstructExpr *Ctor,
llvm::raw_string_ostream &OS,
const OutputBuilder &OB) {
switch (OB.Kind) {
case (OutputBuilder::Kind::Top):
for (auto &ob : OB.SubBuilders) {
buildRewriterStrForCtor(Ctor, OS, *ob);
}
return;
case (OutputBuilder::Kind::String):
OS << OB.Str;
return;
case (OutputBuilder::Kind::Arg): {
if (OB.ArgIndex >= Ctor->getNumArgs()) {
OS << "";
return;
}
ArgumentAnalysis AA;
AA.setCallSpelling(Ctor);
AA.analyze(Ctor->getArg(OB.ArgIndex));
OS << AA.getRewriteString();
return;
}
default:
DpctDebugs() << "[OutputBuilder::Kind] Unexpected value: " << OB.Kind
<< "\n";
assert(0);
}
}

void clang::dpct::UserDefinedAPIRule::runRule(
Expand All @@ -836,6 +879,32 @@ void clang::dpct::UserDefinedAPIRule::runRule(
EA.analyze(CE);
emplaceTransformation(EA.getReplacement());
EA.applyAllSubExprRepl();
} else if (const CXXConstructExpr *CCE =
getAssistNodeAsType<CXXConstructExpr>(Result, "ctor")) {
if (CCE->getNumArgs()) {
PrintingPolicy PP = DpctGlobalInfo::getContext().getPrintingPolicy();
PP.SuppressTagKeyword = true;
auto Iter =
APIRulesMap.find(CCE->getType().getCanonicalType().getAsString(PP));
if (Iter == APIRulesMap.end())
return;
std::string Repl;
{
llvm::raw_string_ostream OS(Repl);
buildRewriterStrForCtor(CCE, OS, Iter->second);
}
auto Range =
getDefinitionRange(CCE->getArg(0)->getBeginLoc(),
CCE->getArg(CCE->getNumArgs() - 1)->getEndLoc());
auto BeginLoc = Range.getBegin();
auto EndLoc = Range.getEnd();
const auto &SM = DpctGlobalInfo::getContext().getSourceManager();
auto Len = Lexer::MeasureTokenLength(
EndLoc, SM, DpctGlobalInfo::getContext().getLangOpts());
Len += SM.getDecomposedLoc(EndLoc).second -
SM.getDecomposedLoc(BeginLoc).second;
emplaceTransformation(new ReplaceText(BeginLoc, Len, std::move(Repl)));
}
}
}

Expand Down
43 changes: 30 additions & 13 deletions clang/test/dpct/pytorch/ATen.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// RUN: cp -r %S/pytorch_inc %T/pytorch/ATen/
// RUN: cd %T/pytorch/ATen
// RUN: mkdir dpct_out
// RUN: dpct --out-root dpct_out %T/pytorch/ATen/src/ATen.cu --extra-arg="-I%T/pytorch/ATen/pytorch_inc" --cuda-include-path="%cuda-path/include" --rule-file=%S/../../../tools/dpct/extensions/pytorch_api_rules/pytorch_api.yaml --analysis-scope-path %T/pytorch/ATen/pytorch_inc --analysis-scope-path %T/pytorch/ATen/src --in-root %T/pytorch/ATen/src
// RUN: dpct --format-range=none --out-root dpct_out %T/pytorch/ATen/src/ATen.cu --extra-arg="-I%T/pytorch/ATen/pytorch_inc" --cuda-include-path="%cuda-path/include" --rule-file=%S/../../../tools/dpct/extensions/pytorch_api_rules/pytorch_api.yaml --analysis-scope-path %T/pytorch/ATen/pytorch_inc --analysis-scope-path %T/pytorch/ATen/src --in-root %T/pytorch/ATen/src
// RUN: FileCheck --input-file %T/pytorch/ATen/dpct_out/ATen.dp.cpp --match-full-lines %T/pytorch/ATen/src/ATen.cu

// CHECK: #include <c10/xpu/XPUStream.h>
Expand All @@ -18,6 +18,18 @@
// CHECK-NEXT: #include <c10/util/Half.h>
#include <ATen/cuda/CUDATensorMethods.cuh>

// CHECK: // BEGIN_1
// CHECK-EMPTY:
// CHECK-EMPTY:
// CHECK-NEXT: // END_1
// BEGIN_1
#include <ATen/cuda/Exceptions.h>
#include <THC/THCAtomics.cuh>
// END_1

// CHECK: #include <c10/xpu/XPUMacros.h>
#include <c10/cuda/CUDAMacros.h>

#define AT_CUDA_CHECK(stmt) (stmt)

// CHECK: #define BE_AT_CHECK
Expand All @@ -31,20 +43,19 @@ void test_CUDAStream_as_arg() {
dim3 blockSize(8, 8, 1);
void *args[] = {nullptr};

// CHECK: ([&]() {
// CHECK-NEXT: ((sycl::queue *)(c10::xpu::getCurrentXPUStream()))
// CHECK-NEXT: ->parallel_for(sycl::nd_range<3>(gridSize * blockSize, blockSize),
// CHECK-NEXT: [=](sycl::nd_item<3> item_ct1) {
// CHECK-NEXT: kernel();
// CHECK-NEXT: });
// CHECK: ([&](){
// CHECK-NEXT: ((sycl::queue*)(c10::xpu::getCurrentXPUStream()))->parallel_for(
// CHECK-NEXT: sycl::nd_range<3>(gridSize * blockSize, blockSize),
// CHECK-NEXT: [=](sycl::nd_item<3> item_ct1) {
// CHECK-NEXT: kernel();
// CHECK-NEXT: });
// CHECK-NEXT: return 0;
// CHECK-NEXT: }());
AT_CUDA_CHECK(cudaLaunchKernel((const void *)kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
}

int main() {
// CHECK: dpct::queue_ptr st =
// CHECK-NEXT: &static_cast<sycl::queue &>(c10::xpu::getCurrentXPUStream());
// CHECK: dpct::queue_ptr st = &static_cast<sycl::queue&>(c10::xpu::getCurrentXPUStream());
cudaStream_t st = 0;

// stream APIs
Expand All @@ -55,14 +66,20 @@ int main() {
// CHECK: auto deviceStream = c10::xpu::getCurrentXPUStream(devInd);
auto deviceStream = at::cuda::getCurrentCUDAStream(devInd);

// CHECK: dpct::queue_ptr curr_cuda_st =
// CHECK-NEXT: &static_cast<sycl::queue &>(c10::xpu::getCurrentXPUStream().queue());
// CHECK: dpct::queue_ptr curr_cuda_st = &static_cast<sycl::queue &>(c10::xpu::getCurrentXPUStream(). queue());
cudaStream_t curr_cuda_st = at::cuda::getCurrentCUDAStream().stream();
// CHECK: dpct::queue_ptr dev_cuda_st = &static_cast<sycl::queue &>(
// CHECK-NEXT: c10::xpu::getCurrentXPUStream(devInd).queue());
// CHECK: dpct::queue_ptr dev_cuda_st = &static_cast<sycl::queue &>(c10::xpu::getCurrentXPUStream(devInd). queue());
cudaStream_t dev_cuda_st = at::cuda::getCurrentCUDAStream(devInd).stream();

test_CUDAStream_as_arg();

return 0;
}

void foo2(const at::Tensor &x) {
float *f;
// CHECK: (DPCT_CHECK_ERROR(f = (float *)sycl::malloc_device(4, static_cast<sycl::queue&>(c10::xpu::getCurrentXPUStream()))));
// CHECK-NEXT: c10::DeviceGuard device_guard{torch::kXPU, (char)x.get_device()};
C10_CUDA_CHECK(cudaMalloc(&f, 4));
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
}
6 changes: 6 additions & 0 deletions clang/test/dpct/pytorch/pytorch_inc/ATen/Tensor.h
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
#pragma once
namespace at {
class Tensor {
public:
int get_device() const { return 0; }
};
} // namespace at
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>

namespace at {
using namespace c10;
Expand Down
2 changes: 2 additions & 0 deletions clang/test/dpct/pytorch/pytorch_inc/ATen/cuda/Exceptions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#pragma once

1 change: 1 addition & 0 deletions clang/test/dpct/pytorch/pytorch_inc/THC/THCAtomics.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
// RUN: echo "empty command"
11 changes: 11 additions & 0 deletions clang/test/dpct/pytorch/pytorch_inc/c10/cuda/CUDAGuard.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class optional {
} // namespace std

namespace c10 {
using DeviceIndex = int8_t;
class Device {
public:
Device(std::string str) {}
Expand All @@ -19,5 +20,15 @@ class OptionalCUDAGuard {
public:
OptionalCUDAGuard(std::optional<c10::Device> device) {}
};
struct CUDAGuard {
explicit CUDAGuard() = delete;
explicit CUDAGuard(DeviceIndex device_index) {}
explicit CUDAGuard(Device device) {}
CUDAGuard(const CUDAGuard&) = delete;
CUDAGuard& operator=(const CUDAGuard&) = delete;
CUDAGuard(CUDAGuard&& other) = delete;
CUDAGuard& operator=(CUDAGuard&& other) = delete;
~CUDAGuard() = default;
};
} // namespace cuda
} // namespace c10
4 changes: 4 additions & 0 deletions clang/test/dpct/pytorch/pytorch_inc/c10/cuda/CUDAMacros.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
#define C10_CUDA_IMPORT
#define C10_CUDA_API
#define C10_CUDA_BUILD_MAIN_LIB
#define C10_CUDA_CHECK(EXPR) \
do { \
const cudaError_t __err = EXPR; \
} while (0)
31 changes: 31 additions & 0 deletions clang/tools/dpct/extensions/pytorch_api_rules/pytorch_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,34 @@
In: get_in_order_queue
Out: static_cast<sycl::queue&>(c10::xpu::getCurrentXPUStream())
Includes: [<c10/xpu/XPUStream.h>]

- Rule: rule_THC_THCAtomics_cuh
Kind: Header
Priority: Takeover
In: THC/THCAtomics.cuh
Out: |

- Rule: rule_ATen_cuda_Exceptions_h
Kind: Header
Priority: Takeover
In: ATen/cuda/Exceptions.h
Out: |

- Rule: rule_remove_C10_CUDA_CHECK
Kind: Macro
Priority: Takeover
In: C10_CUDA_CHECK
Out: |

- Rule: rule_at_cuda_CUDAGuard
Kind: Type
Priority: Takeover
In: c10::cuda::CUDAGuard
Out: c10::DeviceGuard
Includes: [<c10/core/DeviceGuard.h>]

- Rule: rule_at_cuda_CUDAGuard_ctor
Kind: API
Priority: Takeover
In: c10::cuda::CUDAGuard
Out: torch::kXPU, $1