Skip to content

Commit 9f3931b

Browse files
authored
[AMDGPU] Fold fmed3 when inputs include infinity (#144824)
1 parent 4785832 commit 9f3931b

File tree

2 files changed

+124
-13
lines changed

2 files changed

+124
-13
lines changed

llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,14 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
10311031
// s1 _nan: min(s0, s2)
10321032
// s2 _nan: min(s0, s1)
10331033

1034+
// med3 behavior with infinity
1035+
// s0 +inf: max(s1, s2)
1036+
// s1 +inf: max(s0, s2)
1037+
// s2 +inf: max(s0, s1)
1038+
// s0 -inf: min(s1, s2)
1039+
// s1 -inf: min(s0, s2)
1040+
// s2 -inf: min(s0, s1)
1041+
10341042
// Checking for NaN before canonicalization provides better fidelity when
10351043
// mapping other operations onto fmed3 since the order of operands is
10361044
// unchanged.
@@ -1039,51 +1047,64 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
10391047
const APFloat *ConstSrc1 = nullptr;
10401048
const APFloat *ConstSrc2 = nullptr;
10411049

1042-
// TODO: Also can fold to 2 operands with infinities.
1043-
if ((match(Src0, m_APFloat(ConstSrc0)) && ConstSrc0->isNaN()) ||
1050+
if ((match(Src0, m_APFloat(ConstSrc0)) &&
1051+
(ConstSrc0->isNaN() || ConstSrc0->isInfinity())) ||
10441052
isa<UndefValue>(Src0)) {
1053+
const bool IsPosInfinity = ConstSrc0 && ConstSrc0->isPosInfinity();
10451054
switch (fpenvIEEEMode(II)) {
10461055
case KnownIEEEMode::On:
10471056
// TODO: If Src2 is snan, does it need quieting?
1048-
if (ConstSrc0 && ConstSrc0->isSignaling())
1057+
if (ConstSrc0 && ConstSrc0->isNaN() && ConstSrc0->isSignaling())
10491058
return IC.replaceInstUsesWith(II, Src2);
1050-
V = IC.Builder.CreateMinNum(Src1, Src2);
1059+
1060+
V = IsPosInfinity ? IC.Builder.CreateMaxNum(Src1, Src2)
1061+
: IC.Builder.CreateMinNum(Src1, Src2);
10511062
break;
10521063
case KnownIEEEMode::Off:
1053-
V = IC.Builder.CreateMinimumNum(Src1, Src2);
1064+
V = IsPosInfinity ? IC.Builder.CreateMaximumNum(Src1, Src2)
1065+
: IC.Builder.CreateMinimumNum(Src1, Src2);
10541066
break;
10551067
case KnownIEEEMode::Unknown:
10561068
break;
10571069
}
1058-
} else if ((match(Src1, m_APFloat(ConstSrc1)) && ConstSrc1->isNaN()) ||
1070+
} else if ((match(Src1, m_APFloat(ConstSrc1)) &&
1071+
(ConstSrc1->isNaN() || ConstSrc1->isInfinity())) ||
10591072
isa<UndefValue>(Src1)) {
1073+
const bool IsPosInfinity = ConstSrc1 && ConstSrc1->isPosInfinity();
10601074
switch (fpenvIEEEMode(II)) {
10611075
case KnownIEEEMode::On:
10621076
// TODO: If Src2 is snan, does it need quieting?
1063-
if (ConstSrc1 && ConstSrc1->isSignaling())
1077+
if (ConstSrc1 && ConstSrc1->isNaN() && ConstSrc1->isSignaling())
10641078
return IC.replaceInstUsesWith(II, Src2);
10651079

1066-
V = IC.Builder.CreateMinNum(Src0, Src2);
1080+
V = IsPosInfinity ? IC.Builder.CreateMaxNum(Src0, Src2)
1081+
: IC.Builder.CreateMinNum(Src0, Src2);
10671082
break;
10681083
case KnownIEEEMode::Off:
1069-
V = IC.Builder.CreateMinimumNum(Src0, Src2);
1084+
V = IsPosInfinity ? IC.Builder.CreateMaximumNum(Src0, Src2)
1085+
: IC.Builder.CreateMinimumNum(Src0, Src2);
10701086
break;
10711087
case KnownIEEEMode::Unknown:
10721088
break;
10731089
}
1074-
} else if ((match(Src2, m_APFloat(ConstSrc2)) && ConstSrc2->isNaN()) ||
1090+
} else if ((match(Src2, m_APFloat(ConstSrc2)) &&
1091+
(ConstSrc2->isNaN() || ConstSrc2->isInfinity())) ||
10751092
isa<UndefValue>(Src2)) {
10761093
switch (fpenvIEEEMode(II)) {
10771094
case KnownIEEEMode::On:
1078-
if (ConstSrc2 && ConstSrc2->isSignaling()) {
1095+
if (ConstSrc2 && ConstSrc2->isNaN() && ConstSrc2->isSignaling()) {
10791096
auto *Quieted = ConstantFP::get(II.getType(), ConstSrc2->makeQuiet());
10801097
return IC.replaceInstUsesWith(II, Quieted);
10811098
}
10821099

1083-
V = IC.Builder.CreateMinNum(Src0, Src1);
1100+
V = (ConstSrc2 && ConstSrc2->isPosInfinity())
1101+
? IC.Builder.CreateMaxNum(Src0, Src1)
1102+
: IC.Builder.CreateMinNum(Src0, Src1);
10841103
break;
10851104
case KnownIEEEMode::Off:
1086-
V = IC.Builder.CreateMaximumNum(Src0, Src1);
1105+
V = (ConstSrc2 && ConstSrc2->isNegInfinity())
1106+
? IC.Builder.CreateMinimumNum(Src0, Src1)
1107+
: IC.Builder.CreateMaximumNum(Src0, Src1);
10871108
break;
10881109
case KnownIEEEMode::Unknown:
10891110
break;

llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,96 @@ define float @fmed3_neg2_3_snan1_f32(float %x, float %y) #1 {
521521
ret float %med3
522522
}
523523

524+
define float @fmed3_inf_x_y_f32(float %x, float %y) #1 {
525+
; IEEE1-LABEL: define float @fmed3_inf_x_y_f32(
526+
; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
527+
; IEEE1-NEXT: [[MED3:%.*]] = call float @llvm.maxnum.f32(float [[X]], float [[Y]])
528+
; IEEE1-NEXT: ret float [[MED3]]
529+
;
530+
; IEEE0-LABEL: define float @fmed3_inf_x_y_f32(
531+
; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
532+
; IEEE0-NEXT: [[MED3:%.*]] = call float @llvm.maximumnum.f32(float [[X]], float [[Y]])
533+
; IEEE0-NEXT: ret float [[MED3]]
534+
;
535+
%med3 = call float @llvm.amdgcn.fmed3.f32(float 0x7FF0000000000000, float %x, float %y)
536+
ret float %med3
537+
}
538+
539+
define float @fmed3_x_inf_y_f32(float %x, float %y) #1 {
540+
; IEEE1-LABEL: define float @fmed3_x_inf_y_f32(
541+
; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
542+
; IEEE1-NEXT: [[MED3:%.*]] = call float @llvm.maxnum.f32(float [[X]], float [[Y]])
543+
; IEEE1-NEXT: ret float [[MED3]]
544+
;
545+
; IEEE0-LABEL: define float @fmed3_x_inf_y_f32(
546+
; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
547+
; IEEE0-NEXT: [[MED3:%.*]] = call float @llvm.maximumnum.f32(float [[X]], float [[Y]])
548+
; IEEE0-NEXT: ret float [[MED3]]
549+
;
550+
%med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float 0x7FF0000000000000, float %y)
551+
ret float %med3
552+
}
553+
554+
define float @fmed3_x_y_inf_f32(float %x, float %y) #1 {
555+
; IEEE1-LABEL: define float @fmed3_x_y_inf_f32(
556+
; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
557+
; IEEE1-NEXT: [[MED3:%.*]] = call float @llvm.maxnum.f32(float [[X]], float [[Y]])
558+
; IEEE1-NEXT: ret float [[MED3]]
559+
;
560+
; IEEE0-LABEL: define float @fmed3_x_y_inf_f32(
561+
; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
562+
; IEEE0-NEXT: [[MED3:%.*]] = call float @llvm.maximumnum.f32(float [[X]], float [[Y]])
563+
; IEEE0-NEXT: ret float [[MED3]]
564+
;
565+
%med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float %y, float 0x7FF0000000000000)
566+
ret float %med3
567+
}
568+
569+
define float @fmed3_ninf_x_y_f32(float %x, float %y) #1 {
570+
; IEEE1-LABEL: define float @fmed3_ninf_x_y_f32(
571+
; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
572+
; IEEE1-NEXT: [[MED3:%.*]] = call float @llvm.minnum.f32(float [[X]], float [[Y]])
573+
; IEEE1-NEXT: ret float [[MED3]]
574+
;
575+
; IEEE0-LABEL: define float @fmed3_ninf_x_y_f32(
576+
; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
577+
; IEEE0-NEXT: [[MED3:%.*]] = call float @llvm.minimumnum.f32(float [[X]], float [[Y]])
578+
; IEEE0-NEXT: ret float [[MED3]]
579+
;
580+
%med3 = call float @llvm.amdgcn.fmed3.f32(float 0xFFF0000000000000, float %x, float %y)
581+
ret float %med3
582+
}
583+
584+
define float @fmed3_x_ninf_y_f32(float %x, float %y) #1 {
585+
; IEEE1-LABEL: define float @fmed3_x_ninf_y_f32(
586+
; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
587+
; IEEE1-NEXT: [[MED3:%.*]] = call float @llvm.minnum.f32(float [[X]], float [[Y]])
588+
; IEEE1-NEXT: ret float [[MED3]]
589+
;
590+
; IEEE0-LABEL: define float @fmed3_x_ninf_y_f32(
591+
; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
592+
; IEEE0-NEXT: [[MED3:%.*]] = call float @llvm.minimumnum.f32(float [[X]], float [[Y]])
593+
; IEEE0-NEXT: ret float [[MED3]]
594+
;
595+
%med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float 0xFFF0000000000000, float %y)
596+
ret float %med3
597+
}
598+
599+
define float @fmed3_x_y_ninf_f32(float %x, float %y) #1 {
600+
; IEEE1-LABEL: define float @fmed3_x_y_ninf_f32(
601+
; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
602+
; IEEE1-NEXT: [[MED3:%.*]] = call float @llvm.minnum.f32(float [[X]], float [[Y]])
603+
; IEEE1-NEXT: ret float [[MED3]]
604+
;
605+
; IEEE0-LABEL: define float @fmed3_x_y_ninf_f32(
606+
; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
607+
; IEEE0-NEXT: [[MED3:%.*]] = call float @llvm.minimumnum.f32(float [[X]], float [[Y]])
608+
; IEEE0-NEXT: ret float [[MED3]]
609+
;
610+
%med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float %y, float 0xFFF0000000000000)
611+
ret float %med3
612+
}
613+
524614
; --------------------------------------------------------------------
525615
; llvm.amdgcn.fmed3 with default mode implied by shader CC
526616
; --------------------------------------------------------------------

0 commit comments

Comments
 (0)