Commit d10e6c58 authored by Luca Bataillard's avatar Luca Bataillard
Browse files

fix if translation and clean up code

parent 256ce02f
......@@ -9,100 +9,57 @@ object CL3ToCPSTranslator extends (S.Tree => C.Tree) {
C.Halt(C.AtomL(IntLit(L3Int(0))))
)
def tail(tree: S.Tree)(t: C.Name): C.Tree = ???
def tail(tree: S.Tree)(c: C.Name): C.Tree =
nonTail(tree)(v => C.AppC(c, Seq(v)))
def cond(condition: S.Tree)(thnCnt: C.Name, elsCnt: C.Name): C.Tree = {
def isTrue(l: CL3Literal): Boolean = l match {
case BooleanLit(false) => false
case _ => true
def transformTestPrim(p: L3TestPrimitive, args: Seq[S.Tree]): C.Tree = {
def transformArgs(args: Seq[S.Tree])(atoms: Seq[C.Atom]): C.Tree = args match {
case Nil => C.If(p, atoms, thnCnt, elsCnt)
case x::xs =>
nonTail(x)(v => transformArgs(xs)(atoms :+ v))
}
transformArgs(args)(Seq())
}
// When e2 is a literal
def thenLit(e1: S.Tree, e2lit: CL3Literal, e3: S.Tree): C.Tree = {
val aName = Symbol.fresh("ac")
val aCnt = C.Cnt(aName, Seq(), cond(e3)(thnCnt, elsCnt))
if (isTrue(e2lit)) {
C.LetC(Seq(aCnt), cond(e1)(thnCnt, aName))
} else {
C.LetC(Seq(aCnt), cond(e1)(elsCnt, aName))
}
def litSelect[T](lit: CL3Literal, tru: T, fls: T): T = lit match {
case BooleanLit(false) => fls
case _ => tru
}
def elseLit(e1: S.Tree, e2: S.Tree, e3lit: CL3Literal): C.Tree = {
val aName = Symbol.fresh("ac")
val aCnt = C.Cnt(aName, Seq(), cond(e2)(thnCnt, elsCnt))
if (isTrue(e3lit)) {
C.LetC(Seq(aCnt), cond(e1)(aName, thnCnt))
} else {
C.LetC(Seq(aCnt), cond(e1)(aName, elsCnt))
}
def singleLiteral(condE: S.Tree, argE: S.Tree, lit: CL3Literal, literalLeft: Boolean) = {
val branchCntName = Symbol.fresh("ac")
val literalCntName = litSelect(lit, thnCnt, elsCnt)
val branchCnt = C.Cnt(branchCntName, Seq(), cond(argE)(thnCnt, elsCnt))
val leftCnt = if (literalLeft) literalCntName else branchCntName
val rightCnt = if (literalLeft) branchCntName else literalCntName
C.LetC(Seq(branchCnt), cond(condE)(leftCnt, rightCnt))
}
condition match {
case S.If(e1, e2, e3) =>
e1 match {
// Result depends entirely on e3
case S.Lit(BooleanLit(false)) =>
cond(e3)(thnCnt, elsCnt)
// Result depends entirely on e2
case S.Lit(_) =>
cond(e2)(thnCnt, elsCnt)
// e1 is not constant
case _ =>
(e2, e3) match {
case (S.Lit(l1), S.Lit(l2)) =>
// The result is entirely determined by e1.
// The continuations are determined by the constants
// present in e2 and e3
val cnts = Seq(l1, l2).map(_ match {
case BooleanLit(false) => elsCnt
case _ => thnCnt
})
// Technically, here we could compile out e1 (we'd need
// to check for side effects tho).
cond(e1)(cnts(0), cnts(1))
// If e1 evaluates to true, jump directly to one of the conts
case (S.Lit(l), _) => thenLit(e1, l, e3)
// If e1 evaluates to false, jump directly to the else cont
case (_, S.Lit(l)) => elseLit(e1, e2, l)
case _ => // No optimisation.
val aName = Symbol.fresh("ac")
val aCnt = C.Cnt(aName, Seq(), cond(e2)(thnCnt, elsCnt))
val bName = Symbol.fresh("bc")
val bCnt = C.Cnt(bName, Seq(), cond(e3)(thnCnt, elsCnt))
// Not sure if declaring both continuations at once is allowed
// I don't see why not, but I'm not sure
C.LetC(Seq(aCnt, bCnt), cond(e1)(aName, bName))
}
}
/*
case S.If(e1, e2, S.Lit(BooleanLit(false))) =>
val aName = Symbol.fresh("ac")
val aCnt = C.Cnt(aName, Seq(), cond(e2)(thnCnt, elsCnt))
C.LetC(Seq(aCnt), cond(e1)(aName, elsCnt))
// If e1 is false, jump directly to then
case S.If(e1, e2, S.Lit(BooleanLit(true))) =>
val aName = Symbol.fresh("ac")
val aCnt = C.Cnt(aName, Seq(), cond(e2)(thnCnt, elsCnt))
C.LetC(Seq(aCnt), cond(e1)(aName, thnCnt))
*/
case S.If(S.Lit(lit), e2, e3) =>
cond(litSelect(lit, e2, e3))(thnCnt, elsCnt)
case S.If(e1, S.Lit(lit2), S.Lit(lit3)) =>
cond(e1)(litSelect(lit2, thnCnt, elsCnt), litSelect(lit3, thnCnt, elsCnt))
case S.If(e1, e2, S.Lit(lit3)) =>
singleLiteral(e1, e2, lit3, literalLeft=false)
case S.If(e1, S.Lit(lit2), e3) =>
singleLiteral(e1, e3, lit2, literalLeft=true)
case S.Prim(p: L3TestPrimitive, primArgs) =>
transformTestPrim(p, primArgs)
case _ => {
implicit val pos = condition.pos
val primCond = S.Prim(L3.Eq, Seq(condition, S.Lit(BooleanLit(false))))
cond(primCond)(elsCnt, thnCnt)
}
}
}
def nonTail(tree: S.Tree)(ctx: C.Atom => C.Tree): C.Tree = {
def transformApp(fun: S.Tree, args: Seq[S.Tree]) = {
// Build return continuation
val cntName = Symbol.fresh("c")
......@@ -120,12 +77,9 @@ object CL3ToCPSTranslator extends (S.Tree => C.Tree) {
nonTail(fun)(fId => transformArgs(args)(fId, Seq()))
}
def transformLetRec(funs: Seq[S.Fun], body: S.Tree) = {
val cpsFuns = funs map { case S.Fun(name, args, body) =>
val cName = Symbol.fresh("c-" + name)
val funCtx = (v: C.Atom) => C.AppC(cName, Seq(v))
C.Fun(name, cName, args, tail(body)(cName))
}
......@@ -135,7 +89,7 @@ object CL3ToCPSTranslator extends (S.Tree => C.Tree) {
def transformValPrim(p: L3ValuePrimitive, primArgs: Seq[S.Tree]) = {
val name = Symbol.fresh("p")
val body = ctx(C.AtomN(name))
// Slight misnomer
def transformArgs(args: Seq[S.Tree])(atoms: Seq[C.Atom]): C.Tree = args match {
case Nil =>
// Finished gathering atoms, now we can just place them into the primitive
......@@ -147,19 +101,6 @@ object CL3ToCPSTranslator extends (S.Tree => C.Tree) {
transformArgs(primArgs)(Seq())
}
// Builds a CPS If out of its arguments
def transformTestPrim(p: L3TestPrimitive,
primArgs: Seq[S.Tree],
thenC: C.Name,
elseC: C.Name): C.Tree = {
def transformArgs(args: Seq[S.Tree])(atoms: Seq[C.Atom]): C.Tree = args match {
case Nil => C.If(p, atoms, thenC, elseC)
case x::xs =>
nonTail(x)(v => transformArgs(xs)(atoms :+ v))
}
transformArgs(primArgs)(Seq())
}
tree match {
case S.Lit(v) => ctx(C.AtomL(v))
case S.Ident(id) => ctx(C.AtomN(id))
......@@ -168,7 +109,7 @@ object CL3ToCPSTranslator extends (S.Tree => C.Tree) {
case ((name, expr), tree) => nonTail(expr)(v => C.LetP(name, L3.Id, Seq(v), tree))
}
case S.If(iff, thn, els) => {
case S.If(condition, thn, els) => {
val retCntName = Symbol.fresh("c")
val retArgName = Symbol.fresh("r")
val retCnt = C.Cnt(retCntName, Seq(retArgName), ctx(C.AtomN(retArgName)))
......@@ -181,22 +122,20 @@ object CL3ToCPSTranslator extends (S.Tree => C.Tree) {
val elsCntBody = tail(els)(retCntName)
val elsCnt = C.Cnt(elsCntName, Seq(), elsCntBody)
val ifBody = cond(iff)(thnCntName, elsCntName)
val ifBody = cond(condition)(thnCntName, elsCntName)
C.LetC(Seq(retCnt), C.LetC(Seq(thnCnt), C.LetC(Seq(elsCnt), ifBody)))
}
case S.App(fun, args) => transformApp(fun, args)
case S.LetRec(funs, body) => transformLetRec(funs, body)
case S.Prim(p, args) => p match {
case _ : L3TestPrimitive =>
val newPrim = S.If( tree,
S.Lit(BooleanLit(true))(tree.pos),
S.Lit(BooleanLit(false))(tree.pos)
)(tree.pos)
nonTail(newPrim)(ctx)
case p : L3ValuePrimitive => transformValPrim(p, args)
}
case S.Prim(p: L3TestPrimitive, args) => {
implicit val pos = tree.pos
val newPrim = S.If(tree, S.Lit(BooleanLit(true)), S.Lit(BooleanLit(false)))
nonTail(newPrim)(ctx)
}
case S.Prim(p: L3ValuePrimitive, args) => transformValPrim(p, args)
case S.Halt(arg) => nonTail(arg)(a => C.Halt(a))
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment