Skip to content

Commit

Permalink
Select condition vector lanes must match the true and false value (#8465
Browse files Browse the repository at this point in the history
)

The ability for the condition to be a scalar while the value is a vector
was a mistake that added complexity in a few places. It's easy enough to
just check if the condition is a broadcast if you really need to know.

One place it added a bunch of complexity is that it meant in the RHS of
the simplifier rules, sometimes you needed an implicit broadcast. This
was responsible for about a third of the code size in Simplify_Sub.o!
It's also unclear if it respects the reduction order we use to prove the
simplifier terminates, because those rules were turning an implicit
broadcast in the IR into a new actual Broadcast node.
  • Loading branch information
abadams authored Nov 6, 2024
1 parent 2e73b3c commit b3d42e5
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 65 deletions.
2 changes: 1 addition & 1 deletion src/CodeGen_C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2038,7 +2038,7 @@ void CodeGen_C::visit(const Select *op) {

// clang doesn't support the ternary operator on OpenCL style vectors.
// See: https://bugs.llvm.org/show_bug.cgi?id=33103
if (op->condition.type().is_scalar()) {
if (op->type.is_scalar()) {
rhs << "(" << type << ")"
<< "(" << cond
<< " ? " << true_val
Expand Down
5 changes: 3 additions & 2 deletions src/CodeGen_Hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2104,10 +2104,11 @@ void CodeGen_Hexagon::visit(const Min *op) {
}

void CodeGen_Hexagon::visit(const Select *op) {
if (op->condition.type().is_scalar() && op->type.is_vector()) {
const Broadcast *b = op->condition.as<Broadcast>();
if (op->type.is_vector() && b && b->type.is_scalar()) {
// Implement scalar conditions on vector values with if-then-else.
value = codegen(Call::make(op->type, Call::if_then_else,
{op->condition, op->true_value, op->false_value},
{b->value, op->true_value, op->false_value},
Call::PureIntrinsic));
} else {
CodeGen_Posix::visit(op);
Expand Down
6 changes: 0 additions & 6 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1880,12 +1880,6 @@ void CodeGen_LLVM::visit(const Not *op) {

void CodeGen_LLVM::visit(const Select *op) {
Value *cmp = codegen(op->condition);
if (use_llvm_vp_intrinsics &&
op->type.is_vector() &&
op->condition.type().is_scalar()) {
cmp = create_broadcast(cmp, op->type.lanes());
}

Value *a = codegen(op->true_value);
Value *b = codegen(op->false_value);
if (a->getType()->isVectorTy()) {
Expand Down
2 changes: 1 addition & 1 deletion src/CodeGen_OpenCL_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Cast *op) {
}

void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Select *op) {
if (!op->condition.type().is_scalar()) {
if (op->type.is_vector()) {
// A vector of bool was recursively introduced while
// performing codegen. Eliminate it.
Expr equiv = eliminate_bool_vectors(op);
Expand Down
2 changes: 1 addition & 1 deletion src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ void CodeGen_X86::visit(const NE *op) {
}

void CodeGen_X86::visit(const Select *op) {
if (op->condition.type().is_vector()) {
if (op->type.is_vector()) {
// LLVM handles selects on vector conditions much better at native width
Value *cond = codegen(op->condition);
Value *true_val = codegen(op->true_value);
Expand Down
5 changes: 2 additions & 3 deletions src/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,8 @@ Expr Select::make(Expr condition, Expr true_value, Expr false_value) {
internal_assert(false_value.defined()) << "Select of undefined\n";
internal_assert(condition.type().is_bool()) << "First argument to Select is not a bool: " << condition.type() << "\n";
internal_assert(false_value.type() == true_value.type()) << "Select of mismatched types\n";
internal_assert(condition.type().is_scalar() ||
condition.type().lanes() == true_value.type().lanes())
<< "In Select, vector lanes of condition must either be 1, or equal to vector lanes of arguments\n";
internal_assert(condition.type().lanes() == true_value.type().lanes())
<< "In Select, vector lanes of condition must be equal to vector lanes of arguments\n";

Select *node = new Select;
node->type = true_value.type();
Expand Down
70 changes: 28 additions & 42 deletions src/IRMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty) {
return Expr();
}
if (lanes > 1) {
e = Broadcast::make(e, lanes);
e = Broadcast::make(std::move(e), lanes);
}
return e;
}
Expand Down Expand Up @@ -713,14 +713,6 @@ struct BinOp {
ea = a.make(state, type_hint);
eb = b.make(state, ea.type());
}
// We sometimes mix vectors and scalars in the rewrite rules,
// so insert a broadcast if necessary.
if (ea.type().is_vector() && !eb.type().is_vector()) {
eb = Broadcast::make(eb, ea.type().lanes());
}
if (eb.type().is_vector() && !ea.type().is_vector()) {
ea = Broadcast::make(ea, eb.type().lanes());
}
return Op::make(std::move(ea), std::move(eb));
}
};
Expand Down Expand Up @@ -815,14 +807,6 @@ struct CmpOp {
ea = a.make(state, {});
eb = b.make(state, ea.type());
}
// We sometimes mix vectors and scalars in the rewrite rules,
// so insert a broadcast if necessary.
if (ea.type().is_vector() && !eb.type().is_vector()) {
eb = Broadcast::make(eb, ea.type().lanes());
}
if (eb.type().is_vector() && !ea.type().is_vector()) {
ea = Broadcast::make(ea, eb.type().lanes());
}
return Op::make(std::move(ea), std::move(eb));
}
};
Expand Down Expand Up @@ -1405,55 +1389,55 @@ struct Intrin {
Expr make(MatcherState &state, halide_type_t type_hint) const {
Expr arg0 = std::get<0>(args).make(state, type_hint);
if (intrin == Call::likely) {
return likely(arg0);
return likely(std::move(arg0));
} else if (intrin == Call::likely_if_innermost) {
return likely_if_innermost(arg0);
return likely_if_innermost(std::move(arg0));
} else if (intrin == Call::abs) {
return abs(arg0);
return abs(std::move(arg0));
} else if (intrin == Call::saturating_cast) {
return saturating_cast(optional_type_hint, arg0);
return saturating_cast(optional_type_hint, std::move(arg0));
}

Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
if (intrin == Call::absd) {
return absd(arg0, arg1);
return absd(std::move(arg0), std::move(arg1));
} else if (intrin == Call::widen_right_add) {
return widen_right_add(arg0, arg1);
return widen_right_add(std::move(arg0), std::move(arg1));
} else if (intrin == Call::widen_right_mul) {
return widen_right_mul(arg0, arg1);
return widen_right_mul(std::move(arg0), std::move(arg1));
} else if (intrin == Call::widen_right_sub) {
return widen_right_sub(arg0, arg1);
return widen_right_sub(std::move(arg0), std::move(arg1));
} else if (intrin == Call::widening_add) {
return widening_add(arg0, arg1);
return widening_add(std::move(arg0), std::move(arg1));
} else if (intrin == Call::widening_sub) {
return widening_sub(arg0, arg1);
return widening_sub(std::move(arg0), std::move(arg1));
} else if (intrin == Call::widening_mul) {
return widening_mul(arg0, arg1);
return widening_mul(std::move(arg0), std::move(arg1));
} else if (intrin == Call::saturating_add) {
return saturating_add(arg0, arg1);
return saturating_add(std::move(arg0), std::move(arg1));
} else if (intrin == Call::saturating_sub) {
return saturating_sub(arg0, arg1);
return saturating_sub(std::move(arg0), std::move(arg1));
} else if (intrin == Call::halving_add) {
return halving_add(arg0, arg1);
return halving_add(std::move(arg0), std::move(arg1));
} else if (intrin == Call::halving_sub) {
return halving_sub(arg0, arg1);
return halving_sub(std::move(arg0), std::move(arg1));
} else if (intrin == Call::rounding_halving_add) {
return rounding_halving_add(arg0, arg1);
return rounding_halving_add(std::move(arg0), std::move(arg1));
} else if (intrin == Call::shift_left) {
return arg0 << arg1;
return std::move(arg0) << std::move(arg1);
} else if (intrin == Call::shift_right) {
return arg0 >> arg1;
return std::move(arg0) >> std::move(arg1);
} else if (intrin == Call::rounding_shift_left) {
return rounding_shift_left(arg0, arg1);
return rounding_shift_left(std::move(arg0), std::move(arg1));
} else if (intrin == Call::rounding_shift_right) {
return rounding_shift_right(arg0, arg1);
return rounding_shift_right(std::move(arg0), std::move(arg1));
}

Expr arg2 = std::get<const_min(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
if (intrin == Call::mul_shift_right) {
return mul_shift_right(arg0, arg1, arg2);
return mul_shift_right(std::move(arg0), std::move(arg1), std::move(arg2));
} else if (intrin == Call::rounding_mul_shift_right) {
return rounding_mul_shift_right(arg0, arg1, arg2);
return rounding_mul_shift_right(std::move(arg0), std::move(arg1), std::move(arg2));
}

internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
Expand Down Expand Up @@ -1840,7 +1824,7 @@ struct RampOp {
Expr ea, eb;
eb = b.make(state, type_hint);
ea = a.make(state, eb.type());
return Ramp::make(ea, eb, l);
return Ramp::make(std::move(ea), std::move(eb), l);
}

constexpr static bool foldable = false;
Expand Down Expand Up @@ -2210,8 +2194,7 @@ struct Fold {
ty.bits = type_hint.bits;
}

Expr e = make_const_expr(c, ty);
return e;
return make_const_expr(c, ty);
}

constexpr static bool foldable = A::foldable;
Expand Down Expand Up @@ -2864,6 +2847,9 @@ struct Rewriter {

template<typename After>
HALIDE_NEVER_INLINE void build_replacement(After after) {
#if HALIDE_DEBUG_MATCHED_RULES
debug(0) << instance << " -> " << after << "\n";
#endif
result = after.make(state, output_type);
}

Expand Down
6 changes: 5 additions & 1 deletion src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1497,7 +1497,7 @@ Expr saturating_cast(Type t, Expr e) {
Expr select(Expr condition, Expr true_value, Expr false_value) {
if (as_const_int(condition)) {
// Why are you doing this? We'll preserve the select node until constant folding for you.
condition = cast(Bool(), std::move(condition));
condition = cast(Bool(true_value.type().lanes()), std::move(condition));
}

// Coerce int literals to the type of the other argument
Expand All @@ -1517,6 +1517,10 @@ Expr select(Expr condition, Expr true_value, Expr false_value) {
<< " " << true_value << " has type " << true_value.type() << "\n"
<< " " << false_value << " has type " << false_value.type() << "\n";

if (true_value.type().is_vector() && condition.type().is_scalar()) {
condition = Broadcast::make(std::move(condition), true_value.type().lanes());
}

return Select::make(std::move(condition), std::move(true_value), std::move(false_value));
}

Expand Down
16 changes: 12 additions & 4 deletions src/Simplify_Mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,18 @@ Expr Simplify::visit(const Mod *op, ExprInfo *info) {
rewrite(ramp(x, c0, lanes) % broadcast(c1, lanes), ramp(x % c1, c0, lanes),
// First and last lanes are the same when...
can_prove((x % c1 + c0 * (lanes - 1)) / c1 == 0, this)) ||
rewrite(ramp(x * c0, c2, c3) % broadcast(c1, c3), ramp(x * fold(c0 % c1), fold(c2 % c1), c3) % c1, c1 > 0 && (c0 >= c1 || c0 < 0)) ||
rewrite(ramp(x + c0, c2, c3) % broadcast(c1, c3), ramp(x + fold(c0 % c1), fold(c2 % c1), c3) % c1, c1 > 0 && (c0 >= c1 || c0 < 0)) ||
rewrite(ramp(x * c0 + y, c2, c3) % broadcast(c1, c3), ramp(y, fold(c2 % c1), c3) % c1, c0 % c1 == 0) ||
rewrite(ramp(y + x * c0, c2, c3) % broadcast(c1, c3), ramp(y, fold(c2 % c1), c3) % c1, c0 % c1 == 0))))) {
rewrite(ramp(x * c0, c2, c3) % broadcast(c1, c3),
ramp(x * fold(c0 % c1), fold(c2 % c1), c3) % broadcast(c1, c3),
c1 > 0 && (c0 >= c1 || c0 < 0)) ||
rewrite(ramp(x + c0, c2, c3) % broadcast(c1, c3),
ramp(x + fold(c0 % c1), fold(c2 % c1), c3) % broadcast(c1, c3),
c1 > 0 && (c0 >= c1 || c0 < 0)) ||
rewrite(ramp(x * c0 + y, c2, c3) % broadcast(c1, c3),
ramp(y, fold(c2 % c1), c3) % broadcast(c1, c3),
c0 % c1 == 0) ||
rewrite(ramp(y + x * c0, c2, c3) % broadcast(c1, c3),
ramp(y, fold(c2 % c1), c3) % broadcast(c1, c3),
c0 % c1 == 0))))) {
return mutate(rewrite.result, info);
}
// clang-format on
Expand Down
4 changes: 1 addition & 3 deletions src/Simplify_Select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ Expr Simplify::visit(const Select *op, ExprInfo *info) {
}

if (may_simplify(op->type)) {
int lanes = op->type.lanes();
auto rewrite = IRMatcher::rewriter(IRMatcher::select(condition, true_value, false_value), op->type);

// clang-format off
Expand All @@ -40,8 +39,7 @@ Expr Simplify::visit(const Select *op, ExprInfo *info) {

// clang-format off
if (EVAL_IN_LAMBDA
(rewrite(select(broadcast(x, lanes), y, z), select(x, y, z)) ||
rewrite(select(x != y, z, w), select(x == y, w, z)) ||
(rewrite(select(x != y, z, w), select(x == y, w, z)) ||
rewrite(select(x <= y, z, w), select(y < x, w, z)) ||
rewrite(select(x, select(y, z, w), z), select(x && !y, w, z)) ||
rewrite(select(x, select(y, z, w), w), select(x && y, z, w)) ||
Expand Down
1 change: 1 addition & 0 deletions src/VectorizeLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ class VectorSubs : public IRMutator {
// Widen the true and false values, but we don't have to widen the condition
true_value = widen(true_value, lanes);
false_value = widen(false_value, lanes);
condition = widen(condition, lanes);
return Select::make(condition, true_value, false_value);
}
}
Expand Down
2 changes: 1 addition & 1 deletion test/correctness/fuzz_simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Expr random_expr(std::mt19937 &rng, Type t, int depth, bool overflow_undef) {
auto c = random_condition(rng, t, depth, true);
auto e1 = random_expr(rng, t, depth, overflow_undef);
auto e2 = random_expr(rng, t, depth, overflow_undef);
return Select::make(c, e1, e2);
return select(c, e1, e2);
},
[&]() {
if (t.lanes() != 1) {
Expand Down

0 comments on commit b3d42e5

Please sign in to comment.