@@ -331,7 +331,7 @@ struct LateLowerGCFrame: public FunctionPass, private JuliaPassContext {
331331 }
332332
333333 void LiftPhi (State &S, PHINode *Phi);
334- bool LiftSelect (State &S, SelectInst *SI);
334+ void LiftSelect (State &S, SelectInst *SI);
335335 Value *MaybeExtractScalar (State &S, std::pair<Value*,int > ValExpr, Instruction *InsertBefore);
336336 std::vector<Value*> MaybeExtractVector (State &S, Value *BaseVec, Instruction *InsertBefore);
337337 Value *GetPtrForNumber (State &S, unsigned Num, Instruction *InsertBefore);
@@ -600,12 +600,12 @@ Value *LateLowerGCFrame::GetPtrForNumber(State &S, unsigned Num, Instruction *In
600600 return MaybeExtractScalar (S, std::make_pair (Val, Idx), InsertBefore);
601601}
602602
603- bool LateLowerGCFrame::LiftSelect (State &S, SelectInst *SI) {
603+ void LateLowerGCFrame::LiftSelect (State &S, SelectInst *SI) {
604604 if (isa<PointerType>(SI->getType ()) ?
605605 S.AllPtrNumbering .count (SI) :
606606 S.AllCompositeNumbering .count (SI)) {
607607 // already visited here--nothing to do
608- return true ;
608+ return ;
609609 }
610610 std::vector<int > Numbers;
611611 unsigned NumRoots = 1 ;
@@ -617,68 +617,60 @@ bool LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) {
617617 // find the base root for the arguments
618618 Value *TrueBase = MaybeExtractScalar (S, FindBaseValue (S, SI->getTrueValue (), false ), SI);
619619 Value *FalseBase = MaybeExtractScalar (S, FindBaseValue (S, SI->getFalseValue (), false ), SI);
620- Value *V_null = ConstantPointerNull::get (cast<PointerType>(T_prjlvalue));
621- bool didsplit = false ;
622- if (TrueBase != V_null && FalseBase != V_null) {
623- std::vector<Value*> TrueBases;
624- std::vector<Value*> FalseBases;
625- if (!isa<PointerType>(TrueBase->getType ())) {
626- TrueBases = MaybeExtractVector (S, TrueBase, SI);
627- assert (TrueBases.size () == Numbers.size ());
628- NumRoots = TrueBases.size ();
629- }
630- if (!isa<PointerType>(FalseBase->getType ())) {
631- FalseBases = MaybeExtractVector (S, FalseBase, SI);
632- assert (FalseBases.size () == Numbers.size ());
633- NumRoots = FalseBases.size ();
634- }
635- if (isa<PointerType>(SI->getType ()) ?
636- S.AllPtrNumbering .count (SI) :
637- S.AllCompositeNumbering .count (SI)) {
638- // MaybeExtractScalar or MaybeExtractVector handled this for us (recursively, though a PHINode)
639- return true ;
640- }
641- // need to handle each element (may just be one scalar)
642- for (unsigned i = 0 ; i < NumRoots; ++i) {
643- Value *TrueElem;
644- if (isa<PointerType>(TrueBase->getType ()))
645- TrueElem = TrueBase;
646- else
647- TrueElem = TrueBases[i];
648- Value *FalseElem;
649- if (isa<PointerType>(FalseBase->getType ()))
650- FalseElem = FalseBase;
651- else
652- FalseElem = FalseBases[i];
653- if (TrueElem != V_null || FalseElem != V_null) {
654- Value *Cond = SI->getCondition ();
655- if (isa<VectorType>(Cond->getType ())) {
656- Cond = ExtractElementInst::Create (Cond,
657- ConstantInt::get (Type::getInt32Ty (Cond->getContext ()), i),
658- " " , SI);
659- }
660- SelectInst *SelectBase = SelectInst::Create (Cond, TrueElem, FalseElem, " gclift" , SI);
661- int Number = ++S.MaxPtrNumber ;
662- S.AllPtrNumbering [SelectBase] = Number;
663- S.ReversePtrNumbering [Number] = SelectBase;
664- if (isa<PointerType>(SI->getType ()))
665- S.AllPtrNumbering [SI] = Number;
666- else
667- Numbers[i] = Number;
668- didsplit = true ;
669- }
670- }
671- if (isa<VectorType>(SI->getType ()) && NumRoots != Numbers.size ()) {
672- // broadcast the scalar root number to fill the vector
673- assert (NumRoots == 1 );
674- int Number = Numbers[0 ];
675- Numbers.resize (0 );
676- Numbers.resize (SI->getType ()->getVectorNumElements (), Number);
677- }
620+ std::vector<Value*> TrueBases;
621+ std::vector<Value*> FalseBases;
622+ if (!isa<PointerType>(TrueBase->getType ())) {
623+ TrueBases = MaybeExtractVector (S, TrueBase, SI);
624+ assert (TrueBases.size () == Numbers.size ());
625+ NumRoots = TrueBases.size ();
626+ }
627+ if (!isa<PointerType>(FalseBase->getType ())) {
628+ FalseBases = MaybeExtractVector (S, FalseBase, SI);
629+ assert (FalseBases.size () == Numbers.size ());
630+ NumRoots = FalseBases.size ();
631+ }
632+ if (isa<PointerType>(SI->getType ()) ?
633+ S.AllPtrNumbering .count (SI) :
634+ S.AllCompositeNumbering .count (SI)) {
635+ // MaybeExtractScalar or MaybeExtractVector handled this for us (recursively, though a PHINode)
636+ return ;
637+ }
638+ // need to handle each element (may just be one scalar)
639+ for (unsigned i = 0 ; i < NumRoots; ++i) {
640+ Value *TrueElem;
641+ if (isa<PointerType>(TrueBase->getType ()))
642+ TrueElem = TrueBase;
643+ else
644+ TrueElem = TrueBases[i];
645+ Value *FalseElem;
646+ if (isa<PointerType>(FalseBase->getType ()))
647+ FalseElem = FalseBase;
648+ else
649+ FalseElem = FalseBases[i];
650+ Value *Cond = SI->getCondition ();
651+ if (isa<VectorType>(Cond->getType ())) {
652+ Cond = ExtractElementInst::Create (Cond,
653+ ConstantInt::get (Type::getInt32Ty (Cond->getContext ()), i),
654+ " " , SI);
655+ }
656+ SelectInst *SelectBase = SelectInst::Create (Cond, TrueElem, FalseElem, " gclift" , SI);
657+ int Number = ++S.MaxPtrNumber ;
658+ S.AllPtrNumbering [SelectBase] = Number;
659+ S.ReversePtrNumbering [Number] = SelectBase;
660+ if (isa<PointerType>(SI->getType ()))
661+ S.AllPtrNumbering [SI] = Number;
662+ else
663+ Numbers[i] = Number;
664+ }
665+ if (isa<VectorType>(SI->getType ()) && NumRoots != Numbers.size ()) {
666+ // broadcast the scalar root number to fill the vector
667+ assert (NumRoots == 1 );
668+ int Number = Numbers[0 ];
669+ Numbers.resize (0 );
670+ Numbers.resize (SI->getType ()->getVectorNumElements (), Number);
678671 }
679672 if (!isa<PointerType>(SI->getType ()))
680673 S.AllCompositeNumbering [SI] = Numbers;
681- return didsplit;
682674}
683675
684676void LateLowerGCFrame::LiftPhi (State &S, PHINode *Phi) {
@@ -754,9 +746,8 @@ int LateLowerGCFrame::NumberBase(State &S, Value *CurrentV)
754746 // input IR)
755747 Number = -1 ;
756748 } else if (isa<SelectInst>(CurrentV) && !isTrackedValue (CurrentV)) {
757- Number = -1 ;
758- if (LiftSelect (S, cast<SelectInst>(CurrentV))) // lifting a scalar pointer (if necessary)
759- Number = S.AllPtrNumbering .at (CurrentV);
749+ LiftSelect (S, cast<SelectInst>(CurrentV));
750+ Number = S.AllPtrNumbering .at (CurrentV);
760751 return Number;
761752 } else if (isa<PHINode>(CurrentV) && !isTrackedValue (CurrentV)) {
762753 LiftPhi (S, cast<PHINode>(CurrentV));
0 commit comments