Skip to content
Snippets Groups Projects
Commit 3ad69e52 authored by Sapphie's avatar Sapphie
Browse files

Implement and refactor

parent 95e9f764
No related branches found
No related tags found
No related merge requests found
......@@ -104,8 +104,9 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
case LetP(name, prim, args, body) => {
val replacedArgs = replaceArgs(args, s)
val newState = s.withExp(name, prim, replacedArgs)
lazy val newState = s.withExp(name, prim, replacedArgs)
lazy val shrunkBody = shrink(body, newState)
lazy val noOp = LetP(name, prim, replacedArgs, shrunkBody)
// Dead code elim
if (s.dead(name) && !impure(prim)) {
......@@ -114,24 +115,32 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
val allLitOpt = replacedArgs.map(_.asLiteral)
val isAllLit = allLitOpt.forall(_.isDefined)
lazy val asLit = allLitOpt.map(_.get)
// TODO: fix this. For some reason it doesn't work
if (false && s.eInvEnv.isDefinedAt((prim, replacedArgs))) {
val preComputedAtom = s.eInvEnv((prim, replacedArgs))
shrink(body, newState.withASubst(name, preComputedAtom))
} else {
// Constant folding
if (isAllLit && vEvaluator.isDefinedAt((prim, asLit))) {
val newNewState = newState.withASubst(name, vEvaluator((prim, asLit)))
val newBody = shrink(body, newNewState)
newBody
} else {
lazy val x = replacedArgs(0)
lazy val y = replacedArgs(1)
if (replacedArgs.length == 2 && x == y && sameArgReduce.isDefinedAt((prim, x))) {
val newNewState = newState.withASubst(name, sameArgReduce((prim, x)))
shrink(body, newNewState)
} else {
LetP(name, prim, replacedArgs, shrunkBody)
}
replacedArgs match {
// Constant folding
case Seq(AtomL(l1), AtomL(l2)) if vEvaluator.isDefinedAt((prim, asLit)) =>
shrink(body, s.withASubst(name, vEvaluator((prim, asLit))))
// Same argument reduction
case Seq(a1, a2) if a1 == a2 && sameArgReduce.isDefinedAt(prim, a1) =>
val reduced = sameArgReduce((prim, a1))
shrink(body, s.withASubst(name, sameArgReduce((prim, reduced))))
// Left Neutral
case Seq(AtomL(l1), a2) if leftNeutral((l1, prim)) =>
shrink(body, s.withASubst(name, a2))
// Left Absorbing
case Seq(AtomL(l1), a2) if leftAbsorbing((l1, prim)) =>
shrink(body, s.withASubst(name, l1))
// Right Neutral
case Seq(a1, AtomL(l2)) if rightNeutral((prim, l2)) =>
shrink(body, s.withASubst(name, a1))
// Right Absorbing
case Seq(a1, AtomL(l2)) if rightAbsorbing((prim, l2)) =>
shrink(body, s.withASubst(name, l2))
case _ => noOp
}
}
}
......@@ -157,7 +166,26 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
}
case If(cond, args, thenC, elseC) =>
If(cond, replaceArgs(args, s), replaceCnt(thenC, s), replaceCnt(elseC, s))
val newArgs = replaceArgs(args, s)
def getApp(b: Boolean): AppC = {
val contToUse = if (b) thenC else elseC
AppC(contToUse, newArgs)
}
val allLitOpt = newArgs.map(_.asLiteral)
val isAllLit = allLitOpt.forall(_.isDefined)
lazy val asLit = allLitOpt.map(_.get)
// Constant folding
if (isAllLit && cEvaluator.isDefinedAt((cond, asLit))) {
getApp(cEvaluator((cond, asLit)))
} else {
lazy val x = newArgs(0)
lazy val y = newArgs(1)
if (newArgs.length == 2 && x == y) {
getApp(sameArgReduceC(cond))
} else {
If(cond, newArgs, replaceCnt(thenC, s), replaceCnt(elseC, s))
}
}
case Halt(a) =>
Halt(s.aSubst.getOrElse(a, a))
......
......@@ -34,6 +34,7 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
case H.AtomL(CharLit(c)) => L.Halt(L.AtomL(c.toInt))
case H.AtomL(BooleanLit(b)) => L.Halt(L.AtomL(if (b) 1 else 0))
case v1 @ H.AtomN(_) =>
return L.Halt(rewrite(v))
val haltContName = Symbol.fresh("c-halt")
val haltContArgs = Seq(Symbol.fresh("halt_arg"))
val haltContBody = L.Halt(L.AtomN(haltContArgs(0)))
......@@ -81,27 +82,27 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
val untagUnitCont = mkeUntagCont("unit_untag", 2)
// If it's a unit, untag it, otherwise, immediately skip to halt
// val unitCheckCont = mkeCheckCont("unit_check",
// transformIf(L3.UnitP, Seq(v), untagUnitCont.name, haltCont))
val unitCheckCont = mkeCheckCont("unit_check",
transformIf(L3.UnitP, Seq(v), untagUnitCont.name, haltCont))
// If it's a boolean, untag it, otherwise check if it's a unit
// val boolCheckCont = mkeCheckCont("bool_check",
// transformIf(L3.BoolP, Seq(v), untagBoolCont.name, unitCheckCont.name))
val boolCheckCont = mkeCheckCont("bool_check",
transformIf(L3.BoolP, Seq(v), untagBoolCont.name, unitCheckCont.name))
// if it's a character, untag it, otherwise check if it's a boolean
//val charCheckCont = mkeCheckCont("char_check",
// transformIf(L3.CharP, Seq(v), untagCharCont.name, boolCheckCont.name))
val charCheckCont = mkeCheckCont("char_check",
transformIf(L3.CharP, Seq(v), untagCharCont.name, boolCheckCont.name))
//etc
val letCBody = transformIf(L3.IntP, Seq(v), untagIntCont.name, haltCont)
val conts = Seq(untagIntCont,
// untagCharCont,
// untagBoolCont,
// untagUnitCont,
// unitCheckCont,
// charCheckCont,
// boolCheckCont
untagCharCont,
untagBoolCont,
untagUnitCont,
unitCheckCont,
charCheckCont,
boolCheckCont
)
L.LetC(conts, letCBody)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment