Commit e3128c7f authored by Luca Bataillard's avatar Luca Bataillard
Browse files

implement non-shrinking inlining (all tests pass)

parent 4dc52f1b
...@@ -57,12 +57,15 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] ...@@ -57,12 +57,15 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
private def replaceArgs(args: Seq[Atom], s: State): Seq[Atom] = private def replaceArgs(args: Seq[Atom], s: State): Seq[Atom] =
args map { a => s.aSubst.getOrElse(a, a) } args map { a => s.aSubst.getOrElse(a, a) }
private def replaceCnt(cnt: Symbol, s: State): Symbol =
s.cSubst.getOrElse(cnt, cnt)
private def shrink(tree: Tree, s: State): Tree = tree match { private def shrink(tree: Tree, s: State): Tree = tree match {
case AppC(oldCntName, args) => { case AppC(oldCntName, args) => {
val replacedArgs = replaceArgs(args, s) val replacedArgs = replaceArgs(args, s)
val cntName = s.cSubst.getOrElse(oldCntName, oldCntName) val cntName = replaceCnt(oldCntName, s)
if (s.appliedOnce(cntName) && s.cEnv.contains(cntName)) { if (s.appliedOnce(cntName) && s.cEnv.contains(cntName)) {
val cnt = s.cEnv(cntName) val cnt = s.cEnv(cntName)
...@@ -76,7 +79,7 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] ...@@ -76,7 +79,7 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
case AppF(oldFunAtom, oldRetC, args) => { case AppF(oldFunAtom, oldRetC, args) => {
val funAtom = s.aSubst.getOrElse(oldFunAtom, oldFunAtom) val funAtom = s.aSubst.getOrElse(oldFunAtom, oldFunAtom)
val replacedArgs = replaceArgs(args, s) val replacedArgs = replaceArgs(args, s)
val retC = s.cSubst.getOrElse(oldRetC, oldRetC) val retC = replaceCnt(oldRetC, s)
funAtom match { funAtom match {
case AtomN(n) if s.fEnv.contains(n) && s.appliedOnce(n) => case AtomN(n) if s.fEnv.contains(n) && s.appliedOnce(n) =>
...@@ -131,7 +134,7 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] ...@@ -131,7 +134,7 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
} }
case If(cond, args, thenC, elseC) => case If(cond, args, thenC, elseC) =>
If(cond, replaceArgs(args, s), thenC, elseC) If(cond, replaceArgs(args, s), replaceCnt(thenC, s), replaceCnt(elseC, s))
case Halt(a) => case Halt(a) =>
Halt(s.aSubst.getOrElse(a, a)) Halt(s.aSubst.getOrElse(a, a))
...@@ -192,7 +195,63 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] ...@@ -192,7 +195,63 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
def sameLen[T,U](formalArgs: Seq[T], actualArgs: Seq[U]): Boolean = def sameLen[T,U](formalArgs: Seq[T], actualArgs: Seq[U]): Boolean =
formalArgs.length == actualArgs.length formalArgs.length == actualArgs.length
def inlineT(tree: Tree)(implicit s: State): Tree = ??? def inlineT(tree: Tree)(implicit s: State): Tree = tree match {
case AppC(oldCntName, oldArgs) => {
val cntName = replaceCnt(oldCntName, s)
val args = replaceArgs(oldArgs, s)
if (s.cEnv.contains(cntName)) {
val cnt = copyC(s.cEnv(cntName), s.aSubst, s.cSubst)
val newState = s.withASubst(cnt.args, args)
inlineT(cnt.body)(newState)
} else {
AppC(cntName, args)
}
}
case AppF(oldFunName, oldRetC, oldArgs) => {
val funName = s.aSubst.getOrElse(oldFunName, oldFunName)
val retC = replaceCnt(oldRetC, s)
val args = replaceArgs(oldArgs, s)
funName match {
case AtomN(n) if s.fEnv.contains(n) => {
val oldFun = s.fEnv(n)
/* TODO ask about stopping inlining of recursive functions */
if (!census(oldFun.body).contains(oldFun.name) && sameLen(oldFun.args, args)) {
val fun = copyF(oldFun, s.aSubst, s.cSubst)
val newState = s.withASubst(fun.args, args).withCSubst(fun.retC, retC)
inlineT(fun.body)(newState)
} else {
AppF(funName, retC, args)
}
}
case _ => AppF(funName, retC, args)
}
}
case Halt(arg) => Halt(s.aSubst.getOrElse(arg, arg))
case If(cond, args, thenC, elseC) => If(cond, replaceArgs(args, s), replaceCnt(thenC, s), replaceCnt(elseC, s))
case LetP(name, prim, args, body) => LetP(name, prim, replaceArgs(args, s), inlineT(body))
case LetC(oldCnts, body) => {
val cnts = oldCnts map (c => Cnt(c.name, c.args, inlineT(c.body)))
val inlinedCnts = cnts.filter(c => size(c.body) <= cntLimit)
val newState = s.withCnts(inlinedCnts)
LetC(cnts, inlineT(body)(newState))
}
case LetF(oldFuns, body) => {
val funs = oldFuns map (f => Fun(f.name, f.retC, f.args, inlineT(f.body)))
val inlinedFuns = funs.filter(f => size(f.body) <= funLimit)
val newState = s.withFuns(inlinedFuns)
LetF(funs, inlineT(body)(newState))
}
}
(i + 1, fixedPoint(inlineT(tree)(State(census(tree))))(shrink)) (i + 1, fixedPoint(inlineT(tree)(State(census(tree))))(shrink))
} }
......
...@@ -14,11 +14,17 @@ object Main { ...@@ -14,11 +14,17 @@ object Main {
val backEnd: Tree => TerminalPhaseResult = ( val backEnd: Tree => TerminalPhaseResult = (
CL3ToCPSTranslator CL3ToCPSTranslator
andThen treePrinter("---------- After CPS translation") andThen treePrinter("---------- After CPS translation")
andThen CPSOptimizerHigh
andThen treePrinter("---------- After Optimization (High)")
andThen treeChecker
andThen CPSValueRepresenter andThen CPSValueRepresenter
andThen treePrinter("---------- After value representation") andThen treePrinter("---------- After value representation")
andThen treeChecker andThen treeChecker
andThen treePrinter("---------- After hoisting")
andThen CPSHoister andThen CPSHoister
andThen treePrinter("---------- After hoisting")
andThen CPSOptimizerLow
andThen treePrinter("---------- After Optimization (Low)")
andThen treeChecker
andThen CPSInterpreterLow andThen CPSInterpreterLow
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment