Commit f462ed60 authored by Sapphie's avatar Sapphie
Browse files

Implement constant cases for cond translation

parent a81981a1
......@@ -5,11 +5,93 @@ import l3.{ SymbolicCL3TreeModule => S}
import l3.{ SymbolicCPSTreeModule => C}
object CL3ToCPSTranslator extends (S.Tree => C.Tree) {
def apply(tree: S.Tree): C.Tree = transform(tree)(_ =>
def apply(tree: S.Tree): C.Tree = nonTail(tree)(_ =>
C.Halt(C.AtomL(IntLit(L3Int(0))))
)
def tail(tree: S.Tree)(t: C.Name): C.Tree = ???
def cond(condition: S.Tree)(thnCnt: C.Name, elsCnt: C.Name): C.Tree = {
def isTrue(l: CL3Literal): Boolean = l match {
case BooleanLit(false) => false
case _ => true
}
// When e2 is a literal
def thenLit(e1: S.Tree, e2lit: CL3Literal, e3: S.Tree): C.Tree = {
val aName = Symbol.fresh("ac")
val aCnt = C.Cnt(aName, Seq(), cond(e3)(thnCnt, elsCnt))
if (isTrue(e2lit)) {
C.LetC(Seq(aCnt), cond(e1)(thnCnt, aName))
} else {
C.LetC(Seq(aCnt), cond(e1)(elsCnt, aName))
}
}
def elseLit(e1: S.Tree, e2: S.Tree, e3lit: CL3Literal): C.Tree = {
val aName = Symbol.fresh("ac")
val aCnt = C.Cnt(aName, Seq(), cond(e2)(thnCnt, elsCnt))
if (isTrue(e3lit)) {
C.LetC(Seq(aCnt), cond(e1)(aName, thnCnt))
} else {
C.LetC(Seq(aCnt), cond(e1)(aName, elsCnt))
}
}
condition match {
case S.If(e1, e2, e3) =>
e1 match {
// Result depends entirely on e3
case S.Lit(BooleanLit(false)) =>
cond(e3)(thnCnt, elsCnt)
// Result depends entirely on e2
case S.Lit(_) =>
cond(e2)(thnCnt, elsCnt)
// e1 is not constant
case _ =>
(e2, e3) match {
case (S.Lit(l1), S.Lit(l2)) =>
// The result is entirely determined by e1.
// The continuations are determined by the constants
// present in e2 and e3
val cnts = Seq(l1, l2).map(_ match {
case BooleanLit(false) => elsCnt
case _ => thnCnt
})
// Technically, here we could compile out e1 (we'd need
// to check for side effects tho).
cond(e1)(cnts(0), cnts(1))
// If e1 evaluates to true, jump directly to one of the conts
case (S.Lit(l), _) => thenLit(e1, l, e3)
// If e1 evaluates to false, jump directly to the else cont
case (_, S.Lit(l)) => elseLit(e1, e2, l)
case _ => ???
}
}
/*
case S.If(e1, e2, S.Lit(BooleanLit(false))) =>
val aName = Symbol.fresh("ac")
val aCnt = C.Cnt(aName, Seq(), cond(e2)(thnCnt, elsCnt))
C.LetC(Seq(aCnt), cond(e1)(aName, elsCnt))
// If e1 is false, jump directly to then
case S.If(e1, e2, S.Lit(BooleanLit(true))) =>
val aName = Symbol.fresh("ac")
val aCnt = C.Cnt(aName, Seq(), cond(e2)(thnCnt, elsCnt))
C.LetC(Seq(aCnt), cond(e1)(aName, thnCnt))
*/
}
}
def transform(tree: S.Tree)(ctx: C.Atom => C.Tree): C.Tree = {
def nonTail(tree: S.Tree)(ctx: C.Atom => C.Tree): C.Tree = {
def transformApp(fun: S.Tree, args: Seq[S.Tree]) = {
......@@ -22,11 +104,11 @@ object CL3ToCPSTranslator extends (S.Tree => C.Tree) {
// expression. Return resulting tree
def transformArgs(args: Seq[S.Tree])(fId: C.Atom, ids: Seq[C.Atom]): C.Tree = args match {
case Nil => C.LetC(Seq(returnCnt), C.AppF(fId, cntName, ids))
case arg :: args => transform(arg)(v => transformArgs(args)(fId, ids :+ v))
case arg :: args => nonTail(arg)(v => transformArgs(args)(fId, ids :+ v))
}
// Transform initial function then start argument transformation
transform(fun)(fId => transformArgs(args)(fId, Seq()))
nonTail(fun)(fId => transformArgs(args)(fId, Seq()))
}
......@@ -35,10 +117,10 @@ object CL3ToCPSTranslator extends (S.Tree => C.Tree) {
val cName = Symbol.fresh("c-" + name)
val funCtx = (v: C.Atom) => C.AppC(cName, Seq(v))
C.Fun(name, cName, args, transform(body)(funCtx))
C.Fun(name, cName, args, tail(body)(cName))
}
C.LetF(cpsFuns, transform(body)(ctx))
C.LetF(cpsFuns, nonTail(body)(ctx))
}
def transformValPrim(p: L3ValuePrimitive, primArgs: Seq[S.Tree]) = {
......@@ -51,7 +133,7 @@ object CL3ToCPSTranslator extends (S.Tree => C.Tree) {
C.LetP(name, p, atoms, body)
case x :: xs =>
// Transform the current argument and add it to the list
transform(x)(v => transformArgs(xs)(atoms :+ v))
nonTail(x)(v => transformArgs(xs)(atoms :+ v))
}
transformArgs(primArgs)(Seq())
}
......@@ -64,7 +146,7 @@ object CL3ToCPSTranslator extends (S.Tree => C.Tree) {
def transformArgs(args: Seq[S.Tree])(atoms: Seq[C.Atom]): C.Tree = args match {
case Nil => C.If(p, atoms, thenC, elseC)
case x::xs =>
transform(x)(v => transformArgs(xs)(atoms :+ v))
nonTail(x)(v => transformArgs(xs)(atoms :+ v))
}
transformArgs(primArgs)(Seq())
}
......@@ -73,34 +155,28 @@ object CL3ToCPSTranslator extends (S.Tree => C.Tree) {
case S.Lit(v) => ctx(C.AtomL(v))
case S.Ident(id) => ctx(C.AtomN(id))
case S.Let(bindings, body) =>
bindings.foldRight(transform(body)(ctx)) {
case ((name, expr), tree) => transform(expr)(v => C.LetP(name, L3.Id, Seq(v), tree))
bindings.foldRight(nonTail(body)(ctx)) {
case ((name, expr), tree) => nonTail(expr)(v => C.LetP(name, L3.Id, Seq(v), tree))
}
case S.If(S.Prim(p: L3TestPrimitive, primArgs), thn, els) => {
case S.If(iff, thn, els) => {
val retCntName = Symbol.fresh("c")
val retArgName = Symbol.fresh("r")
val retCnt = C.Cnt(retCntName, Seq(retArgName), ctx(C.AtomN(retArgName)))
val thnCntName = Symbol.fresh("ct")
val thnCntBody = transform(thn)(v => C.AppC(retCntName, Seq(v)))
val thnCntBody = tail(thn)(retCntName)
val thnCnt = C.Cnt(thnCntName, Seq(), thnCntBody)
val elsCntName = Symbol.fresh("cf")
val elsCntBody = transform(els)(v => C.AppC(retCntName, Seq(v)))
val elsCntBody = tail(els)(retCntName)
val elsCnt = C.Cnt(elsCntName, Seq(), elsCntBody)
val ifBody = transformTestPrim(p, primArgs, thnCntName, elsCntName)
val ifBody = cond(iff)(thnCntName, elsCntName)
C.LetC(Seq(retCnt), C.LetC(Seq(thnCnt), C.LetC(Seq(elsCnt), ifBody)))
}
case S.If(cond, thn, els) => {
implicit val pos = tree.pos
val primCond = S.Prim(L3.Eq, Seq(cond, S.Lit(BooleanLit(false))))
transform(S.If(primCond, els, thn))(ctx)
}
case S.App(fun, args) => transformApp(fun, args)
case S.LetRec(funs, body) => transformLetRec(funs, body)
case S.Prim(p, args) => p match {
......@@ -109,10 +185,10 @@ object CL3ToCPSTranslator extends (S.Tree => C.Tree) {
S.Lit(BooleanLit(true))(tree.pos),
S.Lit(BooleanLit(false))(tree.pos)
)(tree.pos)
transform(newPrim)(ctx)
nonTail(newPrim)(ctx)
case p : L3ValuePrimitive => transformValPrim(p, args)
}
case S.Halt(arg) => transform(arg)(a => C.Halt(a))
case S.Halt(arg) => nonTail(arg)(a => C.Halt(a))
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment