Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCLomatic][PTX] Support migration of PTX instruction st.cs.global.v2.s16 and st.cs.global.v4.s16 #2748

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 89 additions & 11 deletions clang/lib/DPCT/RulesAsm/AsmMigration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

#include "AsmMigration.h"
#include "AnalysisInfo.h"
#include "Diagnostics/Diagnostics.h"
#include "ErrorHandle/CrashRecovery.h"
#include "RuleInfra/MapNames.h"
#include "RulesAsm/Parser/AsmNodes.h"
#include "RulesAsm/Parser/AsmParser.h"
#include "RulesAsm/Parser/AsmTokenKinds.h"
#include "ErrorHandle/CrashRecovery.h"
#include "Diagnostics/Diagnostics.h"
#include "RuleInfra/MapNames.h"
#include "TextModification.h"
#include "Utility.h"
#include "clang/AST/Expr.h"
Expand Down Expand Up @@ -609,7 +609,7 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
return false;
};

if (CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_red))
if (CurrInst->is(asmtok::op_ld, asmtok::op_red))
OS() << "*";
switch (Dst->getMemoryOpKind()) {
case InlineAsmAddressExpr::Imm:
Expand All @@ -632,8 +632,12 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
std::string Reg;
if (tryEmitStmt(Reg, Dst->getSymbol()))
return SYCLGenSuccess();
OS() << llvm::formatv("(({0} *)((uintptr_t){1} + {2}))", Type, Reg,
Dst->getImmAddr()->getValue().getZExtValue());

if (CurrInst->is(asmtok::op_st))
OS() << llvm::formatv("(uintptr_t){0}", Reg);
else
OS() << llvm::formatv("(({0} *)((uintptr_t){1} + {2}))", Type, Reg,
Dst->getImmAddr()->getValue().getZExtValue());
break;
}
case InlineAsmAddressExpr::Var: {
Expand Down Expand Up @@ -2690,24 +2694,98 @@ class SYCLGen : public SYCLGenBase {
return SYCLGenSuccess();
}

bool HandleStVec(const InlineAsmInstruction *Inst, int VecNum) {
std::string Ops;
if (tryEmitStmt(Ops, Inst->getInputOperand(0)))
return SYCLGenError();

// To extract the values from the string like "{x, y, z, w}" and store them
// int Values vector
std::vector<std::string> Values;
size_t start = 1; // Skip the '{' character
size_t end = Ops.find(',', start); // Find the first comma

while (end != std::string::npos) {
std::string Token = Ops.substr(start, end - start);
size_t First = Token.find_first_not_of(' ');
size_t Last = Token.find_last_not_of(' ');
if (First != std::string::npos && Last != std::string::npos) {
Values.push_back(Token.substr(First, Last - First + 1));
}
start = end + 1;
end = Ops.find(',', start);
}

// Extract the last value after the last comma
std::string token = Ops.substr(start, Ops.size() - start - 1);
size_t first = token.find_first_not_of(' ');
size_t last = token.find_last_not_of(' ');

if (first != std::string::npos && last != std::string::npos) {
Values.push_back(token.substr(first, last - first + 1));
}

std::string Output;
if (tryEmitStmt(Output, Inst->getOutputOperand()))
return SYCLGenError();

std::string Type;
if (tryEmitType(Type, Inst->getType(0)))
return SYCLGenError();

const auto *Dst =
dyn_cast_or_null<InlineAsmAddressExpr>(Inst->getOutputOperand());
if (!Dst)
return SYCLGenError();

for (int Index = 0; Index < VecNum; Index++) {
OS() << llvm::formatv("*(({0} *)({1}) + {2}) = {3}{4}", Type, Output,
Index, Values[Index],
Index == VecNum - 1 ? "" : ";\n");
}

endstmt();
return SYCLGenSuccess();
}

bool handle_st(const InlineAsmInstruction *Inst) override {
if (Inst->getNumInputOperands() != 1)
return SYCLGenError();
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
CurrInst = Inst;

llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst, Inst);

if (Inst->hasAttr(InstAttr::cs)) {
if (Inst->hasAttr(InstAttr::v4))
return HandleStVec(Inst, 4);
if (Inst->hasAttr(InstAttr::v2))
return HandleStVec(Inst, 2);
}

const auto *Src = Inst->getInputOperand(0);
const auto *Dst =
dyn_cast_or_null<InlineAsmAddressExpr>(Inst->getOutputOperand());
if (!Dst)
return false;
return SYCLGenError();

std::string Type;
if (tryEmitType(Type, Inst->getType(0)))
return SYCLGenError();
if (emitStmt(Dst))

std::string OutOp;
if (tryEmitStmt(OutOp, Inst->getOutputOperand()))
return SYCLGenError();

if (Dst->getMemoryOpKind() == InlineAsmAddressExpr::RegImm) {
OS() << llvm::formatv("*(({0} *)({1} + {2}))", Type, OutOp,
Dst->getImmAddr()->getValue().getZExtValue());
} else {
OS() << "*" << OutOp;
}

OS() << " = ";
if (emitStmt(Src))
return SYCLGenError();

endstmt();
return SYCLGenSuccess();
}
Expand All @@ -2722,7 +2800,7 @@ class SYCLGen : public SYCLGenBase {
const auto *Dst = Inst->getOutputOperand();

if (!Src)
return false;
return SYCLGenError();
std::string Type;
if (tryEmitType(Type, Inst->getType(0)))
return SYCLGenError();
Expand Down
1 change: 1 addition & 0 deletions clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ MODIFIER(rc8, ".rc8")
MODIFIER(ecl, ".ecl")
MODIFIER(ecr, ".ecr")
MODIFIER(rc16, ".rc16")
MODIFIER(cs, ".cs")

#undef LINKAGE
#undef TARGET
Expand Down
24 changes: 24 additions & 0 deletions clang/test/dpct/asm/st.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,28 @@ __device__ void shared_address_store32(uint32_t addr, uint32_t val) {
asm volatile("{st.shared.b32 [%0], %1;}" : : "r"(__addr), "r"(__val) : "memory");
}

#if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__)
#define __PTR "l"
#else
#define __PTR "r"
#endif

// CHECK: inline void store_streaming_short4(sycl::short4 *addr, short x, short y, short z, short w) {
// CHECK-NEXT: *((int16_t *)((uintptr_t)addr) + 0) = x;
// CHECK-NEXT: *((int16_t *)((uintptr_t)addr) + 1) = y;
// CHECK-NEXT: *((int16_t *)((uintptr_t)addr) + 2) = z;
// CHECK-NEXT: *((int16_t *)((uintptr_t)addr) + 3) = w;
// CHECK-NEXT: }
__device__ inline void store_streaming_short4(short4 *addr, short x, short y, short z, short w) {
asm("st.cs.global.v4.s16 [%0+0], {%1, %2, %3, %4};" ::__PTR(addr), "h"(x), "h"(y), "h"(z), "h"(w));
}

// CHECK: inline void store_streaming_short2(sycl::short2 *addr, short x, short y) {
// CHECK-NEXT: *((int16_t *)((uintptr_t)addr) + 0) = x;
// CHECK-NEXT: *((int16_t *)((uintptr_t)addr) + 1) = y;
// CHECK-NEXT: }
__device__ inline void store_streaming_short2(short2 *addr, short x, short y) {
asm("st.cs.global.v2.s16 [%0+0], {%1, %2};" ::__PTR(addr), "h"(x), "h"(y));
}

// clang-format on