Commit 3ad69e52 authored by Sapphie's avatar Sapphie
Browse files

Implement and refactor

parent 95e9f764
...@@ -104,8 +104,9 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] ...@@ -104,8 +104,9 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
case LetP(name, prim, args, body) => { case LetP(name, prim, args, body) => {
val replacedArgs = replaceArgs(args, s) val replacedArgs = replaceArgs(args, s)
val newState = s.withExp(name, prim, replacedArgs) lazy val newState = s.withExp(name, prim, replacedArgs)
lazy val shrunkBody = shrink(body, newState) lazy val shrunkBody = shrink(body, newState)
lazy val noOp = LetP(name, prim, replacedArgs, shrunkBody)
// Dead code elim // Dead code elim
if (s.dead(name) && !impure(prim)) { if (s.dead(name) && !impure(prim)) {
...@@ -114,24 +115,32 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] ...@@ -114,24 +115,32 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
val allLitOpt = replacedArgs.map(_.asLiteral) val allLitOpt = replacedArgs.map(_.asLiteral)
val isAllLit = allLitOpt.forall(_.isDefined) val isAllLit = allLitOpt.forall(_.isDefined)
lazy val asLit = allLitOpt.map(_.get) lazy val asLit = allLitOpt.map(_.get)
// TODO: fix this. For some reason it doesn't work
if (false && s.eInvEnv.isDefinedAt((prim, replacedArgs))) { if (false && s.eInvEnv.isDefinedAt((prim, replacedArgs))) {
val preComputedAtom = s.eInvEnv((prim, replacedArgs)) val preComputedAtom = s.eInvEnv((prim, replacedArgs))
shrink(body, newState.withASubst(name, preComputedAtom)) shrink(body, newState.withASubst(name, preComputedAtom))
} else { } else {
replacedArgs match {
// Constant folding // Constant folding
if (isAllLit && vEvaluator.isDefinedAt((prim, asLit))) { case Seq(AtomL(l1), AtomL(l2)) if vEvaluator.isDefinedAt((prim, asLit)) =>
val newNewState = newState.withASubst(name, vEvaluator((prim, asLit))) shrink(body, s.withASubst(name, vEvaluator((prim, asLit))))
val newBody = shrink(body, newNewState) // Same argument reduction
newBody case Seq(a1, a2) if a1 == a2 && sameArgReduce.isDefinedAt(prim, a1) =>
} else { val reduced = sameArgReduce((prim, a1))
lazy val x = replacedArgs(0) shrink(body, s.withASubst(name, sameArgReduce((prim, reduced))))
lazy val y = replacedArgs(1) // Left Neutral
if (replacedArgs.length == 2 && x == y && sameArgReduce.isDefinedAt((prim, x))) { case Seq(AtomL(l1), a2) if leftNeutral((l1, prim)) =>
val newNewState = newState.withASubst(name, sameArgReduce((prim, x))) shrink(body, s.withASubst(name, a2))
shrink(body, newNewState) // Left Absorbing
} else { case Seq(AtomL(l1), a2) if leftAbsorbing((l1, prim)) =>
LetP(name, prim, replacedArgs, shrunkBody) shrink(body, s.withASubst(name, l1))
} // Right Neutral
case Seq(a1, AtomL(l2)) if rightNeutral((prim, l2)) =>
shrink(body, s.withASubst(name, a1))
// Right Absorbing
case Seq(a1, AtomL(l2)) if rightAbsorbing((prim, l2)) =>
shrink(body, s.withASubst(name, l2))
case _ => noOp
} }
} }
} }
...@@ -157,7 +166,26 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] ...@@ -157,7 +166,26 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
} }
case If(cond, args, thenC, elseC) => case If(cond, args, thenC, elseC) =>
If(cond, replaceArgs(args, s), replaceCnt(thenC, s), replaceCnt(elseC, s)) val newArgs = replaceArgs(args, s)
def getApp(b: Boolean): AppC = {
val contToUse = if (b) thenC else elseC
AppC(contToUse, newArgs)
}
val allLitOpt = newArgs.map(_.asLiteral)
val isAllLit = allLitOpt.forall(_.isDefined)
lazy val asLit = allLitOpt.map(_.get)
// Constant folding
if (isAllLit && cEvaluator.isDefinedAt((cond, asLit))) {
getApp(cEvaluator((cond, asLit)))
} else {
lazy val x = newArgs(0)
lazy val y = newArgs(1)
if (newArgs.length == 2 && x == y) {
getApp(sameArgReduceC(cond))
} else {
If(cond, newArgs, replaceCnt(thenC, s), replaceCnt(elseC, s))
}
}
case Halt(a) => case Halt(a) =>
Halt(s.aSubst.getOrElse(a, a)) Halt(s.aSubst.getOrElse(a, a))
......
...@@ -34,6 +34,7 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) { ...@@ -34,6 +34,7 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
case H.AtomL(CharLit(c)) => L.Halt(L.AtomL(c.toInt)) case H.AtomL(CharLit(c)) => L.Halt(L.AtomL(c.toInt))
case H.AtomL(BooleanLit(b)) => L.Halt(L.AtomL(if (b) 1 else 0)) case H.AtomL(BooleanLit(b)) => L.Halt(L.AtomL(if (b) 1 else 0))
case v1 @ H.AtomN(_) => case v1 @ H.AtomN(_) =>
return L.Halt(rewrite(v))
val haltContName = Symbol.fresh("c-halt") val haltContName = Symbol.fresh("c-halt")
val haltContArgs = Seq(Symbol.fresh("halt_arg")) val haltContArgs = Seq(Symbol.fresh("halt_arg"))
val haltContBody = L.Halt(L.AtomN(haltContArgs(0))) val haltContBody = L.Halt(L.AtomN(haltContArgs(0)))
...@@ -81,27 +82,27 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) { ...@@ -81,27 +82,27 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
val untagUnitCont = mkeUntagCont("unit_untag", 2) val untagUnitCont = mkeUntagCont("unit_untag", 2)
// If it's a unit, untag it, otherwise, immediately skip to halt // If it's a unit, untag it, otherwise, immediately skip to halt
// val unitCheckCont = mkeCheckCont("unit_check", val unitCheckCont = mkeCheckCont("unit_check",
// transformIf(L3.UnitP, Seq(v), untagUnitCont.name, haltCont)) transformIf(L3.UnitP, Seq(v), untagUnitCont.name, haltCont))
// If it's a boolean, untag it, otherwise check if it's a unit // If it's a boolean, untag it, otherwise check if it's a unit
// val boolCheckCont = mkeCheckCont("bool_check", val boolCheckCont = mkeCheckCont("bool_check",
// transformIf(L3.BoolP, Seq(v), untagBoolCont.name, unitCheckCont.name)) transformIf(L3.BoolP, Seq(v), untagBoolCont.name, unitCheckCont.name))
// if it's a character, untag it, otherwise check if it's a boolean // if it's a character, untag it, otherwise check if it's a boolean
//val charCheckCont = mkeCheckCont("char_check", val charCheckCont = mkeCheckCont("char_check",
// transformIf(L3.CharP, Seq(v), untagCharCont.name, boolCheckCont.name)) transformIf(L3.CharP, Seq(v), untagCharCont.name, boolCheckCont.name))
//etc //etc
val letCBody = transformIf(L3.IntP, Seq(v), untagIntCont.name, haltCont) val letCBody = transformIf(L3.IntP, Seq(v), untagIntCont.name, haltCont)
val conts = Seq(untagIntCont, val conts = Seq(untagIntCont,
// untagCharCont, untagCharCont,
// untagBoolCont, untagBoolCont,
// untagUnitCont, untagUnitCont,
// unitCheckCont, unitCheckCont,
// charCheckCont, charCheckCont,
// boolCheckCont boolCheckCont
) )
L.LetC(conts, letCBody) L.LetC(conts, letCBody)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment