diff --git a/core/src/main/java/com/google/errorprone/bugpatterns/AssertThrowsMinimizer.java b/core/src/main/java/com/google/errorprone/bugpatterns/AssertThrowsMinimizer.java index 75424e139f3..b746698e995 100644 --- a/core/src/main/java/com/google/errorprone/bugpatterns/AssertThrowsMinimizer.java +++ b/core/src/main/java/com/google/errorprone/bugpatterns/AssertThrowsMinimizer.java @@ -30,6 +30,7 @@ import static com.google.errorprone.util.ASTHelpers.getThrownExceptions; import static com.google.errorprone.util.ASTHelpers.getType; import static com.google.errorprone.util.ASTHelpers.isCheckedExceptionType; +import static com.google.errorprone.util.ASTHelpers.isSubtype; import static java.util.stream.Collectors.toCollection; import com.google.common.base.CaseFormat; @@ -60,8 +61,10 @@ import com.sun.source.tree.NewClassTree; import com.sun.source.tree.StatementTree; import com.sun.source.tree.Tree; +import com.sun.source.tree.TypeCastTree; +import com.sun.source.util.TreePath; import com.sun.tools.javac.code.Symbol; -import com.sun.tools.javac.code.Symbol.VarSymbol; +import com.sun.tools.javac.code.Symbol.MethodSymbol; import com.sun.tools.javac.code.Type; import com.sun.tools.javac.code.Types; import java.util.HashSet; @@ -145,14 +148,19 @@ private Optional matchMethodInvocation( return Optional.empty(); } } + MethodSymbol sym = getSymbol(runnable); ImmutableList toHoist = Streams.concat( Stream.ofNullable(getReceiver(runnable)) .map(r -> new Hoist(r, receiverVariableName(r))), Streams.zip( runnable.getArguments().stream(), - getSymbol(runnable).getParameters().stream(), - (ExpressionTree a, VarSymbol p) -> new Hoist(a, p.getSimpleName().toString()))) + Streams.concat( + sym.getParameters().stream(), + // if there are varargs, there may be more arguments than parameters + Stream.generate(() -> sym.getParameters().getLast())) + .map(p -> p.getSimpleName().toString()), + (ExpressionTree a, String p) -> new Hoist(a, p))) .filter(h -> needsHoisting(h.site(), exceptionType, state)) .collect(toImmutableList()); if (toHoist.isEmpty()) { @@ -188,7 +196,9 @@ public Description matchMethod(MethodTree tree, VisitorState state) { state); } - VariableNamer variableNamer = new VariableNamer(state); + // update the tree path so VariableName considers the method parameters + VariableNamer variableNamer = + new VariableNamer(state.withPath(new TreePath(state.getPath(), toFix.getFirst().runnable))); for (AssertThrows current : toFix) { StringBuilder hoistedVariables = new StringBuilder(); for (Hoist hoist : current.toHoist) { @@ -203,7 +213,7 @@ public Description matchMethod(MethodTree tree, VisitorState state) { "%s %s = %s;\n", useVarType ? "var" : SuggestedFixes.qualifyType(state, fix, type), identifier, - state.getSourceForNode(hoist.site()))); + state.getSourceForNode(initializer(state, hoist.site(), type)))); fix.replace(hoist.site(), identifier); } fix.prefixWith(current.parent(), hoistedVariables.toString()); @@ -215,6 +225,18 @@ public Description matchMethod(MethodTree tree, VisitorState state) { return describeMatch(toFix.getFirst().parent(), fix.build()); } + private ExpressionTree initializer(VisitorState state, ExpressionTree site, Type type) { + if (useVarType) { + return site; + } + if (site instanceof TypeCastTree typeCastTree + && isSubtype(getType(typeCastTree.getExpression()), type, state)) { + // avoid unnecessary casts in hoisted variables + return typeCastTree.getExpression(); + } + return site; + } + private void addThrows( MethodTree tree, Type exceptionType, SuggestedFix.Builder fix, VisitorState state) { var types = state.getTypes(); @@ -278,6 +300,12 @@ private boolean needsHoisting(ExpressionTree tree, Type exceptionType, VisitorSt .orElse(Stream.empty())) .anyMatch(t -> needsHoisting(t, exceptionType, state)); case NewClassTree newClassTree -> newClassTreeNeedsHoisting(newClassTree); + case TypeCastTree typeCastTree -> + needsHoisting(typeCastTree.getExpression(), exceptionType, state) + || !isSubtype( + getType(typeCastTree.getExpression()), + getType(typeCastTree.getType()), + state); default -> true; }; if (!needsHoisting) { diff --git a/core/src/test/java/com/google/errorprone/bugpatterns/AssertThrowsMinimizerTest.java b/core/src/test/java/com/google/errorprone/bugpatterns/AssertThrowsMinimizerTest.java index bc5026626e4..1cd77f2df3e 100644 --- a/core/src/test/java/com/google/errorprone/bugpatterns/AssertThrowsMinimizerTest.java +++ b/core/src/test/java/com/google/errorprone/bugpatterns/AssertThrowsMinimizerTest.java @@ -906,4 +906,110 @@ void f() { """) .doTest(TEXT_MATCH); } + + @Test + public void varArgs() { + compilationHelper + .addInputLines( + "Test.java", + """ + import java.util.ArrayList; + import java.util.List; + import java.util.function.Supplier; + import static org.junit.Assert.assertThrows; + + abstract class Test { + void f() { + assertThrows(IllegalStateException.class, () -> doSomething(getString(), getString())); + } + + abstract String getString(); + + abstract void doSomething(String... strings); + } + """) + .addOutputLines( + "Test.java", + """ + import java.util.ArrayList; + import java.util.List; + import java.util.function.Supplier; + import static org.junit.Assert.assertThrows; + + abstract class Test { + void f() { + String strings = getString(); + String strings2 = getString(); + assertThrows(IllegalStateException.class, () -> doSomething(strings, strings2)); + } + + abstract String getString(); + + abstract void doSomething(String... strings); + } + """) + .doTest(TEXT_MATCH); + } + + @Test + public void cast() { + compilationHelper + .addInputLines( + "Test.java", + """ + import java.util.ArrayList; + import java.util.List; + import java.util.function.Supplier; + import static org.junit.Assert.assertThrows; + + abstract class Test { + void f(String s, Object o) { + assertThrows(IllegalStateException.class, () -> doSomething((String) getString())); + assertThrows(IllegalStateException.class, () -> doSomething((String) s)); + assertThrows(IllegalStateException.class, () -> doSomething((String) o)); + assertThrows(IllegalStateException.class, () -> doSomething((Object) s)); + assertThrows(IllegalStateException.class, () -> doSomething((Object) o)); + assertThrows(IllegalStateException.class, () -> doSomething((String) null)); + } + + abstract String getString(); + + abstract Object getObject(); + + abstract void doSomething(String s); + + abstract void doSomething(Object o); + } + """) + .addOutputLines( + "Test.java", + """ + import java.util.ArrayList; + import java.util.List; + import java.util.function.Supplier; + import static org.junit.Assert.assertThrows; + + abstract class Test { + void f(String s, Object o) { + String s2 = getString(); + assertThrows(IllegalStateException.class, () -> doSomething(s2)); + assertThrows(IllegalStateException.class, () -> doSomething((String) s)); + String s3 = (String) o; + assertThrows(IllegalStateException.class, () -> doSomething(s3)); + assertThrows(IllegalStateException.class, () -> doSomething((Object) s)); + assertThrows(IllegalStateException.class, () -> doSomething((Object) o)); + assertThrows(IllegalStateException.class, () -> doSomething((String) null)); + } + + abstract String getString(); + + abstract Object getObject(); + + abstract void doSomething(String s); + + abstract void doSomething(Object o); + } + """) + .doTest(TEXT_MATCH); + } }