Skip to content

Commit 9404c5b

Browse files
committed
[CIR] Data member pointer comparison and casts
This patch adds CIRGen and LLVM lowering support for the following language features related to pointers to data members: - Comparisons between pointers to data members. - Casting from pointers to data members to boolean. - Reinterpret casts between pointers to data members.
1 parent 04d7dcf commit 9404c5b

File tree

9 files changed

+210
-15
lines changed

9 files changed

+210
-15
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

+2-1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def CK_FloatComplexToIntegralComplex
123123
def CK_IntegralComplexCast : I32EnumAttrCase<"int_complex", 23>;
124124
def CK_IntegralComplexToFloatComplex
125125
: I32EnumAttrCase<"int_complex_to_float_complex", 24>;
126+
def CK_MemberPtrToBoolean : I32EnumAttrCase<"member_ptr_to_bool", 25>;
126127

127128
def CastKind : I32EnumAttr<
128129
"CastKind",
@@ -135,7 +136,7 @@ def CastKind : I32EnumAttr<
135136
CK_FloatComplexToReal, CK_IntegralComplexToReal, CK_FloatComplexToBoolean,
136137
CK_IntegralComplexToBoolean, CK_FloatComplexCast,
137138
CK_FloatComplexToIntegralComplex, CK_IntegralComplexCast,
138-
CK_IntegralComplexToFloatComplex]> {
139+
CK_IntegralComplexToFloatComplex, CK_MemberPtrToBoolean]> {
139140
let cppNamespace = "::cir";
140141
}
141142

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

+17-5
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,12 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
932932
};
933933

934934
if (const MemberPointerType *MPT = LHSTy->getAs<MemberPointerType>()) {
935-
assert(0 && "not implemented");
935+
assert(E->getOpcode() == BO_EQ || E->getOpcode() == BO_NE);
936+
mlir::Value lhs = CGF.emitScalarExpr(E->getLHS());
937+
mlir::Value rhs = CGF.emitScalarExpr(E->getRHS());
938+
cir::CmpOpKind kind = ClangCmpToCIRCmp(E->getOpcode());
939+
Result =
940+
Builder.createCompare(CGF.getLoc(E->getExprLoc()), kind, lhs, rhs);
936941
} else if (!LHSTy->isAnyComplexType() && !RHSTy->isAnyComplexType()) {
937942
BinOpInfo BOInfo = emitBinOps(E);
938943
mlir::Value LHS = BOInfo.LHS;
@@ -1741,8 +1746,11 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
17411746
auto Ty = mlir::cast<cir::DataMemberType>(CGF.getCIRType(DestTy));
17421747
return Builder.getNullDataMemberPtr(Ty, CGF.getLoc(E->getExprLoc()));
17431748
}
1744-
case CK_ReinterpretMemberPointer:
1745-
llvm_unreachable("NYI");
1749+
case CK_ReinterpretMemberPointer: {
1750+
mlir::Value src = Visit(E);
1751+
return Builder.createBitcast(CGF.getLoc(E->getExprLoc()), src,
1752+
CGF.getCIRType(DestTy));
1753+
}
17461754
case CK_BaseToDerivedMemberPointer:
17471755
case CK_DerivedToBaseMemberPointer: {
17481756
mlir::Value src = Visit(E);
@@ -1875,8 +1883,12 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
18751883
return emitPointerToBoolConversion(Visit(E), E->getType());
18761884
case CK_FloatingToBoolean:
18771885
return emitFloatToBoolConversion(Visit(E), CGF.getLoc(E->getExprLoc()));
1878-
case CK_MemberPointerToBoolean:
1879-
llvm_unreachable("NYI");
1886+
case CK_MemberPointerToBoolean: {
1887+
mlir::Value memPtr = Visit(E);
1888+
return Builder.createCast(CGF.getLoc(CE->getSourceRange()),
1889+
cir::CastKind::member_ptr_to_bool, memPtr,
1890+
ConvertType(DestTy));
1891+
}
18801892
case CK_FloatingComplexToReal:
18811893
case CK_IntegralComplexToReal:
18821894
case CK_FloatingComplexToBoolean:

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,11 @@ LogicalResult cir::CastOp::verify() {
529529
return success();
530530
}
531531

532+
// Handle the data member pointer types.
533+
if (mlir::isa<cir::DataMemberType>(srcType) &&
534+
mlir::isa<cir::DataMemberType>(resType))
535+
return success();
536+
532537
// This is the only cast kind where we don't want vector types to decay
533538
// into the element type.
534539
if ((!mlir::isa<cir::VectorType>(getSrc().getType()) ||
@@ -704,6 +709,13 @@ LogicalResult cir::CastOp::verify() {
704709
<< "requires !cir.complex<!cir.float> type for result";
705710
return success();
706711
}
712+
case cir::CastKind::member_ptr_to_bool: {
713+
if (!mlir::isa<cir::DataMemberType>(srcType))
714+
return emitOpError() << "requires !cir.data_member type for source";
715+
if (!mlir::isa<cir::BoolType>(resType))
716+
return emitOpError() << "requires !cir.bool type for result";
717+
return success();
718+
}
707719
}
708720

709721
llvm_unreachable("Unknown CastOp kind?");

clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h

+13
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,19 @@ class CIRCXXABI {
9797
virtual mlir::Value
9898
lowerDerivedDataMember(cir::DerivedDataMemberOp op, mlir::Value loweredSrc,
9999
mlir::OpBuilder &builder) const = 0;
100+
101+
virtual mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs,
102+
mlir::Value loweredRhs,
103+
mlir::OpBuilder &builder) const = 0;
104+
105+
virtual mlir::Value
106+
lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
107+
mlir::Value loweredSrc,
108+
mlir::OpBuilder &builder) const = 0;
109+
110+
virtual mlir::Value
111+
lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
112+
mlir::OpBuilder &builder) const = 0;
100113
};
101114

102115
/// Creates an Itanium-family ABI.

clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp

+48-4
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,18 @@ class ItaniumCXXABI : public CIRCXXABI {
7373
mlir::Value lowerDerivedDataMember(cir::DerivedDataMemberOp op,
7474
mlir::Value loweredSrc,
7575
mlir::OpBuilder &builder) const override;
76+
77+
mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs,
78+
mlir::Value loweredRhs,
79+
mlir::OpBuilder &builder) const override;
80+
81+
mlir::Value lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
82+
mlir::Value loweredSrc,
83+
mlir::OpBuilder &builder) const override;
84+
85+
mlir::Value
86+
lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
87+
mlir::OpBuilder &builder) const override;
7688
};
7789

7890
} // namespace
@@ -89,18 +101,23 @@ bool ItaniumCXXABI::classifyReturnType(LowerFunctionInfo &FI) const {
89101
return false;
90102
}
91103

92-
mlir::Type ItaniumCXXABI::lowerDataMemberType(
93-
cir::DataMemberType type, const mlir::TypeConverter &typeConverter) const {
104+
static mlir::Type getABITypeForDataMember(LowerModule &lowerMod) {
94105
// Itanium C++ ABI 2.3:
95106
// A pointer to data member is an offset from the base address of
96107
// the class object containing it, represented as a ptrdiff_t
97-
const clang::TargetInfo &target = LM.getTarget();
108+
const clang::TargetInfo &target = lowerMod.getTarget();
98109
clang::TargetInfo::IntType ptrdiffTy =
99110
target.getPtrDiffType(clang::LangAS::Default);
100-
return cir::IntType::get(type.getContext(), target.getTypeWidth(ptrdiffTy),
111+
return cir::IntType::get(lowerMod.getMLIRContext(),
112+
target.getTypeWidth(ptrdiffTy),
101113
target.isTypeSigned(ptrdiffTy));
102114
}
103115

116+
mlir::Type ItaniumCXXABI::lowerDataMemberType(
117+
cir::DataMemberType type, const mlir::TypeConverter &typeConverter) const {
118+
return getABITypeForDataMember(LM);
119+
}
120+
104121
mlir::TypedAttr ItaniumCXXABI::lowerDataMemberConstant(
105122
cir::DataMemberAttr attr, const mlir::DataLayout &layout,
106123
const mlir::TypeConverter &typeConverter) const {
@@ -175,6 +192,33 @@ ItaniumCXXABI::lowerDerivedDataMember(cir::DerivedDataMemberOp op,
175192
/*isDerivedToBase=*/false, builder);
176193
}
177194

195+
mlir::Value ItaniumCXXABI::lowerDataMemberCmp(cir::CmpOp op,
196+
mlir::Value loweredLhs,
197+
mlir::Value loweredRhs,
198+
mlir::OpBuilder &builder) const {
199+
return builder.create<cir::CmpOp>(op.getLoc(), op.getKind(), loweredLhs,
200+
loweredRhs);
201+
}
202+
203+
mlir::Value
204+
ItaniumCXXABI::lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
205+
mlir::Value loweredSrc,
206+
mlir::OpBuilder &builder) const {
207+
return builder.create<cir::CastOp>(op.getLoc(), loweredDstTy,
208+
cir::CastKind::bitcast, loweredSrc);
209+
}
210+
211+
mlir::Value
212+
ItaniumCXXABI::lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
213+
mlir::OpBuilder &builder) const {
214+
// Itanium C++ ABI 2.3:
215+
// A NULL pointer is represented as -1.
216+
auto nullAttr = cir::IntAttr::get(getABITypeForDataMember(LM), -1);
217+
auto nullValue = builder.create<cir::ConstantOp>(op.getLoc(), nullAttr);
218+
return builder.create<cir::CmpOp>(op.getLoc(), cir::CmpOpKind::ne, loweredSrc,
219+
nullValue);
220+
}
221+
178222
CIRCXXABI *CreateItaniumCXXABI(LowerModule &LM) {
179223
switch (LM.getCXXABIKind()) {
180224
// Note that AArch64 uses the generic ItaniumCXXABI class since it doesn't

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+32-3
Original file line numberDiff line numberDiff line change
@@ -1179,8 +1179,18 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
11791179
}
11801180
case cir::CastKind::bitcast: {
11811181
auto dstTy = castOp.getType();
1182-
auto llvmSrcVal = adaptor.getOperands().front();
11831182
auto llvmDstTy = getTypeConverter()->convertType(dstTy);
1183+
1184+
if (mlir::isa<cir::DataMemberType>(castOp.getSrc().getType())) {
1185+
mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberBitcast(
1186+
castOp, llvmDstTy, src, rewriter);
1187+
rewriter.replaceOp(castOp, loweredResult);
1188+
return mlir::success();
1189+
}
1190+
if (mlir::isa<cir::MethodType>(castOp.getSrc().getType()))
1191+
llvm_unreachable("NYI");
1192+
1193+
auto llvmSrcVal = adaptor.getOperands().front();
11841194
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy,
11851195
llvmSrcVal);
11861196
return mlir::success();
@@ -1204,6 +1214,16 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
12041214
llvmSrcVal);
12051215
break;
12061216
}
1217+
case cir::CastKind::member_ptr_to_bool: {
1218+
mlir::Value loweredResult;
1219+
if (mlir::isa<cir::MethodType>(castOp.getSrc().getType()))
1220+
llvm_unreachable("NYI");
1221+
else
1222+
loweredResult = lowerMod->getCXXABI().lowerDataMemberToBoolCast(
1223+
castOp, src, rewriter);
1224+
rewriter.replaceOp(castOp, loweredResult);
1225+
break;
1226+
}
12071227
default: {
12081228
return castOp.emitError("Unhandled cast kind: ")
12091229
<< castOp.getKindAttrName();
@@ -2748,6 +2768,15 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
27482768
cir::CmpOp cmpOp, OpAdaptor adaptor,
27492769
mlir::ConversionPatternRewriter &rewriter) const {
27502770
auto type = cmpOp.getLhs().getType();
2771+
2772+
if (mlir::isa<cir::DataMemberType>(type)) {
2773+
assert(lowerMod && "lowering module is not available");
2774+
mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberCmp(
2775+
cmpOp, adaptor.getLhs(), adaptor.getRhs(), rewriter);
2776+
rewriter.replaceOp(cmpOp, loweredResult);
2777+
return mlir::success();
2778+
}
2779+
27512780
mlir::Value llResult;
27522781

27532782
// Lower to LLVM comparison op.
@@ -3963,6 +3992,8 @@ void populateCIRToLLVMConversionPatterns(
39633992
patterns.add<
39643993
// clang-format off
39653994
CIRToLLVMBaseDataMemberOpLowering,
3995+
CIRToLLVMCastOpLowering,
3996+
CIRToLLVMCmpOpLowering,
39663997
CIRToLLVMConstantOpLowering,
39673998
CIRToLLVMDerivedDataMemberOpLowering,
39683999
CIRToLLVMGetRuntimeMemberOpLowering,
@@ -3994,10 +4025,8 @@ void populateCIRToLLVMConversionPatterns(
39944025
CIRToLLVMBrOpLowering,
39954026
CIRToLLVMByteswapOpLowering,
39964027
CIRToLLVMCallOpLowering,
3997-
CIRToLLVMCastOpLowering,
39984028
CIRToLLVMCatchParamOpLowering,
39994029
CIRToLLVMClearCacheOpLowering,
4000-
CIRToLLVMCmpOpLowering,
40014030
CIRToLLVMCmpThreeWayOpLowering,
40024031
CIRToLLVMComplexCreateOpLowering,
40034032
CIRToLLVMComplexImagOpLowering,

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

+16-2
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,17 @@ class CIRToLLVMBrCondOpLowering
216216
};
217217

218218
class CIRToLLVMCastOpLowering : public mlir::OpConversionPattern<cir::CastOp> {
219+
cir::LowerModule *lowerMod;
220+
219221
mlir::Type convertTy(mlir::Type ty) const;
220222

221223
public:
222-
using mlir::OpConversionPattern<cir::CastOp>::OpConversionPattern;
224+
CIRToLLVMCastOpLowering(const mlir::TypeConverter &typeConverter,
225+
mlir::MLIRContext *context,
226+
cir::LowerModule *lowerModule)
227+
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {
228+
setHasBoundedRewriteRecursion();
229+
}
223230

224231
mlir::LogicalResult
225232
matchAndRewrite(cir::CastOp op, OpAdaptor,
@@ -615,8 +622,15 @@ class CIRToLLVMShiftOpLowering
615622
};
616623

617624
class CIRToLLVMCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
625+
cir::LowerModule *lowerMod;
626+
618627
public:
619-
using mlir::OpConversionPattern<cir::CmpOp>::OpConversionPattern;
628+
CIRToLLVMCmpOpLowering(const mlir::TypeConverter &typeConverter,
629+
mlir::MLIRContext *context,
630+
cir::LowerModule *lowerModule)
631+
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {
632+
setHasBoundedRewriteRecursion();
633+
}
620634

621635
mlir::LogicalResult
622636
matchAndRewrite(cir::CmpOp op, OpAdaptor,

clang/test/CIR/CodeGen/pointer-to-data-member-cast.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,29 @@ auto derived_to_base_zero_offset(int Derived::*ptr) -> int Base1::* {
7474
// LLVM-NEXT: %[[#ret:]] = load i64, ptr %[[#ret_slot]]
7575
// LLVM-NEXT: ret i64 %[[#ret]]
7676
}
77+
78+
struct Foo {
79+
int a;
80+
};
81+
82+
struct Bar {
83+
int a;
84+
};
85+
86+
bool to_bool(int Foo::*x) {
87+
return x;
88+
}
89+
90+
// CIR-LABEL: @_Z7to_boolM3Fooi
91+
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
92+
// CIR-NEXT: %{{.+}} = cir.cast(member_ptr_to_bool, %[[#x]] : !cir.data_member<!s32i in !ty_Foo>), !cir.bool
93+
// CIR: }
94+
95+
auto bitcast(int Foo::*x) {
96+
return reinterpret_cast<int Bar::*>(x);
97+
}
98+
99+
// CIR-LABEL: @_Z7bitcastM3Fooi
100+
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
101+
// CIR-NEXT: %{{.+}} = cir.cast(bitcast, %[[#x]] : !cir.data_member<!s32i in !ty_Foo>), !cir.data_member<!s32i in !ty_Bar>
102+
// CIR: }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -fclangir -emit-cir %s -o %t.cir
2+
// RUN: FileCheck --input-file=%t.cir --check-prefix=CIR %s
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -fclangir -emit-llvm %s -o %t.ll
4+
// RUN: FileCheck --input-file=%t.ll --check-prefix=LLVM %s
5+
6+
struct Foo {
7+
int a;
8+
};
9+
10+
struct Bar {
11+
int a;
12+
};
13+
14+
bool eq(int Foo::*x, int Foo::*y) {
15+
return x == y;
16+
}
17+
18+
// CIR-LABEL: @_Z2eqM3FooiS0_
19+
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
20+
// CIR-NEXT: %[[#y:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
21+
// CIR-NEXT: %{{.+}} = cir.cmp(eq, %[[#x]], %[[#y]]) : !cir.data_member<!s32i in !ty_Foo>, !cir.bool
22+
// CIR: }
23+
24+
// LLVM-LABEL: @_Z2eqM3FooiS0_
25+
// LLVM: %[[#x:]] = load i64, ptr %{{.+}}, align 8
26+
// LLVM-NEXT: %[[#y:]] = load i64, ptr %{{.+}}, align 8
27+
// LLVM-NEXT: %{{.+}} = icmp eq i64 %[[#x]], %[[#y]]
28+
// LLVM: }
29+
30+
bool ne(int Foo::*x, int Foo::*y) {
31+
return x != y;
32+
}
33+
34+
// CIR-LABEL: @_Z2neM3FooiS0_
35+
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
36+
// CIR-NEXT: %[[#y:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
37+
// CIR-NEXT: %{{.+}} = cir.cmp(ne, %[[#x]], %[[#y]]) : !cir.data_member<!s32i in !ty_Foo>, !cir.bool
38+
// CIR: }
39+
40+
// LLVM-LABEL: @_Z2neM3FooiS0_
41+
// LLVM: %[[#x:]] = load i64, ptr %{{.+}}, align 8
42+
// LLVM-NEXT: %[[#y:]] = load i64, ptr %{{.+}}, align 8
43+
// LLVM-NEXT: %{{.+}} = icmp ne i64 %[[#x]], %[[#y]]
44+
// LLVM: }

0 commit comments

Comments
 (0)