Commit 849e0c6e authored by Luca Bataillard's avatar Luca Bataillard
Browse files

first attempt at opt cc

parent 37d9acec
......@@ -22,21 +22,83 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
val lCnts = cnts.map(c => L.Cnt(c.name, c.args, transform(c.body)))
L.LetC(lCnts, transform(body))
case H.AppF(fun, retC, args) =>
val f = Symbol.fresh("f")
val newBody = L.AppF(L.AtomN(f), retC, rewrite(fun) +: args.map(rewrite))
val newArgs = Seq(rewrite(fun), L.AtomL(0))
L.LetP(f, CPS.BlockGet, newArgs, newBody)
transformAppF(fun, retC, args)
case H.AppC(cnt, args) =>
val lArgs = args.map(rewrite)
L.AppC(cnt, lArgs)
case H.If(cond, args, thenC, elseC) =>
transformIf(cond, args, thenC, elseC)
case H.Halt(v) => L.Halt(rewrite(v))
case _ => throw new Exception("Unimplemented: " + tree.getClass.toString)
}
private def transformAppF(fun: H.Atom, retC: Symbol, args: Seq[H.Atom])(implicit knownFuns: KnownFunsMap): L.Tree = {
val fName = fun.asName.get
if (knownFuns.contains(fName)) {
val (wName, sName, fvs) = knownFuns(fName)
val newArgs = (args map rewrite) ++ (fvs map L.AtomN)
L.AppF(L.AtomN(wName), retC, newArgs)
} else {
val f = Symbol.fresh("f")
val newBody = L.AppF(L.AtomN(f), retC, rewrite(fun) +: args.map(rewrite))
val newArgs = Seq(rewrite(fun), L.AtomL(0))
L.LetP(f, CPS.BlockGet, newArgs, newBody)
}
}
private def transformLetF(initialFuns: Seq[H.Fun], body: H.Tree)(implicit oldKnownFuns: KnownFunsMap): L.LetF = {
def funsFV(definedFuns: Seq[H.Fun], prevKnownFuns: KnownFunsMap): Map[Symbol, Seq[Symbol]] = {
type FVMap = Map[Symbol, Set[Symbol]]
def fv(e: H.Tree, fvMap: FVMap): Set[Symbol] = e match {
case H.LetP(n, prim, args, body) =>
val argsFV = fvAtomSeq(args, fvMap)
(fv(body, fvMap) - n) ++ argsFV
case H.LetC(cnts, body) =>
val cntsFVs = cnts.flatMap(c => fv(c.body, fvMap) -- c.args)
fv(body, fvMap) ++ cntsFVs
case H.LetF(funs, body) =>
val (newFvMap, funFVs) = fvFunSeq(funs, fvMap)
(fv(body, newFvMap) ++ funFVs) -- funs.map(_.name)
case H.AppC(cnt, args) =>
fvAtomSeq(args, fvMap)
case H.AppF(fun, retC, args) =>
fun.asName.toSet ++ fvAtomSeq(args, fvMap)
case H.If(_, args, _, _) =>
fvAtomSeq(args, fvMap)
case H.Halt(arg) =>
arg.asName.toSet
}
def fvAtomSeq(as: Seq[H.Atom], fvMap: FVMap): Set[Symbol] =
as.map(_.asName).filter(_.isDefined).toSet
.flatMap((n: Option[Symbol]) => fvMap.getOrElse(n.get, Set()) + n.get)
def fvFunSeq(funs: Seq[H.Fun], fvMap: FVMap): (FVMap, Set[Symbol]) = {
val fNames = funs map (_.name)
val defFvMap = fNames.map((_, Set[Symbol]())).toMap
val funFVs = funs map {f => {
val fNewFvMap = fvMap ++ (defFvMap - f.name)
fv(f.body, fNewFvMap) -- f.args
}}
(fvMap ++ (fNames zip funFVs), funFVs reduce (_ ++ _))
}
def iterate(fvMap: FVMap): FVMap =
definedFuns.map {
case H.Fun(fName, _, fArgs, fBody) =>
val newFv = (fv(fBody, fvMap - fName) - fName) -- fArgs
(fName, newFv)
}.toMap
val definedFvMap = definedFuns.map(f => (f.name, Set[Symbol]())).toMap
val initialFvMap: FVMap = definedFvMap ++ prevKnownFuns.map{ case (fName, (_, _, fvs)) => (fName, fvs.toSet)}
fixedPoint(initialFvMap)(iterate) map { case (fName, fvs) => (fName, fvs.toSeq) }
}
def bindArguments(wName: Symbol, retC: Symbol, envName: Symbol,
freeVars: Seq[Symbol], counter: Int, wArgs: Seq[L.Atom]): L.Tree = freeVars match {
......@@ -47,10 +109,13 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
bindArguments(wName, retC, envName, fvs, counter + 1, wArgs :+ L.AtomN(v)))
}
val fvs = funsFV(initialFuns, oldKnownFuns)
val definedFuns = initialFuns map { case H.Fun(fName, _, fArgs, fBody) =>
val wName = Symbol.fresh(fName.name + "_worker")
val sName = Symbol.fresh(fName.name + "_wrapper")
val fv = ((freeVars(fBody) - fName) -- fArgs).toList
val fv = fvs(fName)
(fName -> (wName, sName, fv))
}
val knownFuns = oldKnownFuns ++ definedFuns
......@@ -74,54 +139,6 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
L.Fun(sName, sCntName, sArgs, sBody)
}
// for each function, closes it and returns all the variables that used be free in it
// as well as the associated original function name
/*def transformFunAbs(funs: Seq[H.Fun]): Seq[(L.Fun, Seq[Symbol], Symbol)] = funs match {
case Nil => Nil
case f :: fs =>
val workerFun = Symbol.fresh("worker_function")
val envName = Symbol.fresh("env")
val newArgs = envName +: f.args
val funBody = transform(f.body)
// Get free variables for this function, then order them
val fv = ((freeVars(f.body) - f.name) -- f.args).toList
// Creates a letP
def argsBindings(freeVars: Seq[Symbol], counter: Int, accSubst: Subst[Symbol]): L.Tree = freeVars match {
case Nil =>
substitute(funBody)(accSubst)
case freeVar :: vs =>
// Bind the fresh variable v to a block get
val v = Symbol.fresh("block_variable")
L.LetP(v, CPS.BlockGet, Seq(L.AtomN(envName), L.AtomL(counter)),
argsBindings(vs, counter + 1, accSubst + (freeVar -> v)))
}
val newFunBody = argsBindings(fv, 1, subst(f.name, envName))
val newFun = L.Fun(workerFun, f.retC, newArgs, newFunBody)
(newFun, fv, f.name) +: transformFunAbs(fs)
}*/
/*def initFuns(funsAndVars: Seq[(L.Fun, Seq[Symbol], Symbol)], lastBody: L.Tree): L.Tree = {
def initFunHelper(remVars: Seq[Symbol], counter: Int, blockAtom: L.Atom, rest: Seq[(L.Fun, Seq[Symbol], Symbol)]): L.Tree = remVars match {
case Nil => initFuns(rest, lastBody)
case v :: vs =>
val nextBody = initFunHelper(vs, counter + 1, blockAtom, rest)
val args: Seq[L.Atom] = Seq(blockAtom, L.AtomL(counter), L.AtomN(v))
L.LetP(Symbol.fresh("blockset_unused"), CPS.BlockSet, args, nextBody)
}
funsAndVars match {
case Nil => lastBody
case (workerFun, vars, originalFunName) :: rest =>
val blockAtom = L.AtomN(originalFunName)
val varInits = initFunHelper(vars, 1, blockAtom, rest)
val t1 = Symbol.fresh("blockset_unused")
val blockSetArgs = Seq(blockAtom, L.AtomL(0), L.AtomN(workerFun.name))
L.LetP(t1, CPS.BlockSet, blockSetArgs, varInits)
}
}*/
def initFuns(funsAndVars: Seq[(Symbol, (Worker, Wrapper, FreeVars))], lastBody: L.Tree): L.Tree = {
def initFunHelper(fvs: Seq[Symbol], counter: Int, blockAtom: L.Atom, rest: Seq[(Symbol, (Worker, Wrapper, FreeVars))]): L.Tree = fvs match {
case Nil => initFuns(rest, lastBody)
......@@ -219,33 +236,6 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
}
}
private def freeVars(e: H.Tree): Set[Symbol] = e match {
case H.LetP(n, prim, args, body) =>
val argsFV = freeVarsAtomSeq(args)
(freeVars(body) - n) ++ argsFV
case H.LetC(cnts, body) =>
freeVars(body) ++ cnts.map(freeVarsCont).reduce(_ ++ _)
case H.LetF(funs, body) =>
val funsFreeVars = funs.map(freeVarsFun).reduce(_ ++ _)
(freeVars(body) ++ funsFreeVars) -- funs.map(_.name)
case H.AppC(cnt, args) =>
freeVarsAtomSeq(args)
case H.AppF(fun, retC, args) =>
fun.asName.toSet ++ freeVarsAtomSeq(args)
case H.If(_, args, _, _) =>
freeVarsAtomSeq(args)
case H.Halt(arg) => arg.asName.toSet
}
private def freeVarsAtomSeq(a: Seq[H.Atom]): Set[Symbol] =
a.map(_.asName).filter(_.isDefined).map(_.get).toSet
private def freeVarsCont(cnt: H.Cnt): Set[Symbol] =
freeVars(cnt.body) -- cnt.args
private def freeVarsFun(fun: H.Fun): Set[Symbol] =
freeVars(fun.body) -- fun.args
private def getMaskR(numBits: Int): Either[H.Atom, L.Atom] = Right(L.AtomL((1 << numBits) -1))
private def transformLetP(n: H.Name, prim: L3, args: Seq[H.Atom], body: H.Tree)(implicit knownFuns: KnownFunsMap): L.LetP = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment