Skip to content
Snippets Groups Projects
Commit 3ad69e52 authored by Sapphie's avatar Sapphie
Browse files

Implement and refactor

parent 95e9f764
No related branches found
No related tags found
No related merge requests found
...@@ -104,8 +104,9 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] ...@@ -104,8 +104,9 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
case LetP(name, prim, args, body) => { case LetP(name, prim, args, body) => {
val replacedArgs = replaceArgs(args, s) val replacedArgs = replaceArgs(args, s)
val newState = s.withExp(name, prim, replacedArgs) lazy val newState = s.withExp(name, prim, replacedArgs)
lazy val shrunkBody = shrink(body, newState) lazy val shrunkBody = shrink(body, newState)
lazy val noOp = LetP(name, prim, replacedArgs, shrunkBody)
// Dead code elim // Dead code elim
if (s.dead(name) && !impure(prim)) { if (s.dead(name) && !impure(prim)) {
...@@ -114,24 +115,32 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] ...@@ -114,24 +115,32 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
val allLitOpt = replacedArgs.map(_.asLiteral) val allLitOpt = replacedArgs.map(_.asLiteral)
val isAllLit = allLitOpt.forall(_.isDefined) val isAllLit = allLitOpt.forall(_.isDefined)
lazy val asLit = allLitOpt.map(_.get) lazy val asLit = allLitOpt.map(_.get)
// TODO: fix this. For some reason it doesn't work
if (false && s.eInvEnv.isDefinedAt((prim, replacedArgs))) { if (false && s.eInvEnv.isDefinedAt((prim, replacedArgs))) {
val preComputedAtom = s.eInvEnv((prim, replacedArgs)) val preComputedAtom = s.eInvEnv((prim, replacedArgs))
shrink(body, newState.withASubst(name, preComputedAtom)) shrink(body, newState.withASubst(name, preComputedAtom))
} else { } else {
// Constant folding replacedArgs match {
if (isAllLit && vEvaluator.isDefinedAt((prim, asLit))) { // Constant folding
val newNewState = newState.withASubst(name, vEvaluator((prim, asLit))) case Seq(AtomL(l1), AtomL(l2)) if vEvaluator.isDefinedAt((prim, asLit)) =>
val newBody = shrink(body, newNewState) shrink(body, s.withASubst(name, vEvaluator((prim, asLit))))
newBody // Same argument reduction
} else { case Seq(a1, a2) if a1 == a2 && sameArgReduce.isDefinedAt(prim, a1) =>
lazy val x = replacedArgs(0) val reduced = sameArgReduce((prim, a1))
lazy val y = replacedArgs(1) shrink(body, s.withASubst(name, sameArgReduce((prim, reduced))))
if (replacedArgs.length == 2 && x == y && sameArgReduce.isDefinedAt((prim, x))) { // Left Neutral
val newNewState = newState.withASubst(name, sameArgReduce((prim, x))) case Seq(AtomL(l1), a2) if leftNeutral((l1, prim)) =>
shrink(body, newNewState) shrink(body, s.withASubst(name, a2))
} else { // Left Absorbing
LetP(name, prim, replacedArgs, shrunkBody) case Seq(AtomL(l1), a2) if leftAbsorbing((l1, prim)) =>
} shrink(body, s.withASubst(name, l1))
// Right Neutral
case Seq(a1, AtomL(l2)) if rightNeutral((prim, l2)) =>
shrink(body, s.withASubst(name, a1))
// Right Absorbing
case Seq(a1, AtomL(l2)) if rightAbsorbing((prim, l2)) =>
shrink(body, s.withASubst(name, l2))
case _ => noOp
} }
} }
} }
...@@ -157,7 +166,26 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] ...@@ -157,7 +166,26 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
} }
case If(cond, args, thenC, elseC) => case If(cond, args, thenC, elseC) =>
If(cond, replaceArgs(args, s), replaceCnt(thenC, s), replaceCnt(elseC, s)) val newArgs = replaceArgs(args, s)
def getApp(b: Boolean): AppC = {
val contToUse = if (b) thenC else elseC
AppC(contToUse, newArgs)
}
val allLitOpt = newArgs.map(_.asLiteral)
val isAllLit = allLitOpt.forall(_.isDefined)
lazy val asLit = allLitOpt.map(_.get)
// Constant folding
if (isAllLit && cEvaluator.isDefinedAt((cond, asLit))) {
getApp(cEvaluator((cond, asLit)))
} else {
lazy val x = newArgs(0)
lazy val y = newArgs(1)
if (newArgs.length == 2 && x == y) {
getApp(sameArgReduceC(cond))
} else {
If(cond, newArgs, replaceCnt(thenC, s), replaceCnt(elseC, s))
}
}
case Halt(a) => case Halt(a) =>
Halt(s.aSubst.getOrElse(a, a)) Halt(s.aSubst.getOrElse(a, a))
......
...@@ -34,6 +34,7 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) { ...@@ -34,6 +34,7 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
case H.AtomL(CharLit(c)) => L.Halt(L.AtomL(c.toInt)) case H.AtomL(CharLit(c)) => L.Halt(L.AtomL(c.toInt))
case H.AtomL(BooleanLit(b)) => L.Halt(L.AtomL(if (b) 1 else 0)) case H.AtomL(BooleanLit(b)) => L.Halt(L.AtomL(if (b) 1 else 0))
case v1 @ H.AtomN(_) => case v1 @ H.AtomN(_) =>
return L.Halt(rewrite(v))
val haltContName = Symbol.fresh("c-halt") val haltContName = Symbol.fresh("c-halt")
val haltContArgs = Seq(Symbol.fresh("halt_arg")) val haltContArgs = Seq(Symbol.fresh("halt_arg"))
val haltContBody = L.Halt(L.AtomN(haltContArgs(0))) val haltContBody = L.Halt(L.AtomN(haltContArgs(0)))
...@@ -81,27 +82,27 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) { ...@@ -81,27 +82,27 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
val untagUnitCont = mkeUntagCont("unit_untag", 2) val untagUnitCont = mkeUntagCont("unit_untag", 2)
// If it's a unit, untag it, otherwise, immediately skip to halt // If it's a unit, untag it, otherwise, immediately skip to halt
// val unitCheckCont = mkeCheckCont("unit_check", val unitCheckCont = mkeCheckCont("unit_check",
// transformIf(L3.UnitP, Seq(v), untagUnitCont.name, haltCont)) transformIf(L3.UnitP, Seq(v), untagUnitCont.name, haltCont))
// If it's a boolean, untag it, otherwise check if it's a unit // If it's a boolean, untag it, otherwise check if it's a unit
// val boolCheckCont = mkeCheckCont("bool_check", val boolCheckCont = mkeCheckCont("bool_check",
// transformIf(L3.BoolP, Seq(v), untagBoolCont.name, unitCheckCont.name)) transformIf(L3.BoolP, Seq(v), untagBoolCont.name, unitCheckCont.name))
// if it's a character, untag it, otherwise check if it's a boolean // if it's a character, untag it, otherwise check if it's a boolean
//val charCheckCont = mkeCheckCont("char_check", val charCheckCont = mkeCheckCont("char_check",
// transformIf(L3.CharP, Seq(v), untagCharCont.name, boolCheckCont.name)) transformIf(L3.CharP, Seq(v), untagCharCont.name, boolCheckCont.name))
//etc //etc
val letCBody = transformIf(L3.IntP, Seq(v), untagIntCont.name, haltCont) val letCBody = transformIf(L3.IntP, Seq(v), untagIntCont.name, haltCont)
val conts = Seq(untagIntCont, val conts = Seq(untagIntCont,
// untagCharCont, untagCharCont,
// untagBoolCont, untagBoolCont,
// untagUnitCont, untagUnitCont,
// unitCheckCont, unitCheckCont,
// charCheckCont, charCheckCont,
// boolCheckCont boolCheckCont
) )
L.LetC(conts, letCBody) L.LetC(conts, letCBody)
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment