diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index 8fb7554ec84b..b1581aeabdc7 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -130,7 +130,10 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { auto rewrite = IRMatcher::rewriter(IRMatcher::h_add(value, lanes), op->type); if (rewrite(h_add(x * broadcast(y, arg_lanes), lanes), h_add(x, lanes) * broadcast(y, lanes)) || rewrite(h_add(broadcast(x, arg_lanes) * y, lanes), h_add(y, lanes) * broadcast(x, lanes)) || - rewrite(h_add(broadcast(x, arg_lanes), lanes), broadcast(x * factor, lanes))) { + rewrite(h_add(broadcast(x, arg_lanes), lanes), broadcast(x * factor, lanes)) || + rewrite(h_add(broadcast(x, c0), lanes), broadcast(h_add(x, lanes / c0), c0), lanes % c0 == 0) || + rewrite(h_add(broadcast(x, c0), lanes), broadcast(h_add(x, 1) * (c0 / lanes), lanes), c0 % lanes == 0) || + false) { return mutate(rewrite.result, info); } break; @@ -142,8 +145,10 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_min(max(x, broadcast(y, arg_lanes)), lanes), max(h_min(x, lanes), broadcast(y, lanes))) || rewrite(h_min(max(broadcast(x, arg_lanes), y), lanes), max(h_min(y, lanes), broadcast(x, lanes))) || rewrite(h_min(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) || - rewrite(h_min(broadcast(x, c0), lanes), h_min(x, lanes), factor % c0 == 0) || - rewrite(h_min(ramp(x, y, arg_lanes), lanes), x + min(y * (arg_lanes - 1), 0)) || + rewrite(h_min(broadcast(x, c0), 1), h_min(x, 1)) || + rewrite(h_min(broadcast(x, c0), lanes), broadcast(h_min(x, lanes / c0), c0), lanes % c0 == 0) || + rewrite(h_min(ramp(x, y, arg_lanes), 1), x + min(y * (arg_lanes - 1), 0)) || + rewrite(h_min(ramp(x, y, arg_lanes), lanes), ramp(x + min(y * (factor - 1), 0), y * factor, lanes)) || false) { return mutate(rewrite.result, info); } @@ -156,8 +161,10 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_max(max(x, broadcast(y, arg_lanes)), lanes), max(h_max(x, lanes), broadcast(y, lanes))) || rewrite(h_max(max(broadcast(x, arg_lanes), y), lanes), max(h_max(y, lanes), broadcast(x, lanes))) || rewrite(h_max(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) || - rewrite(h_max(broadcast(x, c0), lanes), h_max(x, lanes), factor % c0 == 0) || - rewrite(h_max(ramp(x, y, arg_lanes), lanes), x + max(y * (arg_lanes - 1), 0)) || + rewrite(h_max(broadcast(x, c0), 1), h_max(x, 1)) || + rewrite(h_max(broadcast(x, c0), lanes), broadcast(h_max(x, lanes / c0), c0), lanes % c0 == 0) || + rewrite(h_max(ramp(x, y, arg_lanes), 1), x + max(y * (arg_lanes - 1), 0)) || + rewrite(h_max(ramp(x, y, arg_lanes), lanes), ramp(x + max(y * (factor - 1), 0), y * factor, lanes)) || false) { return mutate(rewrite.result, info); } @@ -170,15 +177,16 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_and(x && broadcast(y, arg_lanes), lanes), h_and(x, lanes) && broadcast(y, lanes)) || rewrite(h_and(broadcast(x, arg_lanes) && y, lanes), h_and(y, lanes) && broadcast(x, lanes)) || rewrite(h_and(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) || - rewrite(h_and(broadcast(x, c0), lanes), h_and(x, lanes), factor % c0 == 0) || - rewrite(h_and(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes), - x + max(y * (arg_lanes - 1), 0) < z) || - rewrite(h_and(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes), - x + max(y * (arg_lanes - 1), 0) <= z) || - rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), - x < y + min(z * (arg_lanes - 1), 0)) || - rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), - x <= y + min(z * (arg_lanes - 1), 0)) || + rewrite(h_and(broadcast(x, c0), lanes), broadcast(h_and(x, lanes / c0), c0), lanes % c0 == 0) || + rewrite(h_and(broadcast(x, c0), lanes), broadcast(h_and(x, 1), lanes), c0 >= lanes) || + (lanes == 1 && rewrite(h_and(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes), + x + max(y * (arg_lanes - 1), 0) < z)) || + (lanes == 1 && rewrite(h_and(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes), + x + max(y * (arg_lanes - 1), 0) <= z)) || + (lanes == 1 && rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), + x < y + min(z * (arg_lanes - 1), 0))) || + (lanes == 1 && rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), + x <= y + min(z * (arg_lanes - 1), 0))) || false) { return mutate(rewrite.result, info); } @@ -191,7 +199,8 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_or(x && broadcast(y, arg_lanes), lanes), h_or(x, lanes) && broadcast(y, lanes)) || rewrite(h_or(broadcast(x, arg_lanes) && y, lanes), h_or(y, lanes) && broadcast(x, lanes)) || rewrite(h_or(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) || - rewrite(h_or(broadcast(x, c0), lanes), h_or(x, lanes), factor % c0 == 0) || + rewrite(h_or(broadcast(x, c0), lanes), broadcast(h_or(x, lanes / c0), c0), lanes % c0 == 0) || + rewrite(h_or(broadcast(x, c0), lanes), broadcast(h_or(x, 1), lanes), c0 >= lanes) || // type of arg_lanes is somewhat indeterminate rewrite(h_or(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes), x + min(y * (arg_lanes - 1), 0) < z) || diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index de10bde5a1b9..ff934e1b82ba 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -805,11 +805,64 @@ void check_vectors() { check(VectorReduce::make(VectorReduce::And, Broadcast::make(bool_vector, 4), 1), VectorReduce::make(VectorReduce::And, bool_vector, 1)); check(VectorReduce::make(VectorReduce::Or, Broadcast::make(bool_vector, 4), 2), - VectorReduce::make(VectorReduce::Or, bool_vector, 2)); + Broadcast::make(VectorReduce::make(VectorReduce::Or, bool_vector, 1), 2)); check(VectorReduce::make(VectorReduce::Min, Broadcast::make(int_vector, 4), 4), - int_vector); + Broadcast::make(VectorReduce::make(VectorReduce::Min, int_vector, 1), 4)); check(VectorReduce::make(VectorReduce::Max, Broadcast::make(int_vector, 4), 8), - VectorReduce::make(VectorReduce::Max, Broadcast::make(int_vector, 4), 8)); + Broadcast::make(VectorReduce::make(VectorReduce::Max, int_vector, 2), 4)); + + { + Expr x = Variable::make(Int(32), "x"); + Expr y = Variable::make(Int(32), "y"); + + // == Symbolic Strides == + + // 1. Min: Scalar Reduction (arg_lanes=4, lanes=1 -> factor=4) + check(VectorReduce::make(VectorReduce::Min, Ramp::make(x, y, 4), 1), + min(y, 0) * 3 + x); + + // 2. Min: Vector Reduction (arg_lanes=6, lanes=2 -> factor=3) + check(VectorReduce::make(VectorReduce::Min, Ramp::make(x, y, 6), 2), + Ramp::make(min(y, 0) * 2 + x, y * 3, 2)); + + // 3. Max: Scalar Reduction (arg_lanes=4, lanes=1 -> factor=4) + check(VectorReduce::make(VectorReduce::Max, Ramp::make(x, y, 4), 1), + max(y, 0) * 3 + x); + + // 4. Max: Vector Reduction (arg_lanes=6, lanes=2 -> factor=3) + check(VectorReduce::make(VectorReduce::Max, Ramp::make(x, y, 6), 2), + Ramp::make(max(y, 0) * 2 + x, y * 3, 2)); + + // == Constant Strides (Positive & Negative) == + + // 5. Min: Positive Stride (arg_lanes=8, lanes=2 -> factor=4, stride=2) + // Block 1: min(x, x+2, x+4, x+6) -> x + // Expected Base: x + min(2 * 3, 0) -> x + 0 -> x + // Expected Stride: 2 * 4 = 8 + check(VectorReduce::make(VectorReduce::Min, Ramp::make(x, 2, 8), 2), + Ramp::make(x, 8, 2)); + + // 6. Max: Positive Stride (arg_lanes=8, lanes=2 -> factor=4, stride=2) + // Block 1: max(x, x+2, x+4, x+6) -> x+6 + // Expected Base: x + max(2 * 3, 0) -> x + 6 + // Expected Stride: 2 * 4 = 8 + check(VectorReduce::make(VectorReduce::Max, Ramp::make(x, 2, 8), 2), + Ramp::make(x + 6, 8, 2)); + + // 7. Min: Negative Stride (arg_lanes=8, lanes=2 -> factor=4, stride=-2) + // Block 1: min(x, x-2, x-4, x-6) -> x-6 + // Expected Base: x + min(-2 * 3, 0) -> x - 6 + // Expected Stride: -2 * 4 = -8 + check(VectorReduce::make(VectorReduce::Min, Ramp::make(x, -2, 8), 2), + Ramp::make(x + -6, -8, 2)); + + // 8. Max: Negative Stride (arg_lanes=8, lanes=2 -> factor=4, stride=-2) + // Block 1: max(x, x-2, x-4, x-6) -> x + // Expected Base: x + max(-2 * 3, 0) -> x + 0 -> x + // Expected Stride: -2 * 4 = -8 + check(VectorReduce::make(VectorReduce::Max, Ramp::make(x, -2, 8), 2), + Ramp::make(x, -8, 2)); + } { // h_add(broadcast(x, 8), 4) should simplify to broadcast(x * 2, 4)