Skip to content

[SER] GetAttributes(out udt) instead of templated return #7606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/dxc/HLSL/HLOperations.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,9 @@ const unsigned kHitObjectInvoke_PayloadOpIdx = 2;
const unsigned kHitObjectFromRayQuery_WithAttrs_AttributeOpIdx = 4;
const unsigned kHitObjectFromRayQuery_WithAttrs_NumOp = 5;

// HitObject::GetAttributes
const unsigned kHitObjectGetAttributes_AttributeOpIdx = 2;

// Linear Algebra Operations

// MatVecMul
Expand Down
17 changes: 5 additions & 12 deletions lib/HLSL/HLOperationLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6378,18 +6378,11 @@ Value *TranslateHitObjectGetAttributes(CallInst *CI, IntrinsicOp IOP,

Value *HitObjectPtr = CI->getArgOperand(1);
Value *HitObject = Builder.CreateLoad(HitObjectPtr);

Type *AttrTy = cast<PointerType>(CI->getType())->getPointerElementType();

IRBuilder<> EntryBuilder(
dxilutil::FindAllocaInsertionPt(CI->getParent()->getParent()));
unsigned AttrAlign = Helper.dataLayout.getABITypeAlignment(AttrTy);
AllocaInst *AttrMem = EntryBuilder.CreateAlloca(AttrTy);
AttrMem->setAlignment(AttrAlign);
Constant *opArg = OP->GetU32Const((unsigned)OpCode);
TrivialDxilOperation(OpCode, {opArg, HitObject, AttrMem}, CI->getType(),
Helper.voidTy, OP, Builder);
return AttrMem;
Value *AttrOutPtr =
CI->getArgOperand(HLOperandIndex::kHitObjectGetAttributes_AttributeOpIdx);
TrivialDxilOperation(OpCode, {nullptr, HitObject, AttrOutPtr},
AttrOutPtr->getType(), CI, OP);
return nullptr;
}

Value *TranslateHitObjectScalarGetter(CallInst *CI, IntrinsicOp IOP,
Expand Down
4 changes: 4 additions & 0 deletions lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1518,6 +1518,10 @@ static bool isUDTIntrinsicArg(CallInst *CI, unsigned OpIdx) {
if (OpIdx == HLOperandIndex::kHitObjectInvoke_PayloadOpIdx)
return true;
break;
case IntrinsicOp::MOP_DxHitObject_GetAttributes:
if (OpIdx == HLOperandIndex::kHitObjectGetAttributes_AttributeOpIdx)
return true;
break;
default:
break;
}
Expand Down
5 changes: 1 addition & 4 deletions tools/clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -3806,8 +3806,7 @@ class Sema {
void DiagnoseHLSLDeclAttr(const Decl *D, const Attr *A);
void DiagnoseCoherenceMismatch(const Expr *SrcExpr, QualType TargetType,
SourceLocation Loc);
void CheckHLSLFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall,
const FunctionProtoType *Proto);
void CheckHLSLFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall);
void DiagnoseReachableHLSLCall(CallExpr *CE, const hlsl::ShaderModel *SM,
hlsl::DXIL::ShaderKind EntrySK,
hlsl::DXIL::NodeLaunchType NodeLaunchTy,
Expand Down Expand Up @@ -8826,8 +8825,6 @@ class Sema {
bool AllowOnePastEnd=true, bool IndexNegated=false);
// HLSL Change Starts - checking array subscript access to vector or matrix member
void CheckHLSLArrayAccess(const Expr *expr);
bool CheckHLSLIntrinsicCall(FunctionDecl *FDecl, CallExpr *TheCall);
bool CheckHLSLFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall);
// HLSL Change ends
void CheckArrayAccess(const Expr *E);
// Used to grab the relevant information from a FormatAttr and a
Expand Down
2 changes: 1 addition & 1 deletion tools/clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ bool Sema::CheckFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall,
CheckMemaccessArguments(TheCall, CMId, FnInfo);
#endif // HLSL Change Ends

CheckHLSLFunctionCall(FDecl, TheCall, Proto); // HLSL Change
CheckHLSLFunctionCall(FDecl, TheCall); // HLSL Change

return false;
}
Expand Down
2 changes: 0 additions & 2 deletions tools/clang/lib/Sema/SemaExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5349,8 +5349,6 @@ Sema::BuildResolvedCallExpr(Expr *Fn, NamedDecl *NDecl,
if (FDecl) {
if (CheckFunctionCall(FDecl, TheCall, Proto))
return ExprError();
if (CheckHLSLFunctionCall(FDecl, TheCall))
return ExprError();
if (BuiltinID)
return CheckBuiltinFunctionCall(FDecl, BuiltinID, TheCall);
} else if (NDecl) {
Expand Down
236 changes: 107 additions & 129 deletions tools/clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10792,18 +10792,24 @@ HLSLExternalSource::ApplyTypeSpecSignToParsedType(clang::QualType &type,
}
}

bool DiagnoseIntersectionAttributes(Sema &S, SourceLocation Loc, QualType Ty) {
// Must be a UDT
bool CheckIntersectionAttributeArg(Sema &S, Expr *E) {
SourceLocation Loc = E->getExprLoc();
QualType Ty = E->getType();

// Identify problematic fields first (high diagnostic accuracy, may miss some
// invalid cases)
const TypeDiagContext DiagContext = TypeDiagContext::Attributes;
if (DiagnoseTypeElements(S, Loc, Ty, DiagContext, DiagContext))
return true;

// Must be a UDT (low diagnostic accuracy, catches remaining invalid cases)
if (Ty.isNull() || !hlsl::IsHLSLCopyableAnnotatableRecord(Ty)) {
S.Diag(Loc, diag::err_payload_attrs_must_be_udt)
<< /*payload|attributes|callable*/ 1 << /*parameter %2|type*/ 1;
return false;
return true;
}

const TypeDiagContext DiagContext = TypeDiagContext::Attributes;
if (DiagnoseTypeElements(S, Loc, Ty, DiagContext, DiagContext))
return false;
return true;
return false;
}

Sema::TemplateDeductionResult
Expand Down Expand Up @@ -10914,7 +10920,6 @@ HLSLExternalSource::DeduceTemplateArgumentsForHLSL(
LPCSTR tableName = cursor.GetTableName();
// Currently only intrinsic we allow for explicit template arguments are
// for Load/Store for ByteAddressBuffer/RWByteAddressBuffer
// and HitObject::GetAttributes with user-defined intersection attributes.

// Check Explicit template arguments
UINT intrinsicOp = (*cursor)->Op;
Expand All @@ -10929,11 +10934,9 @@ HLSLExternalSource::DeduceTemplateArgumentsForHLSL(
IsBABLoad = intrinsicOp == (UINT)IntrinsicOp::MOP_Load;
IsBABStore = intrinsicOp == (UINT)IntrinsicOp::MOP_Store;
}
bool IsHitObjectGetAttributes =
intrinsicOp == (UINT)IntrinsicOp::MOP_DxHitObject_GetAttributes;
if (ExplicitTemplateArgs && ExplicitTemplateArgs->size() >= 1) {
SourceLocation Loc = ExplicitTemplateArgs->getLAngleLoc();
if (!IsBABLoad && !IsBABStore && !IsHitObjectGetAttributes) {
if (!IsBABLoad && !IsBABStore) {
getSema()->Diag(Loc, diag::err_hlsl_intrinsic_template_arg_unsupported)
<< intrinsicName;
return Sema::TemplateDeductionResult::TDK_Invalid;
Expand Down Expand Up @@ -10963,10 +10966,6 @@ HLSLExternalSource::DeduceTemplateArgumentsForHLSL(
return Sema::TemplateDeductionResult::TDK_Invalid;
}
}
if (IsHitObjectGetAttributes &&
!DiagnoseIntersectionAttributes(*getSema(), Loc,
functionTemplateTypeArg))
return Sema::TemplateDeductionResult::TDK_Invalid;
} else if (IsBABStore) {
// Prior to HLSL 2018, Store operation only stored scalar uint.
if (!Is2018) {
Expand Down Expand Up @@ -12240,9 +12239,78 @@ static bool CheckVKBufferPointerCast(Sema &S, FunctionDecl *FD, CallExpr *CE,
}
#endif

static bool isRelatedDeclMarkedNointerpolation(Expr *E) {
if (!E)
return false;
E = E->IgnoreCasts();
if (auto *DRE = dyn_cast<DeclRefExpr>(E))
return DRE->getDecl()->hasAttr<HLSLNoInterpolationAttr>();

if (auto *ME = dyn_cast<MemberExpr>(E))
return ME->getMemberDecl()->hasAttr<HLSLNoInterpolationAttr>() ||
isRelatedDeclMarkedNointerpolation(ME->getBase());

if (auto *HVE = dyn_cast<HLSLVectorElementExpr>(E))
return isRelatedDeclMarkedNointerpolation(HVE->getBase());

if (auto *ASE = dyn_cast<ArraySubscriptExpr>(E))
return isRelatedDeclMarkedNointerpolation(ASE->getBase());

return false;
}

static bool CheckIntrinsicGetAttributeAtVertex(Sema &S, FunctionDecl *FDecl,
CallExpr *TheCall) {
assert(TheCall->getNumArgs() > 0);
auto argument = TheCall->getArg(0)->IgnoreCasts();

if (!isRelatedDeclMarkedNointerpolation(argument)) {
S.Diag(argument->getExprLoc(), diag::err_hlsl_parameter_requires_attribute)
<< 0 << FDecl->getName() << "nointerpolation";
return true;
}

return false;
}

static bool CheckNoInterpolationParams(Sema &S, FunctionDecl *FDecl,
CallExpr *TheCall) {
// See #hlsl-specs/issues/181. Feature is broken. For SPIR-V we want
// to limit the scope, and fail gracefully in some cases.
if (!S.getLangOpts().SPIRV)
return false;

bool error = false;
for (unsigned i = 0; i < FDecl->getNumParams(); i++) {
assert(i < TheCall->getNumArgs());

if (!FDecl->getParamDecl(i)->hasAttr<HLSLNoInterpolationAttr>())
continue;

if (!isRelatedDeclMarkedNointerpolation(TheCall->getArg(i))) {
S.Diag(TheCall->getArg(i)->getExprLoc(),
diag::err_hlsl_parameter_requires_attribute)
<< i << FDecl->getName() << "nointerpolation";
error = true;
}
}

return error;
}

// Verify that user-defined intrinsic struct args contain no long vectors
static bool CheckUDTIntrinsicArg(Sema &S, Expr *Arg) {
const TypeDiagContext DiagContext =
TypeDiagContext::UserDefinedStructParameter;
return DiagnoseTypeElements(S, Arg->getExprLoc(), Arg->getType(), DiagContext,
DiagContext);
}

// Check HLSL call constraints, not fatal to creating the AST.
void Sema::CheckHLSLFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall,
const FunctionProtoType *Proto) {
void Sema::CheckHLSLFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall) {
if (CheckNoInterpolationParams(*this, FDecl, TheCall))
return;

HLSLIntrinsicAttr *IntrinsicAttr = FDecl->getAttr<HLSLIntrinsicAttr>();
if (!IntrinsicAttr)
return;
Expand Down Expand Up @@ -12270,6 +12338,28 @@ void Sema::CheckHLSLFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall,
case hlsl::IntrinsicOp::IOP___builtin_OuterProductAccumulate:
CheckOuterProductAccumulateCall(*this, FDecl, TheCall);
break;
case hlsl::IntrinsicOp::IOP_GetAttributeAtVertex:
// See #hlsl-specs/issues/181. Feature is broken. For SPIR-V we want
// to limit the scope, and fail gracefully in some cases.
if (!getLangOpts().SPIRV)
return;
CheckIntrinsicGetAttributeAtVertex(*this, FDecl, TheCall);
break;
case hlsl::IntrinsicOp::IOP_DispatchMesh:
CheckUDTIntrinsicArg(*this, TheCall->getArg(3)->IgnoreCasts());
break;
case hlsl::IntrinsicOp::IOP_CallShader:
CheckUDTIntrinsicArg(*this, TheCall->getArg(1)->IgnoreCasts());
break;
case hlsl::IntrinsicOp::IOP_TraceRay:
CheckUDTIntrinsicArg(*this, TheCall->getArg(7)->IgnoreCasts());
break;
case hlsl::IntrinsicOp::IOP_ReportHit:
CheckIntersectionAttributeArg(*this, TheCall->getArg(2)->IgnoreCasts());
break;
case hlsl::IntrinsicOp::MOP_DxHitObject_GetAttributes:
CheckIntersectionAttributeArg(*this, TheCall->getArg(0)->IgnoreCasts());
break;
#ifdef ENABLE_SPIRV_CODEGEN
case hlsl::IntrinsicOp::IOP_Vkreinterpret_pointer_cast:
CheckVKBufferPointerCast(*this, FDecl, TheCall, false);
Expand Down Expand Up @@ -16804,118 +16894,6 @@ QualType Sema::getHLSLDefaultSpecialization(TemplateDecl *Decl) {
return QualType();
}

static bool isRelatedDeclMarkedNointerpolation(Expr *E) {
if (!E)
return false;
E = E->IgnoreCasts();
if (auto *DRE = dyn_cast<DeclRefExpr>(E))
return DRE->getDecl()->hasAttr<HLSLNoInterpolationAttr>();

if (auto *ME = dyn_cast<MemberExpr>(E))
return ME->getMemberDecl()->hasAttr<HLSLNoInterpolationAttr>() ||
isRelatedDeclMarkedNointerpolation(ME->getBase());

if (auto *HVE = dyn_cast<HLSLVectorElementExpr>(E))
return isRelatedDeclMarkedNointerpolation(HVE->getBase());

if (auto *ASE = dyn_cast<ArraySubscriptExpr>(E))
return isRelatedDeclMarkedNointerpolation(ASE->getBase());

return false;
}

// Verify that user-defined intrinsic struct args contain no long vectors
static bool CheckUDTIntrinsicArg(Sema *S, Expr *Arg) {
const TypeDiagContext DiagContext =
TypeDiagContext::UserDefinedStructParameter;
return DiagnoseTypeElements(*S, Arg->getExprLoc(), Arg->getType(),
DiagContext, DiagContext);
}

static bool CheckIntrinsicGetAttributeAtVertex(Sema *S, FunctionDecl *FDecl,
CallExpr *TheCall) {
assert(TheCall->getNumArgs() > 0);
auto argument = TheCall->getArg(0)->IgnoreCasts();

if (!isRelatedDeclMarkedNointerpolation(argument)) {
S->Diag(argument->getExprLoc(), diag::err_hlsl_parameter_requires_attribute)
<< 0 << FDecl->getName() << "nointerpolation";
return true;
}

return false;
}

bool Sema::CheckHLSLIntrinsicCall(FunctionDecl *FDecl, CallExpr *TheCall) {
auto attr = FDecl->getAttr<HLSLIntrinsicAttr>();

if (!attr)
return false;

if (!IsBuiltinTable(attr->getGroup()))
return false;

switch (hlsl::IntrinsicOp(attr->getOpcode())) {
case hlsl::IntrinsicOp::IOP_GetAttributeAtVertex:
// See #hlsl-specs/issues/181. Feature is broken. For SPIR-V we want
// to limit the scope, and fail gracefully in some cases.
if (!getLangOpts().SPIRV)
return false;
// This should never happen for SPIR-V. But on the DXIL side, extension can
// be added by inserting new intrinsics, meaning opcodes can collide with
// existing ones. See the ExtensionTest.EvalAttributeCollision test.
assert(FDecl->getName() == "GetAttributeAtVertex");
return CheckIntrinsicGetAttributeAtVertex(this, FDecl, TheCall);
case hlsl::IntrinsicOp::IOP_DispatchMesh:
assert(TheCall->getNumArgs() > 3);
assert(FDecl->getName() == "DispatchMesh");
return CheckUDTIntrinsicArg(this, TheCall->getArg(3)->IgnoreCasts());
case hlsl::IntrinsicOp::IOP_CallShader:
assert(TheCall->getNumArgs() > 1);
assert(FDecl->getName() == "CallShader");
return CheckUDTIntrinsicArg(this, TheCall->getArg(1)->IgnoreCasts());
case hlsl::IntrinsicOp::IOP_TraceRay:
assert(TheCall->getNumArgs() > 7);
assert(FDecl->getName() == "TraceRay");
return CheckUDTIntrinsicArg(this, TheCall->getArg(7)->IgnoreCasts());
case hlsl::IntrinsicOp::IOP_ReportHit:
assert(TheCall->getNumArgs() > 2);
assert(FDecl->getName() == "ReportHit");
return CheckUDTIntrinsicArg(this, TheCall->getArg(2)->IgnoreCasts());
default:
break;
}

return false;
}

bool Sema::CheckHLSLFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall) {
if (hlsl::IsIntrinsicOp(FDecl) && CheckHLSLIntrinsicCall(FDecl, TheCall))
return true;

// See #hlsl-specs/issues/181. Feature is broken. For SPIR-V we want
// to limit the scope, and fail gracefully in some cases.
if (!getLangOpts().SPIRV)
return false;

bool error = false;
for (unsigned i = 0; i < FDecl->getNumParams(); i++) {
assert(i < TheCall->getNumArgs());

if (!FDecl->getParamDecl(i)->hasAttr<HLSLNoInterpolationAttr>())
continue;

if (!isRelatedDeclMarkedNointerpolation(TheCall->getArg(i))) {
Diag(TheCall->getArg(i)->getExprLoc(),
diag::err_hlsl_parameter_requires_attribute)
<< i << FDecl->getName() << "nointerpolation";
error = true;
}
}

return error;
}

namespace hlsl {

static bool nodeInputIsCompatible(DXIL::NodeIOKind IOType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ CustomAttrs {
[shader("raygeneration")]
void main() {
dx::HitObject hit;
CustomAttrs attrs = hit.GetAttributes<CustomAttrs>();
CustomAttrs attrs;
hit.GetAttributes(attrs);
float sum = attrs.v.x + attrs.v.y + attrs.v.z + attrs.v.w + attrs.y;
outbuf.Store(0, sum);
}
Loading