Introduce MethodReferenceUsageCheck

This commit is contained in:
Stephan Schroevers
2017-12-16 14:02:54 +01:00
parent e0fb686bf0
commit dadb21e506
2 changed files with 507 additions and 0 deletions

View File

@@ -0,0 +1,197 @@
package com.picnicinternational.errorprone.bugpatterns;
import static com.google.common.collect.ImmutableList.toImmutableList;
import com.google.auto.service.AutoService;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.BugPattern;
import com.google.errorprone.BugPattern.LinkType;
import com.google.errorprone.BugPattern.ProvidesFix;
import com.google.errorprone.BugPattern.SeverityLevel;
import com.google.errorprone.BugPattern.StandardTags;
import com.google.errorprone.VisitorState;
import com.google.errorprone.bugpatterns.BugChecker;
import com.google.errorprone.bugpatterns.BugChecker.LambdaExpressionTreeMatcher;
import com.google.errorprone.fixes.SuggestedFix;
import com.google.errorprone.fixes.SuggestedFixes;
import com.google.errorprone.matchers.Description;
import com.google.errorprone.util.ASTHelpers;
import com.sun.source.tree.BlockTree;
import com.sun.source.tree.ExpressionStatementTree;
import com.sun.source.tree.ExpressionTree;
import com.sun.source.tree.IdentifierTree;
import com.sun.source.tree.LambdaExpressionTree;
import com.sun.source.tree.MemberSelectTree;
import com.sun.source.tree.MethodInvocationTree;
import com.sun.source.tree.ParenthesizedTree;
import com.sun.source.tree.ReturnTree;
import com.sun.source.tree.Tree;
import com.sun.source.tree.Tree.Kind;
import com.sun.source.tree.VariableTree;
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Type;
import java.util.List;
import java.util.Optional;
import javax.lang.model.element.Name;
// XXX: Other custom expressions we could rewrite:
// - `a -> "str" + a` to `"str"::concat`. But only if `str` is provably non-null.
// - `(a, b) -> a + b` to `String::concat` or `{Integer,Long,Float,Double}::sum`. Also requires null
// checking.
// - `i -> new int[i]` to `int[]::new`.
// - `() -> new Foo()` to `Foo::new` (and variations).
// XXX: For JDK 9, add support for "var handles"!
@AutoService(BugChecker.class)
@BugPattern(
name = "MethodReferenceUsage",
summary = "Prefer method references over lambda expressions",
linkType = LinkType.NONE,
severity = SeverityLevel.SUGGESTION,
tags = StandardTags.STYLE,
providesFix = ProvidesFix.REQUIRES_HUMAN_ATTENTION
)
public final class MethodReferenceUsageCheck extends BugChecker
implements LambdaExpressionTreeMatcher {
private static final long serialVersionUID = 1L;
@Override
public Description matchLambdaExpression(LambdaExpressionTree tree, VisitorState state) {
/*
* Lambda expressions can be used in several places where method references cannot, either
* because the latter are not syntactically valid or ambiguous. Rather than encoding all these
* edge cases we try to compile the code with the suggested fix, to see whether this works.
*/
return constructMethodRef(tree, tree.getBody())
.map(SuggestedFix.Builder::build)
.filter(fix -> SuggestedFixes.compilesWithFix(fix, state))
.map(fix -> describeMatch(tree, fix))
.orElse(Description.NO_MATCH);
}
private static Optional<SuggestedFix.Builder> constructMethodRef(
LambdaExpressionTree lambdaExpr, Tree subTree) {
switch (subTree.getKind()) {
case BLOCK:
return constructMethodRef(lambdaExpr, (BlockTree) subTree);
case EXPRESSION_STATEMENT:
return constructMethodRef(lambdaExpr, ((ExpressionStatementTree) subTree).getExpression());
case METHOD_INVOCATION:
return constructMethodRef(lambdaExpr, (MethodInvocationTree) subTree);
case PARENTHESIZED:
return constructMethodRef(lambdaExpr, ((ParenthesizedTree) subTree).getExpression());
case RETURN:
return constructMethodRef(lambdaExpr, ((ReturnTree) subTree).getExpression());
default:
return Optional.empty();
}
}
private static Optional<SuggestedFix.Builder> constructMethodRef(
LambdaExpressionTree lambdaExpr, BlockTree subTree) {
return Optional.of(subTree.getStatements())
.filter(statements -> statements.size() == 1)
.flatMap(statements -> constructMethodRef(lambdaExpr, statements.get(0)));
}
private static Optional<SuggestedFix.Builder> constructMethodRef(
LambdaExpressionTree lambdaExpr, MethodInvocationTree subTree) {
return matchArguments(lambdaExpr, subTree)
.flatMap(expectedInstance -> constructMethodRef(lambdaExpr, subTree, expectedInstance));
}
private static Optional<SuggestedFix.Builder> constructMethodRef(
LambdaExpressionTree lambdaExpr,
MethodInvocationTree subTree,
Optional<Name> expectedInstance) {
ExpressionTree methodSelect = subTree.getMethodSelect();
switch (methodSelect.getKind()) {
case IDENTIFIER:
if (expectedInstance.isPresent()) {
/* Direct method call; there is no matching "implicit parameter". */
return Optional.empty();
}
Symbol sym = ASTHelpers.getSymbol(methodSelect);
if (!sym.isStatic()) {
return constructFix(lambdaExpr, "this", methodSelect);
}
return constructFix(lambdaExpr, sym.owner, methodSelect);
case MEMBER_SELECT:
return constructMethodRef(lambdaExpr, (MemberSelectTree) methodSelect, expectedInstance);
default:
throw new VerifyException("Unexpected type of expression: " + methodSelect.getKind());
}
}
private static Optional<SuggestedFix.Builder> constructMethodRef(
LambdaExpressionTree lambdaExpr, MemberSelectTree subTree, Optional<Name> expectedInstance) {
if (subTree.getExpression().getKind() != Kind.IDENTIFIER) {
// XXX: Could be parenthesized. Handle. Also in other classes.
/*
* Only suggest a replacement if the method select's expression provably doesn't have
* side-effects. Otherwise the replacement may not be behavior preserving.
*/
return Optional.empty();
}
Name lhs = ((IdentifierTree) subTree.getExpression()).getName();
if (!expectedInstance.isPresent()) {
return constructFix(lambdaExpr, lhs, subTree.getIdentifier());
}
Type lhsType = ASTHelpers.getType(subTree.getExpression());
if (lhsType == null || !expectedInstance.get().equals(lhs)) {
return Optional.empty();
}
// XXX: Dropping generic type information is in most cases fine or even more likely to
// yield a valid expression, but in some cases it's necessary to keep them.
// Maybe return multiple variants?
return constructFix(lambdaExpr, lhsType.tsym, subTree.getIdentifier());
}
private static Optional<Optional<Name>> matchArguments(
LambdaExpressionTree lambdaExpr, MethodInvocationTree subTree) {
ImmutableList<Name> expectedArguments = getVariables(lambdaExpr);
List<? extends ExpressionTree> args = subTree.getArguments();
int diff = expectedArguments.size() - args.size();
if (diff < 0 || diff > 1) {
return Optional.empty();
}
for (int i = 0; i < args.size(); i++) {
ExpressionTree arg = args.get(i);
if (arg.getKind() != Kind.IDENTIFIER
|| !((IdentifierTree) arg).getName().equals(expectedArguments.get(i + diff))) {
return Optional.empty();
}
}
return Optional.of(diff == 0 ? Optional.empty() : Optional.of(expectedArguments.get(0)));
}
private static ImmutableList<Name> getVariables(LambdaExpressionTree tree) {
return tree.getParameters().stream().map(VariableTree::getName).collect(toImmutableList());
}
private static Optional<SuggestedFix.Builder> constructFix(
LambdaExpressionTree lambdaExpr, Symbol target, Object methodName) {
Name sName = target.getSimpleName();
Optional<SuggestedFix.Builder> fix = constructFix(lambdaExpr, sName, methodName);
if (!"java.lang".equals(target.packge().toString())) {
Name fqName = target.getQualifiedName();
if (!sName.equals(fqName)) {
return fix.map(b -> b.addImport(fqName.toString()));
}
}
return fix;
}
private static Optional<SuggestedFix.Builder> constructFix(
LambdaExpressionTree lambdaExpr, Object target, Object methodName) {
return Optional.of(SuggestedFix.builder().replace(lambdaExpr, target + "::" + methodName));
}
}

View File

@@ -0,0 +1,310 @@
package com.picnicinternational.errorprone.bugpatterns;
import com.google.errorprone.BugCheckerRefactoringTestHelper;
import com.google.errorprone.BugCheckerRefactoringTestHelper.TestMode;
import com.google.errorprone.CompilationTestHelper;
import java.io.IOException;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public final class MethodReferenceUsageCheckTest {
private final CompilationTestHelper compilationTestHelper =
CompilationTestHelper.newInstance(MethodReferenceUsageCheck.class, getClass());
private final BugCheckerRefactoringTestHelper refactoringTestHelper =
BugCheckerRefactoringTestHelper.newInstance(new MethodReferenceUsageCheck(), getClass());
@Test
public void testIdentification() {
compilationTestHelper
.addSourceLines(
"A.java",
"import com.google.common.collect.Streams;",
"import java.util.Map;",
"import java.util.HashMap;",
"import java.util.stream.Stream;",
"import java.util.function.IntConsumer;",
"import java.util.function.IntFunction;",
"",
"class A {",
" private final Stream<Integer> s = Stream.of(1);",
" private final Map<Integer, Integer> m = new HashMap<>();",
" private final Runnable thrower = () -> { throw new RuntimeException(); };",
"",
" void unaryExternalStaticFunctionCalls() {",
" s.forEach(String::valueOf);",
" // BUG: Diagnostic contains:",
" s.forEach(v -> String.valueOf(v));",
" // BUG: Diagnostic contains:",
" s.forEach((v) -> { String.valueOf(v); });",
" // BUG: Diagnostic contains:",
" s.forEach((Integer v) -> { { String.valueOf(v); } });",
" s.forEach(v -> { String.valueOf(v); String.valueOf(v); });",
"",
" s.map(String::valueOf);",
" // BUG: Diagnostic contains:",
" s.map(v -> String.valueOf(v));",
" // BUG: Diagnostic contains:",
" s.map((v) -> (String.valueOf(v)));",
" // BUG: Diagnostic contains:",
" s.map((Integer v) -> { return String.valueOf(v); });",
" // BUG: Diagnostic contains:",
" s.map((final Integer v) -> { return (String.valueOf(v)); });",
" s.map(v -> { String.valueOf(v); return String.valueOf(v); });",
"",
" s.findFirst().orElseGet(() -> Integer.valueOf(\"0\"));",
" m.forEach((k, v) -> String.valueOf(v));",
" m.forEach((k, v) -> String.valueOf(k));",
" }",
"",
" void binaryExternalInstanceFunctionCalls() {",
" m.forEach(m::put);",
" // BUG: Diagnostic contains:",
" m.forEach((k, v) -> m.put(k, v));",
" m.forEach((k, v) -> m.put(v, k));",
" // BUG: Diagnostic contains:",
" m.forEach((Integer k, Integer v) -> { m.put(k, v); });",
" m.forEach((k, v) -> { m.put(k, k); });",
" // BUG: Diagnostic contains:",
" m.forEach((final Integer k, final Integer v) -> { { m.put(k, v); } });",
" m.forEach((k, v) -> { { m.put(v, v); } });",
" m.forEach((k, v) -> new HashMap<Integer, Integer>().put(k, v));",
" m.forEach((k, v) -> { m.put(k, v); m.put(k, v); });",
"",
" Streams.zip(s, s, m::put);",
" // BUG: Diagnostic contains:",
" Streams.zip(s, s, (a, b) -> m.put(a, b));",
" Streams.zip(s, s, (a, b) -> m.put(b, a));",
" // BUG: Diagnostic contains:",
" Streams.zip(s, s, (Integer a, Integer b) -> (m.put(a, b)));",
" Streams.zip(s, s, (a, b) -> (m.put(a, a)));",
" // BUG: Diagnostic contains:",
" Streams.zip(s, s, (final Integer a, final Integer b) -> { return m.put(a, b); });",
" Streams.zip(s, s, (a, b) -> { return m.put(b, b); });",
" // BUG: Diagnostic contains:",
" Streams.zip(s, s, (a, b) -> { return (m.put(a, b)); });",
" Streams.zip(s, s, (a, b) -> { return (m.put(b, a)); });",
" Streams.zip(s, s, (a, b) -> { m.put(a, b); return m.put(a, b); });",
" }",
"",
" void nullaryExternalInstanceFunctionCalls() {",
" s.map(Integer::doubleValue);",
" // BUG: Diagnostic contains:",
" s.map(i -> i.doubleValue());",
// `s.map(Integer::toString)` is ambiguous
" s.map(i -> i.toString());",
" s.map(i -> s.toString());",
"",
" // BUG: Diagnostic contains:",
" Stream.of(int.class).filter(c -> c.isEnum());",
" Stream.of((Class<?>) int.class).filter(Class::isEnum);",
" // BUG: Diagnostic contains:",
" Stream.of((Class<?>) int.class).filter(c -> c.isEnum());",
" }",
"",
" void localFunctionCalls() {",
" s.forEach(v -> ivoid0());",
" s.forEach(v -> iint0());",
" s.forEach(v -> svoid0());",
" s.forEach(v -> sint0());",
"",
" s.forEach(this::ivoid1);",
" // BUG: Diagnostic contains:",
" s.forEach(v -> ivoid1(v));",
" // BUG: Diagnostic contains:",
" s.forEach(v -> { ivoid1(v); });",
" s.forEach(this::iint1);",
" // BUG: Diagnostic contains:",
" s.forEach(v -> iint1(v));",
" // BUG: Diagnostic contains:",
" s.forEach(v -> { iint1(v); });",
"",
" s.forEach(A::svoid1);",
" // BUG: Diagnostic contains:",
" s.forEach(v -> svoid1(v));",
" // BUG: Diagnostic contains:",
" s.forEach(v -> { svoid1(v); });",
" s.forEach(A::sint1);",
" // BUG: Diagnostic contains:",
" s.forEach(v -> sint1(v));",
" // BUG: Diagnostic contains:",
" s.forEach(v -> { sint1(v); });",
"",
" s.forEach(v -> ivoid2(v, v));",
" s.forEach(v -> iint2(v, v));",
" s.forEach(v -> svoid2(v, v));",
" s.forEach(v -> sint2(v, v));",
"",
" m.forEach((k, v) -> ivoid0());",
" m.forEach((k, v) -> iint0());",
" m.forEach((k, v) -> svoid0());",
" m.forEach((k, v) -> sint0());",
"",
" m.forEach(this::ivoid2);",
" // BUG: Diagnostic contains:",
" m.forEach((k, v) -> ivoid2(k, v));",
" // BUG: Diagnostic contains:",
" m.forEach((k, v) -> { ivoid2(k, v); });",
" m.forEach(this::iint2);",
" // BUG: Diagnostic contains:",
" m.forEach((k, v) -> iint2(k, v));",
" // BUG: Diagnostic contains:",
" m.forEach((k, v) -> { iint2(k, v); });",
"",
" m.forEach(A::svoid2);",
" // BUG: Diagnostic contains:",
" m.forEach((k, v) -> svoid2(k, v));",
" // BUG: Diagnostic contains:",
" m.forEach((k, v) -> { svoid2(k, v); });",
" m.forEach(A::sint2);",
" // BUG: Diagnostic contains:",
" m.forEach((k, v) -> sint2(k, v));",
" // BUG: Diagnostic contains:",
" m.forEach((k, v) -> { sint2(k, v); });",
" }",
"",
" void functionCallsWhoseReplacementWouldBeAmbiguous() {",
" receiver(i -> { Integer.toString(i); });",
" }",
"",
" void assortedOtherEdgeCases() {",
" s.forEach(v -> String.valueOf(v.toString()));",
" TernaryOp o1 = (a, b, c) -> String.valueOf(a);",
" TernaryOp o2 = (a, b, c) -> String.valueOf(b);",
" TernaryOp o3 = (a, b, c) -> String.valueOf(c);",
" TernaryOp o4 = (a, b, c) -> c.concat(a);",
" TernaryOp o5 = (a, b, c) -> c.concat(b);",
" TernaryOp o6 = (a, b, c) -> a.concat(c);",
" TernaryOp o7 = (a, b, c) -> b.concat(c);",
" }",
"",
" void receiver(IntFunction<?> op) { }",
" void receiver(IntConsumer op) { }",
"",
" void ivoid0() { }",
" void ivoid1(int a) { }",
" void ivoid2(int a, int b) { }",
" int iint0() { return 0; }",
" int iint1(int a) { return 0; }",
" int iint2(int a, int b) { return 0; }",
"",
" static void svoid0() { }",
" static void svoid1(int a) { }",
" static void svoid2(int a, int b) { }",
" static void svoid3(int a, int b, int c) { }",
" static int sint0() { return 0; }",
" static int sint1(int a) { return 0; }",
" static int sint2(int a, int b) { return 0; }",
"",
" interface TernaryOp {",
" String collect(String a, String b, String c);",
" }",
"}")
.doTest();
}
@Test
public void testReplacement() throws IOException {
refactoringTestHelper
.addInputLines(
"in/A.java",
"import static java.util.Collections.emptyList;",
"",
"import java.util.Collections;",
"import java.util.List;",
"import java.util.Map;",
// Don't import `java.util.Set`; it should be added.
"import java.util.function.IntConsumer;",
"import java.util.function.IntFunction;",
"import java.util.function.IntSupplier;",
"import java.util.function.Supplier;",
"import java.util.stream.Stream;",
"",
"class A {",
" static class B extends A {",
" final A a = new B();",
" final B b = new B();",
"",
" IntSupplier intSup;",
" Supplier<List<?>> listSup;",
"",
" void m() {",
" intSup = () -> a.iint0();",
" intSup = () -> b.iint0();",
" intSup = () -> this.iint0();",
" intSup = () -> super.iint0();",
"",
" intSup = () -> a.sint0();",
" intSup = () -> b.sint0();",
" intSup = () -> this.sint0();",
" intSup = () -> super.sint0();",
" intSup = () -> A.sint0();",
" intSup = () -> B.sint0();",
"",
" listSup = () -> Collections.emptyList();",
" listSup = () -> emptyList();",
"",
" Stream.of((Class<?>) int.class).filter(c -> c.isEnum());",
" Stream.of((Map<?, ?>) null).map(Map::keySet).map(s -> s.size());",
" }",
"",
" @Override int iint0() { return 0; }",
" }",
"",
" int iint0() { return 0; }",
"",
" static int sint0() { return 0; }",
"}")
.addOutputLines(
"out/A.java",
"import static java.util.Collections.emptyList;",
"",
"import java.util.Collections;",
"import java.util.List;",
"import java.util.Map;",
"import java.util.Set;",
"import java.util.function.IntConsumer;",
"import java.util.function.IntFunction;",
"import java.util.function.IntSupplier;",
"import java.util.function.Supplier;",
"import java.util.stream.Stream;",
"",
"class A {",
" static class B extends A {",
" final A a = new B();",
" final B b = new B();",
"",
" IntSupplier intSup;",
" Supplier<List<?>> listSup;",
"",
" void m() {",
" intSup = a::iint0;",
" intSup = b::iint0;",
" intSup = this::iint0;",
" intSup = super::iint0;",
"",
" intSup = () -> a.sint0();",
" intSup = () -> b.sint0();",
" intSup = () -> this.sint0();",
" intSup = () -> super.sint0();",
" intSup = A::sint0;",
" intSup = B::sint0;",
"",
" listSup = Collections::emptyList;",
" listSup = Collections::emptyList;",
"",
" Stream.of((Class<?>) int.class).filter(Class::isEnum);",
" Stream.of((Map<?, ?>) null).map(Map::keySet).map(Set::size);",
" }",
"",
" @Override int iint0() { return 0; }",
" }",
"",
" int iint0() { return 0; }",
"",
" static int sint0() { return 0; }",
"}")
.doTest(TestMode.TEXT_MATCH);
}
}