Commit 37d9acec authored by Luca Bataillard's avatar Luca Bataillard
Browse files

write LetF impr. translation

parent 146e5f46
......@@ -7,13 +7,20 @@ import l3.{CPSValuePrimitive => CPS}
import l3.{CPSTestPrimitive => CPST}
object CPSValueRepresenter extends (H.Tree => L.Tree) {
def apply(tree: H.Tree): L.Tree = tree match {
case H.LetP(n, prim, args, body) => applyLetP(n, prim, args, body)
private type Worker = Symbol
private type Wrapper = Symbol
private type FreeVars = Seq[Symbol]
private type KnownFunsMap = Map[Symbol, (Worker, Wrapper, FreeVars)]
def apply(tree: H.Tree): L.Tree = transform(tree)(Map())
private def transform(tree: H.Tree)(implicit knownFuns: KnownFunsMap): L.Tree = tree match {
case H.LetP(n, prim, args, body) => transformLetP(n, prim, args, body)
case H.LetF(funs, body) =>
transformLetF(funs, body)
case H.LetC(cnts, body) =>
val lCnts = cnts.map(c => L.Cnt(c.name, c.args, apply(c.body)))
L.LetC(lCnts, apply(body))
val lCnts = cnts.map(c => L.Cnt(c.name, c.args, transform(c.body)))
L.LetC(lCnts, transform(body))
case H.AppF(fun, retC, args) =>
val f = Symbol.fresh("f")
val newBody = L.AppF(L.AtomN(f), retC, rewrite(fun) +: args.map(rewrite))
......@@ -29,17 +36,53 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
case _ => throw new Exception("Unimplemented: " + tree.getClass.toString)
}
private def transformLetF(initialFuns: Seq[H.Fun], body: H.Tree): L.LetF = {
private def transformLetF(initialFuns: Seq[H.Fun], body: H.Tree)(implicit oldKnownFuns: KnownFunsMap): L.LetF = {
def bindArguments(wName: Symbol, retC: Symbol, envName: Symbol,
freeVars: Seq[Symbol], counter: Int, wArgs: Seq[L.Atom]): L.Tree = freeVars match {
case Nil => L.AppF(L.AtomN(wName), retC, wArgs)
case fv :: fvs =>
val v = Symbol.fresh("binding_" + fv.name)
L.LetP(v, CPS.BlockGet, Seq(L.AtomN(envName), L.AtomL(counter)),
bindArguments(wName, retC, envName, fvs, counter + 1, wArgs :+ L.AtomN(v)))
}
val definedFuns = initialFuns map { case H.Fun(fName, _, fArgs, fBody) =>
val wName = Symbol.fresh(fName.name + "_worker")
val sName = Symbol.fresh(fName.name + "_wrapper")
val fv = ((freeVars(fBody) - fName) -- fArgs).toList
(fName -> (wName, sName, fv))
}
val knownFuns = oldKnownFuns ++ definedFuns
val workers = initialFuns map { case H.Fun(fName, fRetC, fArgs, fBody) =>
val (wName, _, fvs) = knownFuns(fName)
val us = fvs.map(f => Symbol.fresh("fv_" + f.name))
val wBody = substitute(apply(fBody))((fvs zip us).toMap)
L.Fun(wName, fRetC, fArgs ++ us, wBody)
}
val wrappers = initialFuns map { case H.Fun(fName, _, fArgs, fBody) =>
val (wName, sName, fvs) = knownFuns(fName)
val sCntName = Symbol.fresh("c_wrapper")
val envName = Symbol.fresh("env")
val sArgs = fArgs map (f => Symbol.fresh("n_" + f.name))
val sBody = bindArguments(wName, sCntName, envName, fvs, 1, sArgs map (L.AtomN(_)))
L.Fun(sName, sCntName, sArgs, sBody)
}
// for each function, closes it and returns all the variables that used be free in it
// as well as the associated original function name
def transformFunAbs(funs: Seq[H.Fun]): Seq[(L.Fun, Seq[Symbol], Symbol)] = funs match {
/*def transformFunAbs(funs: Seq[H.Fun]): Seq[(L.Fun, Seq[Symbol], Symbol)] = funs match {
case Nil => Nil
case f :: fs =>
val workerFun = Symbol.fresh("worker_function")
val envName = Symbol.fresh("env")
val newArgs = envName +: f.args
val funBody = apply(f.body)
val funBody = transform(f.body)
// Get free variables for this function, then order them
val fv = ((freeVars(f.body) - f.name) -- f.args).toList
......@@ -57,9 +100,9 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
val newFunBody = argsBindings(fv, 1, subst(f.name, envName))
val newFun = L.Fun(workerFun, f.retC, newArgs, newFunBody)
(newFun, fv, f.name) +: transformFunAbs(fs)
}
}*/
def initFuns(funsAndVars: Seq[(L.Fun, Seq[Symbol], Symbol)], lastBody: L.Tree): L.Tree = {
/*def initFuns(funsAndVars: Seq[(L.Fun, Seq[Symbol], Symbol)], lastBody: L.Tree): L.Tree = {
def initFunHelper(remVars: Seq[Symbol], counter: Int, blockAtom: L.Atom, rest: Seq[(L.Fun, Seq[Symbol], Symbol)]): L.Tree = remVars match {
case Nil => initFuns(rest, lastBody)
case v :: vs =>
......@@ -77,21 +120,39 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
val blockSetArgs = Seq(blockAtom, L.AtomL(0), L.AtomN(workerFun.name))
L.LetP(t1, CPS.BlockSet, blockSetArgs, varInits)
}
}*/
def initFuns(funsAndVars: Seq[(Symbol, (Worker, Wrapper, FreeVars))], lastBody: L.Tree): L.Tree = {
def initFunHelper(fvs: Seq[Symbol], counter: Int, blockAtom: L.Atom, rest: Seq[(Symbol, (Worker, Wrapper, FreeVars))]): L.Tree = fvs match {
case Nil => initFuns(rest, lastBody)
case fv :: fvs =>
val nextBody = initFunHelper(fvs, counter + 1, blockAtom, rest)
val args: Seq[L.Atom] = Seq(blockAtom, L.AtomL(counter), L.AtomN(fv))
L.LetP(Symbol.fresh("blockset_unused"), CPS.BlockSet, args, nextBody)
}
funsAndVars match {
case Nil => lastBody
case (fName, (worker, wrapper, fvs)) :: rest =>
val blockAtom = L.AtomN(fName)
val varInits = initFunHelper(fvs, 1, blockAtom, rest)
val t1 = Symbol.fresh("blockset_unused")
val blockSetArgs = Seq(blockAtom, L.AtomL(0), L.AtomN(worker))
L.LetP(t1, CPS.BlockSet, blockSetArgs, varInits)
}
}
def allocFuns(funsAndVars: Seq[(L.Fun, Seq[Symbol], Symbol)], closureInits: L.Tree): L.Tree =
funsAndVars.foldRight(closureInits) { case ((worker, vars, fName), prevBody) =>
L.LetP(fName, CPS.BlockAlloc(202), Seq(L.AtomL(vars.length + 1)), prevBody)
def allocFuns(funsAndVars: Seq[(Symbol, (Worker, Wrapper, FreeVars))], closureInits: L.Tree): L.Tree =
funsAndVars.foldRight(closureInits) { case ((fName, (worker, wrapper, fvs)), prevBody) =>
L.LetP(fName, CPS.BlockAlloc(202), Seq(L.AtomL(fvs.length + 1)), prevBody)
}
val funsAndVars = transformFunAbs(initialFuns)
val lastBody = apply(body)
val closureInits = initFuns(funsAndVars, lastBody)
val closureAllocsInits = allocFuns(funsAndVars, closureInits)
val lastBody = transform(body)(knownFuns)
val closureInits = initFuns(definedFuns, lastBody)
val closureAllocsInits = allocFuns(definedFuns, closureInits)
val res = L.LetF(funsAndVars.unzip3._1, closureAllocsInits)
res
L.LetF(workers ++ wrappers, closureAllocsInits)
}
// Substitutes _free_ variables in `tree`
......@@ -187,7 +248,7 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
private def getMaskR(numBits: Int): Either[H.Atom, L.Atom] = Right(L.AtomL((1 << numBits) -1))
private def applyLetP(n: H.Name, prim: L3, args: Seq[H.Atom], body: H.Tree): L.LetP = {
private def transformLetP(n: H.Name, prim: L3, args: Seq[H.Atom], body: H.Tree)(implicit knownFuns: KnownFunsMap): L.LetP = {
val lAtomOne: Either[H.Atom, L.Atom] = Right(L.AtomL(1))
lazy val x = args(0)
......@@ -202,7 +263,7 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
tempLetP(op, Seq(Right(x1), Right(y1))) { truDiv =>
// Retag the result
tempLetP(CPS.ShiftLeft, Seq(Right(truDiv), lAtomOne)) { shiftedRes =>
L.LetP(n, CPS.Add, Seq(shiftedRes, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(shiftedRes, L.AtomL(1)), transform(body))
}
}
}
......@@ -212,17 +273,17 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
prim match {
case L3.IntAdd =>
tempLetP(CPS.Sub, Seq(Left(x), lAtomOne)) { x1 =>
L.LetP(n, CPS.Add, Seq(x1, rewrite(y)), apply(body))
L.LetP(n, CPS.Add, Seq(x1, rewrite(y)), transform(body))
}
case L3.IntSub =>
tempLetP(CPS.Add, Seq(Left(x), lAtomOne)) { x1 =>
L.LetP(n, CPS.Sub, Seq(x1, rewrite(y)), apply(body))
L.LetP(n, CPS.Sub, Seq(x1, rewrite(y)), transform(body))
}
case L3.IntMul =>
tempLetP(CPS.Sub, Seq(Left(x), lAtomOne)) { x1 =>
tempLetP(CPS.ShiftRight, Seq(Left(y), lAtomOne)) { y1 =>
tempLetP(CPS.Mul, Seq(Right(x1), Right(y1))) { z =>
L.LetP(n, CPS.Add, Seq(z, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(z, L.AtomL(1)), transform(body))
}
}
}
......@@ -235,7 +296,7 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
tempLetP(CPS.Sub, Seq(Left(x), lAtomOne)) { x1 =>
tempLetP(CPS.ShiftRight, Seq(Left(y), lAtomOne)) { y1 =>
tempLetP(CPS.ShiftLeft, Seq(Right(x1), Right(y1))) { z =>
L.LetP(n, CPS.Add, Seq(z, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(z, L.AtomL(1)), transform(body))
}
}
}
......@@ -243,67 +304,67 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
case L3.IntShiftRight =>
tempLetP(CPS.ShiftRight, Seq(Left(y), lAtomOne)) { y1 =>
tempLetP(CPS.ShiftRight, Seq(Left(x), Right(y1))) { z =>
L.LetP(n, CPS.Or, Seq(z, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Or, Seq(z, L.AtomL(1)), transform(body))
}
}
case L3.IntBitwiseAnd =>
L.LetP(n, CPS.And, Seq(rewrite(x), rewrite(y)), apply(body))
L.LetP(n, CPS.And, Seq(rewrite(x), rewrite(y)), transform(body))
case L3.IntBitwiseOr =>
L.LetP(n, CPS.Or, Seq(rewrite(x), rewrite(y)), apply(body))
L.LetP(n, CPS.Or, Seq(rewrite(x), rewrite(y)), transform(body))
case L3.IntBitwiseXOr =>
tempLetP(CPS.XOr, Seq(Left(x), lAtomOne)) { x1 =>
L.LetP(n, CPS.XOr, Seq(x1, rewrite(y)), apply(body))
L.LetP(n, CPS.XOr, Seq(x1, rewrite(y)), transform(body))
}
case L3.BlockAlloc(tag) =>
tempLetP(CPS.ShiftRight, Seq(Left(x), lAtomOne)) { t1 =>
L.LetP(n, CPS.BlockAlloc(tag), Seq(t1), apply(body))
L.LetP(n, CPS.BlockAlloc(tag), Seq(t1), transform(body))
}
case L3.BlockTag =>
tempLetP(CPS.BlockTag, args map (Left(_))) { t1 =>
tempLetP(CPS.ShiftLeft, Seq(Right(t1), lAtomOne)) { t2 =>
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), transform(body))
}
}
case L3.BlockLength =>
tempLetP(CPS.BlockLength, args map (Left(_))) { t1 =>
tempLetP(CPS.ShiftLeft, Seq(Right(t1), lAtomOne)) { t2 =>
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), transform(body))
}
}
case L3.BlockSet =>
val block = rewrite(x)
val value = rewrite(z)
tempLetP(CPS.ShiftRight, Seq(Left(y), Right(L.AtomL(1)))){idx =>
L.LetP(n, CPS.BlockSet, Seq(block, idx, value), apply(body))
L.LetP(n, CPS.BlockSet, Seq(block, idx, value), transform(body))
}
case L3.BlockGet =>
val block = rewrite(x)
tempLetP(CPS.ShiftRight, Seq(Left(y), Right(L.AtomL(1)))){
idx => L.LetP(n, CPS.BlockGet, Seq(block, idx), apply(body))
idx => L.LetP(n, CPS.BlockGet, Seq(block, idx), transform(body))
}
case L3.ByteRead =>
tempLetP(CPS.ByteRead, Seq()){ t1 =>
tempLetP(CPS.ShiftLeft, Seq(Right(t1), lAtomOne)) { t2 =>
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), transform(body))
}
}
case L3.ByteWrite =>
tempLetP(CPS.ShiftRight, Seq(Left(x), lAtomOne)) { t1 =>
L.LetP(n, CPS.ByteWrite, Seq(t1), apply(body))
L.LetP(n, CPS.ByteWrite, Seq(t1), transform(body))
}
case L3.CharToInt =>
L.LetP(n, CPS.ShiftRight, Seq(rewrite(x), L.AtomL(2)), apply(body))
L.LetP(n, CPS.ShiftRight, Seq(rewrite(x), L.AtomL(2)), transform(body))
case L3.Id =>
L.LetP(n, CPS.Id, Seq(rewrite(x)), apply(body))
L.LetP(n, CPS.Id, Seq(rewrite(x)), transform(body))
case L3.IntToChar =>
tempLetP(CPS.ShiftLeft, Seq(Left(x), Right(L.AtomL(2)))){ t1 =>
L.LetP(n, CPS.Add, Seq(t1, L.AtomL(2)), apply(body))
L.LetP(n, CPS.Add, Seq(t1, L.AtomL(2)), transform(body))
}
case _ => throw new Exception("Unreachable code (unary letP) " + prim.getClass)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment