Skip to content
This repository has been archived by the owner on Jan 20, 2024. It is now read-only.

Commit

Permalink
Revert "[GlobalIsel] Combine select of binops (#76763)"
Browse files Browse the repository at this point in the history
This reverts commit 1687555.
  • Loading branch information
tschuett committed Jan 6, 2024
1 parent 61bb3d4 commit a085402
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 322 deletions.
3 changes: 0 additions & 3 deletions llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -910,9 +910,6 @@ class CombinerHelper {

bool tryFoldSelectOfConstants(GSelect *Select, BuildFnTy &MatchInfo);

/// Try to fold select(cc, binop(), binop()) -> binop(select(), X)
bool tryFoldSelectOfBinOps(GSelect *Select, BuildFnTy &MatchInfo);

bool isOneOrOneSplat(Register Src, bool AllowUndefs);
bool isZeroOrZeroSplat(Register Src, bool AllowUndefs);
bool isConstantSplatVector(Register Src, int64_t SplatValue,
Expand Down
103 changes: 0 additions & 103 deletions llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,109 +558,6 @@ class GVecReduce : public GenericMachineInstr {
}
};

// Represents a binary operation, i.e, x = y op z.
class GBinOp : public GenericMachineInstr {
public:
Register getLHSReg() const { return getReg(1); }
Register getRHSReg() const { return getReg(2); }

static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
// Integer.
case TargetOpcode::G_ADD:
case TargetOpcode::G_SUB:
case TargetOpcode::G_MUL:
case TargetOpcode::G_SDIV:
case TargetOpcode::G_UDIV:
case TargetOpcode::G_SREM:
case TargetOpcode::G_UREM:
case TargetOpcode::G_SMIN:
case TargetOpcode::G_SMAX:
case TargetOpcode::G_UMIN:
case TargetOpcode::G_UMAX:
// Floating point.
case TargetOpcode::G_FMINNUM:
case TargetOpcode::G_FMAXNUM:
case TargetOpcode::G_FMINNUM_IEEE:
case TargetOpcode::G_FMAXNUM_IEEE:
case TargetOpcode::G_FMINIMUM:
case TargetOpcode::G_FMAXIMUM:
case TargetOpcode::G_FADD:
case TargetOpcode::G_FSUB:
case TargetOpcode::G_FMUL:
case TargetOpcode::G_FDIV:
case TargetOpcode::G_FPOW:
// Logical.
case TargetOpcode::G_AND:
case TargetOpcode::G_OR:
case TargetOpcode::G_XOR:
return true;
default:
return false;
}
};
};

// Represents an integer binary operation.
class GIntBinOp : public GBinOp {
public:
static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
case TargetOpcode::G_ADD:
case TargetOpcode::G_SUB:
case TargetOpcode::G_MUL:
case TargetOpcode::G_SDIV:
case TargetOpcode::G_UDIV:
case TargetOpcode::G_SREM:
case TargetOpcode::G_UREM:
case TargetOpcode::G_SMIN:
case TargetOpcode::G_SMAX:
case TargetOpcode::G_UMIN:
case TargetOpcode::G_UMAX:
return true;
default:
return false;
}
};
};

// Represents a floating point binary operation.
class GFBinOp : public GBinOp {
public:
static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
case TargetOpcode::G_FMINNUM:
case TargetOpcode::G_FMAXNUM:
case TargetOpcode::G_FMINNUM_IEEE:
case TargetOpcode::G_FMAXNUM_IEEE:
case TargetOpcode::G_FMINIMUM:
case TargetOpcode::G_FMAXIMUM:
case TargetOpcode::G_FADD:
case TargetOpcode::G_FSUB:
case TargetOpcode::G_FMUL:
case TargetOpcode::G_FDIV:
case TargetOpcode::G_FPOW:
return true;
default:
return false;
}
};
};

// Represents a logical binary operation.
class GLogicalBinOp : public GBinOp {
public:
static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
case TargetOpcode::G_AND:
case TargetOpcode::G_OR:
case TargetOpcode::G_XOR:
return true;
default:
return false;
}
};
};

} // namespace llvm

Expand Down
93 changes: 28 additions & 65 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6390,7 +6390,8 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue.isZero() && FalseValue.isOne()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
auto Inner = B.buildNot(CondTy, Cond);
Register Inner = MRI.createGenericVirtualRegister(CondTy);
B.buildNot(Inner, Cond);
B.buildZExtOrTrunc(Dest, Inner);
};
return true;
Expand All @@ -6400,7 +6401,8 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue.isZero() && FalseValue.isAllOnes()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
auto Inner = B.buildNot(CondTy, Cond);
Register Inner = MRI.createGenericVirtualRegister(CondTy);
B.buildNot(Inner, Cond);
B.buildSExtOrTrunc(Dest, Inner);
};
return true;
Expand All @@ -6410,7 +6412,8 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue - 1 == FalseValue) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
auto Inner = B.buildZExtOrTrunc(TrueTy, Cond);
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Inner, Cond);
B.buildAdd(Dest, Inner, False);
};
return true;
Expand All @@ -6420,7 +6423,8 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue + 1 == FalseValue) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
auto Inner = B.buildSExtOrTrunc(TrueTy, Cond);
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
B.buildSExtOrTrunc(Inner, Cond);
B.buildAdd(Dest, Inner, False);
};
return true;
Expand All @@ -6430,7 +6434,8 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
auto Inner = B.buildZExtOrTrunc(TrueTy, Cond);
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Inner, Cond);
// The shift amount must be scalar.
LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2());
Expand All @@ -6442,7 +6447,8 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue.isAllOnes()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
auto Inner = B.buildSExtOrTrunc(TrueTy, Cond);
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
B.buildSExtOrTrunc(Inner, Cond);
B.buildOr(Dest, Inner, False, Flags);
};
return true;
Expand All @@ -6452,8 +6458,10 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (FalseValue.isAllOnes()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
auto Not = B.buildNot(CondTy, Cond);
auto Inner = B.buildSExtOrTrunc(TrueTy, Not);
Register Not = MRI.createGenericVirtualRegister(CondTy);
B.buildNot(Not, Cond);
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
B.buildSExtOrTrunc(Inner, Not);
B.buildOr(Dest, Inner, True, Flags);
};
return true;
Expand Down Expand Up @@ -6488,7 +6496,8 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
if ((Cond == True) || isOneOrOneSplat(True, /* AllowUndefs */ true)) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
auto Ext = B.buildZExtOrTrunc(TrueTy, Cond);
Register Ext = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Ext, Cond);
B.buildOr(DstReg, Ext, False, Flags);
};
return true;
Expand All @@ -6499,7 +6508,8 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
if ((Cond == False) || isZeroOrZeroSplat(False, /* AllowUndefs */ true)) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
auto Ext = B.buildZExtOrTrunc(TrueTy, Cond);
Register Ext = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Ext, Cond);
B.buildAnd(DstReg, Ext, True);
};
return true;
Expand All @@ -6510,9 +6520,11 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
// First the not.
auto Inner = B.buildNot(CondTy, Cond);
Register Inner = MRI.createGenericVirtualRegister(CondTy);
B.buildNot(Inner, Cond);
// Then an ext to match the destination register.
auto Ext = B.buildZExtOrTrunc(TrueTy, Inner);
Register Ext = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Ext, Inner);
B.buildOr(DstReg, Ext, True, Flags);
};
return true;
Expand All @@ -6523,9 +6535,11 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
// First the not.
auto Inner = B.buildNot(CondTy, Cond);
Register Inner = MRI.createGenericVirtualRegister(CondTy);
B.buildNot(Inner, Cond);
// Then an ext to match the destination register.
auto Ext = B.buildZExtOrTrunc(TrueTy, Inner);
Register Ext = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Ext, Inner);
B.buildAnd(DstReg, Ext, False);
};
return true;
Expand All @@ -6534,54 +6548,6 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
return false;
}

bool CombinerHelper::tryFoldSelectOfBinOps(GSelect *Select,
BuildFnTy &MatchInfo) {
Register DstReg = Select->getReg(0);
Register Cond = Select->getCondReg();
Register False = Select->getFalseReg();
Register True = Select->getTrueReg();
LLT DstTy = MRI.getType(DstReg);

GBinOp *LHS = getOpcodeDef<GBinOp>(True, MRI);
GBinOp *RHS = getOpcodeDef<GBinOp>(False, MRI);

// We need two binops of the same kind on the true/false registers.
if (!LHS || !RHS || LHS->getOpcode() != RHS->getOpcode())
return false;

// Note that there are no constraints on CondTy.
unsigned Flags = (LHS->getFlags() & RHS->getFlags()) | Select->getFlags();
unsigned Opcode = LHS->getOpcode();

// Fold select(cond, binop(x, y), binop(z, y))
// --> binop(select(cond, x, z), y)
if (LHS->getRHSReg() == RHS->getRHSReg()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
auto Sel = B.buildSelect(DstTy, Cond, LHS->getLHSReg(), RHS->getLHSReg(),
Select->getFlags());
B.buildInstr(Opcode, {DstReg}, {Sel, LHS->getRHSReg()}, Flags);
};
return true;
}

// Fold select(cond, binop(x, y), binop(x, z))
// --> binop(x, select(cond, y, z))
if (LHS->getLHSReg() == RHS->getLHSReg()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
auto Sel = B.buildSelect(DstTy, Cond, LHS->getRHSReg(), RHS->getRHSReg(),
Select->getFlags());
B.buildInstr(Opcode, {DstReg}, {LHS->getLHSReg(), Sel}, Flags);
};
return true;
}

// FIXME: use isCommutable().

return false;
}

bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
GSelect *Select = cast<GSelect>(&MI);

Expand All @@ -6591,8 +6557,5 @@ bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
if (tryFoldBoolSelectToLogic(Select, MatchInfo))
return true;

if (tryFoldSelectOfBinOps(Select, MatchInfo))
return true;

return false;
}
Loading

0 comments on commit a085402

Please sign in to comment.