Commit 4dc52f1b authored by Luca Bataillard's avatar Luca Bataillard
Browse files

fix cpshighoptimizer (tests pass)

parent cab13317
...@@ -55,74 +55,71 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] ...@@ -55,74 +55,71 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
private def shrink(tree: Tree): Tree = private def shrink(tree: Tree): Tree =
shrink(tree, State(census(tree))) shrink(tree, State(census(tree)))
private def replaceArgs(args: Seq[Atom], s: State): Seq[Atom] = args map { a => private def replaceArgs(args: Seq[Atom], s: State): Seq[Atom] =
s.aSubst.getOrElse(a, a) args map { a => s.aSubst.getOrElse(a, a) }
}
private def shrink(tree: Tree, s: State): Tree = tree match { private def shrink(tree: Tree, s: State): Tree = tree match {
case AppC(cntName, args) => case AppC(oldCntName, args) => {
val replacedArgs = replaceArgs(args, s) val replacedArgs = replaceArgs(args, s)
if (s.appliedOnce(cntName)) { val cntName = s.cSubst.getOrElse(oldCntName, oldCntName)
// Inline
if (s.appliedOnce(cntName) && s.cEnv.contains(cntName)) {
val cnt = s.cEnv(cntName) val cnt = s.cEnv(cntName)
val newState = s.withASubst(cnt.args, replacedArgs) val newState = s.withASubst(cnt.args, replacedArgs)
shrink(cnt.body, newState) shrink(cnt.body, newState)
} else { } else {
AppC(cntName, replacedArgs) AppC(cntName, replacedArgs)
} }
case AppF(funAtom, retC, args) => }
case AppF(oldFunAtom, oldRetC, args) => {
val funAtom = s.aSubst.getOrElse(oldFunAtom, oldFunAtom)
val replacedArgs = replaceArgs(args, s) val replacedArgs = replaceArgs(args, s)
val retC = s.cSubst.getOrElse(oldRetC, oldRetC)
funAtom match { funAtom match {
case AtomN(n) if s.fEnv.contains(n) && s.appliedOnce(n) => case AtomN(n) if s.fEnv.contains(n) && s.appliedOnce(n) =>
// Inline
val fun = s.fEnv(n) val fun = s.fEnv(n)
val newState = s.withASubst(fun.args, replacedArgs) val newState = s.withASubst(fun.args, replacedArgs).withCSubst(fun.retC, retC)
.withCSubst(fun.retC, retC) shrink(fun.body, newState)
val newBody = shrink(fun.body, newState)
newBody
case _ => AppF(funAtom, retC, replacedArgs) case _ => AppF(funAtom, retC, replacedArgs)
} }
}
case LetF(funs, body) => case LetF(funs, body) => {
val undeadFuns = funs.filter(fun => !s.dead(fun.name)) val undeadFuns = funs.filterNot(f => s.dead(f.name))
val undeadShrunkFuns = undeadFuns.map { fun => val nonInlined = undeadFuns.filter(fun => !s.appliedOnce(fun.name))
Fun(fun.name, fun.retC, fun.args, shrink(fun.body, s))} val newState = s.withFuns(undeadFuns)
val nonInlined = undeadShrunkFuns
.filter(fun => !s.appliedOnce(fun.name)) val shrunkFuns = nonInlined.map(f => Fun(f.name, f.retC, f.args, shrink(f.body, newState)))
val shrunkBody = shrink(body, newState)
val newState = s.withFuns(undeadShrunkFuns)
val newBody = shrink(body, newState) if (shrunkFuns.isEmpty) shrunkBody
if (nonInlined.isEmpty) { else LetF(shrunkFuns, shrunkBody)
newBody
} else {
LetF(nonInlined, newBody)
} }
case LetP(name, prim, args, body) => case LetP(name, prim, args, body) => {
if (s.dead(name) && !impure(prim)) {
shrink(body, s)
} else {
val replacedArgs = replaceArgs(args, s) val replacedArgs = replaceArgs(args, s)
val primArgPair = (prim, replacedArgs) val newState = s.withExp(name, prim, replacedArgs)
if (!unstable(prim) && !impure(prim) && s.eInvEnv.contains(primArgPair)) { val shrunkBody = shrink(body, newState)
val newS = s.withASubst(name, s.eInvEnv(primArgPair))
val newBody = shrink(body, newS) if (s.dead(name) && !impure(prim)) {
newBody shrunkBody
} else { } else {
val newS = s.withExp(name, prim, replacedArgs) LetP(name, prim, replacedArgs, shrunkBody)
LetP(name, prim, replacedArgs, shrink(body, newS))
} }
} }
case LetC(cnts, body) => case LetC(cnts, body) => {
val undeadConts = cnts.filter(cnt => !s.dead(cnt.name)) val undeadConts = cnts.filterNot(cnt => s.dead(cnt.name))
val undeadShrunkConts = undeadConts.map { cnt => val undeadShrunkConts = undeadConts.map { cnt =>
val newBody = shrink(cnt.body, s) val newBody = shrink(cnt.body, s)
Cnt(cnt.name, cnt.args, newBody) Cnt(cnt.name, cnt.args, newBody)
} }
val nonInlinedConts = undeadShrunkConts.filter { cnt => val nonInlinedConts = undeadShrunkConts.filterNot { cnt =>
!s.appliedOnce(cnt.name) s.appliedOnce(cnt.name)
} }
val newBody = shrink(body, s.withCnts(undeadShrunkConts)) val newBody = shrink(body, s.withCnts(undeadShrunkConts))
...@@ -131,6 +128,10 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] ...@@ -131,6 +128,10 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
} else { } else {
LetC(nonInlinedConts, newBody) LetC(nonInlinedConts, newBody)
} }
}
case If(cond, args, thenC, elseC) =>
If(cond, replaceArgs(args, s), thenC, elseC)
case Halt(a) => case Halt(a) =>
Halt(s.aSubst.getOrElse(a, a)) Halt(s.aSubst.getOrElse(a, a))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment