Merging subsequent .filter()'s

This commit is contained in:
Valentin Kipyatkov
2016-04-05 22:39:08 +03:00
parent 6dbd9c944a
commit fcbf68617e
11 changed files with 75 additions and 17 deletions

View File

@@ -30,6 +30,8 @@ interface Transformation {
}
interface SequenceTransformation : Transformation {
fun mergeWithPrevious(previousTransformation: SequenceTransformation): SequenceTransformation? = null
val affectsIndex: Boolean
fun generateCode(chainedCallGenerator: ChainedCallGenerator): KtExpression

View File

@@ -71,8 +71,9 @@ fun match(loop: KtForExpression): ResultTransformationMatch? {
val match = matcher.match(state)
if (match != null) {
sequenceTransformations.addAll(match.sequenceTransformations)
val resultMatch = ResultTransformationMatch(match.resultTransformation, sequenceTransformations)
return resultMatch.check { checkSmartCastsPreserved(loop, it) }
return ResultTransformationMatch(match.resultTransformation, sequenceTransformations)
.let { mergeTransformations(it) }
.check { checkSmartCastsPreserved(loop, it) }
}
}
@@ -182,3 +183,29 @@ private fun ResultTransformationMatch.generateCallChain(loop: KtForExpression):
return callChain
}
private fun mergeTransformations(match: ResultTransformationMatch): ResultTransformationMatch {
val transformations = ArrayList<Transformation>().apply { addAll(match.sequenceTransformations); add(match.resultTransformation) }
var anyChange: Boolean
do {
anyChange = false
for (index in 0..transformations.lastIndex - 1) {
val transformation = transformations[index] as SequenceTransformation
val next = transformations[index + 1]
val merged = when (next) {
is SequenceTransformation -> next.mergeWithPrevious(transformation)
is ResultTransformation -> next.mergeWithPrevious(transformation)
else -> error("Unknown transformation type: $next")
} ?: continue
transformations[index] = merged
transformations.removeAt(index + 1)
anyChange = true
break
}
} while(anyChange)
@Suppress("UNCHECKED_CAST")
return ResultTransformationMatch(transformations.last() as ResultTransformation, transformations.dropLast(1) as List<SequenceTransformation>)
}

View File

@@ -34,7 +34,7 @@ class FindAndAssignTransformation(
override fun mergeWithPrevious(previousTransformation: SequenceTransformation): ResultTransformation? {
if (previousTransformation !is FilterTransformation) return null
if (filter != null) return null //TODO
assert(filter == null) { "Should not happen because no 2 consecutive FilterTransformation's possible"}
return FindAndAssignTransformation(loop, previousTransformation.inputVariable, stdlibFunName, initialDeclaration, previousTransformation.buildRealCondition())
}

View File

@@ -34,7 +34,7 @@ class FindAndReturnTransformation(
override fun mergeWithPrevious(previousTransformation: SequenceTransformation): ResultTransformation? {
if (previousTransformation !is FilterTransformation) return null
if (filter != null) return null //TODO
assert(filter == null) { "Should not happen because no 2 consecutive FilterTransformation's possible"}
return FindAndReturnTransformation(loop, previousTransformation.inputVariable, stdlibFunName, endReturn, previousTransformation.buildRealCondition())
}

View File

@@ -29,12 +29,16 @@ class FilterTransformation(
val isInverse: Boolean
) : SequenceTransformation {
init {
assert(condition.isPhysical)
}
fun buildRealCondition() = if (isInverse) condition.negate() else condition
override fun mergeWithPrevious(previousTransformation: SequenceTransformation): SequenceTransformation? {
if (previousTransformation !is FilterTransformation) return null
assert(previousTransformation.inputVariable == inputVariable)
val mergedCondition = KtPsiFactory(condition).createExpressionByPattern(
"$0 && $1", previousTransformation.buildRealCondition(), buildRealCondition())
return FilterTransformation(inputVariable, mergedCondition, isInverse = false) //TODO: build filterNot in some cases?
}
override val affectsIndex: Boolean
get() = true
@@ -44,7 +48,6 @@ class FilterTransformation(
return chainedCallGenerator.generate("$0$1:'{}'", name, lambda)
}
//TODO: merge subsequent filters
/**
* Matches:
* for (...) {

View File

@@ -30,10 +30,6 @@ class FlatMapTransformation(
private val transform: KtExpression
) : SequenceTransformation {
init {
assert(transform.isPhysical)
}
override val affectsIndex: Boolean
get() = true

View File

@@ -26,10 +26,6 @@ class MapTransformation(
val mapping: KtExpression
) : SequenceTransformation {
init {
assert(mapping.isPhysical)
}
override val affectsIndex: Boolean
get() = false

View File

@@ -0,0 +1,12 @@
// WITH_RUNTIME
fun foo(list: List<String>): String? {
<caret>for (s in list) {
if (s.isEmpty()) continue
if (s.length < 10 && s != "abc") {
if (s == "def") continue
val s1 = s + "x"
return s1
}
}
return null
}

View File

@@ -0,0 +1,7 @@
// WITH_RUNTIME
fun foo(list: List<String>): String? {
return list
.filter { !it.isEmpty() && it.length < 10 && it != "abc" && it != "def" }
.map { it + "x" }
.firstOrNull()
}

View File

@@ -0,0 +1,11 @@
// WITH_RUNTIME
fun foo(list: List<String>): String? {
<caret>for (s in list) {
if (s.isEmpty()) continue
if (s.length < 10 && s != "abc") {
if (s == "def") continue
return s
}
}
return null
}

View File

@@ -0,0 +1,4 @@
// WITH_RUNTIME
fun foo(list: List<String>): String? {
<caret>return list.firstOrNull { !it.isEmpty() && it.length < 10 && it != "abc" && it != "def" }
}