Commit 79a9e1d0 authored by Luca Bataillard's avatar Luca Bataillard
Browse files

add identity inlining (one test fails)

parent e86031e1
...@@ -9,7 +9,7 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] ...@@ -9,7 +9,7 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
protected def rewrite(tree: Tree): Tree = { protected def rewrite(tree: Tree): Tree = {
val simplifiedTree = fixedPoint(tree)(shrink) val simplifiedTree = fixedPoint(tree)(shrink)
val maxSize = size(simplifiedTree) * 3 / 2 val maxSize = size(simplifiedTree) * 3 / 2
fixedPoint(simplifiedTree, 8) { t => inline(t, maxSize) } fixedPoint(simplifiedTree, 2) { t => inline(t, maxSize) }
} }
private case class Count(applied: Int = 0, asValue: Int = 0) private case class Count(applied: Int = 0, asValue: Int = 0)
...@@ -113,44 +113,38 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] ...@@ -113,44 +113,38 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
shrunkBody shrunkBody
} else { } else {
val allLitOpt = replacedArgs.map(_.asLiteral) val allLitOpt = replacedArgs.map(_.asLiteral)
val isAllLit = allLitOpt.forall(_.isDefined)
lazy val asLit = allLitOpt.map(_.get) lazy val asLit = allLitOpt.map(_.get)
if (!unstable(prim) && !impure(prim) && s.eInvEnv.isDefinedAt((prim, replacedArgs))) {
if (!unstable(prim) && !impure(prim) && s.eInvEnv.contains((prim, replacedArgs))) {
val preComputedAtom = s.eInvEnv((prim, replacedArgs)) val preComputedAtom = s.eInvEnv((prim, replacedArgs))
shrink(body, newState.withASubst(name, preComputedAtom)) shrink(body, newState.withASubst(name, preComputedAtom))
} else { } else if (prim == identity) {
replacedArgs match { shrink(body, s.withASubst(name, replacedArgs(0)))
// Constant folding } else replacedArgs match {
case Seq(AtomL(l1), AtomL(l2)) if vEvaluator.isDefinedAt((prim, asLit)) => // Constant folding
shrink(body, s.withASubst(name, vEvaluator((prim, asLit)))) case Seq(AtomL(l1), AtomL(l2)) if vEvaluator.isDefinedAt((prim, asLit)) =>
// Same argument reduction shrink(body, s.withASubst(name, vEvaluator((prim, asLit))))
case Seq(a1, a2) if a1 == a2 && sameArgReduce.isDefinedAt(prim, a1) => // Same argument reduction
val reduced = sameArgReduce((prim, a1)) case Seq(a1, a2) if a1 == a2 && sameArgReduce.isDefinedAt(prim, a1) =>
shrink(body, s.withASubst(name, sameArgReduce((prim, reduced)))) val reduced = sameArgReduce((prim, a1))
// Left Neutral shrink(body, s.withASubst(name, sameArgReduce((prim, reduced))))
case Seq(AtomL(l1), a2) if leftNeutral((l1, prim)) => // Left Neutral
shrink(body, s.withASubst(name, a2)) case Seq(AtomL(l1), a2) if leftNeutral((l1, prim)) =>
// Left Absorbing shrink(body, s.withASubst(name, a2))
case Seq(AtomL(l1), a2) if leftAbsorbing((l1, prim)) => // Left Absorbing
shrink(body, s.withASubst(name, l1)) case Seq(AtomL(l1), a2) if leftAbsorbing((l1, prim)) =>
// Right Neutral shrink(body, s.withASubst(name, l1))
case Seq(a1, AtomL(l2)) if rightNeutral((prim, l2)) => // Right Neutral
shrink(body, s.withASubst(name, a1)) case Seq(a1, AtomL(l2)) if rightNeutral((prim, l2)) =>
// Right Absorbing shrink(body, s.withASubst(name, a1))
case Seq(a1, AtomL(l2)) if rightAbsorbing((prim, l2)) => // Right Absorbing
shrink(body, s.withASubst(name, l2)) case Seq(a1, AtomL(l2)) if rightAbsorbing((prim, l2)) =>
case _ => { shrink(body, s.withASubst(name, l2))
if (false && prim == identity) { case _ => noOp
shrink(body, s.withASubst(name, replacedArgs(0)))
} else {
noOp
}
}
}
} }
} }
} }
case LetC(cnts, body) => { case LetC(cnts, body) => {
val undeadConts = cnts.filterNot(cnt => s.dead(cnt.name)) val undeadConts = cnts.filterNot(cnt => s.dead(cnt.name))
val undeadShrunkConts = undeadConts.map { cnt => val undeadShrunkConts = undeadConts.map { cnt =>
...@@ -274,9 +268,11 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] ...@@ -274,9 +268,11 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
funName match { funName match {
case AtomN(n) if s.fEnv.contains(n) => { case AtomN(n) if s.fEnv.contains(n) => {
val oldFun = s.fEnv(n) val oldFun = s.fEnv(n)
val notUsedAsValue = s.census.get(oldFun.name).fold(false)(count => count.asValue == 0)
val nonRecursive = notUsedAsValue && !census(oldFun.body).contains(oldFun.name)
/* TODO ask about stopping inlining of recursive functions */ /* TODO ask about stopping inlining of recursive functions */
if (!census(oldFun.body).contains(oldFun.name) && sameLen(oldFun.args, args)) { if (nonRecursive && sameLen(oldFun.args, args)) {
val fun = copyF(oldFun, s.aSubst, s.cSubst) val fun = copyF(oldFun, s.aSubst, s.cSubst)
val newState = s.withASubst(fun.args, args).withCSubst(fun.retC, retC) val newState = s.withASubst(fun.args, args).withCSubst(fun.retC, retC)
inlineT(fun.body)(newState) inlineT(fun.body)(newState)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment