Commit aa62f3c2 authored by Sapphie's avatar Sapphie
Browse files

Implement constant folding for value primitives

parent c6f5bb5f
...@@ -105,12 +105,21 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] ...@@ -105,12 +105,21 @@ abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
case LetP(name, prim, args, body) => { case LetP(name, prim, args, body) => {
val replacedArgs = replaceArgs(args, s) val replacedArgs = replaceArgs(args, s)
val newState = s.withExp(name, prim, replacedArgs) val newState = s.withExp(name, prim, replacedArgs)
val shrunkBody = shrink(body, newState) lazy val shrunkBody = shrink(body, newState)
if (s.dead(name) && !impure(prim)) { if (s.dead(name) && !impure(prim)) {
shrunkBody shrunkBody
} else { } else {
LetP(name, prim, replacedArgs, shrunkBody) val allLitOpt = replacedArgs.map(_.asLiteral)
val isAllLit = allLitOpt.forall(_.isDefined)
lazy val asLit = allLitOpt.map(_.get)
if (isAllLit && vEvaluator.isDefinedAt((prim, asLit))) {
val newNewState = newState.withASubst(name, vEvaluator((prim, asLit)))
val newBody = shrink(body, newNewState)
newBody
} else {
LetP(name, prim, replacedArgs, shrunkBody)
}
} }
} }
......
...@@ -11,6 +11,7 @@ import CPSTreeChecker._ // Implicits required for CPS tree checking ...@@ -11,6 +11,7 @@ import CPSTreeChecker._ // Implicits required for CPS tree checking
object Main { object Main {
def main(args: Array[String]): Unit = { def main(args: Array[String]): Unit = {
val stats = new Statistics()
val backEnd: Tree => TerminalPhaseResult = ( val backEnd: Tree => TerminalPhaseResult = (
CL3ToCPSTranslator CL3ToCPSTranslator
andThen treePrinter("---------- After CPS translation") andThen treePrinter("---------- After CPS translation")
...@@ -25,7 +26,7 @@ object Main { ...@@ -25,7 +26,7 @@ object Main {
andThen CPSOptimizerLow andThen CPSOptimizerLow
andThen treePrinter("---------- After Optimization (Low)") andThen treePrinter("---------- After Optimization (Low)")
andThen treeChecker andThen treeChecker
andThen CPSInterpreterLow andThen (new CPSInterpreterLow(stats.log _))
) )
...@@ -37,6 +38,7 @@ object Main { ...@@ -37,6 +38,7 @@ object Main {
.flatMap(backEnd) match { .flatMap(backEnd) match {
case Right((retCode, maybeMsg)) => case Right((retCode, maybeMsg)) =>
maybeMsg foreach println maybeMsg foreach println
println(stats)
sys.exit(retCode) sys.exit(retCode)
case Left(errMsg) => case Left(errMsg) =>
println(s"Error: $errMsg") println(s"Error: $errMsg")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment