Commit c15ee6be authored by Luca Bataillard's avatar Luca Bataillard
Browse files

Merge branch 'improvedCC' into 'master'

Improved cc

See merge request !2
parents f89f9e4e 3844c623
......@@ -62,6 +62,11 @@ sealed abstract class CPSInterpreter[M <: CPSTreeModule](
case AppF(fun, retC, args) =>
val FunV(fRetC, fArgs, fBody, fEnv) = unwrapFunV(resolve(fun))
if (fArgs.length != args.length) {
println(fun)
println(fArgs.length)
println(args.length)
}
assume(fArgs.length == args.length)
val rArgs = args map resolve
val env1 = ((fRetC +: fArgs) zip (env(retC) +: rArgs)).toMap orElse fEnv
......
......@@ -7,91 +7,166 @@ import l3.{CPSValuePrimitive => CPS}
import l3.{CPSTestPrimitive => CPST}
object CPSValueRepresenter extends (H.Tree => L.Tree) {
def apply(tree: H.Tree): L.Tree = tree match {
case H.LetP(n, prim, args, body) => applyLetP(n, prim, args, body)
private type Worker = Symbol
private type Wrapper = Symbol
private type FreeVars = Seq[Symbol]
private type KnownFunsMap = Map[Symbol, (Worker, Wrapper, FreeVars)]
def apply(tree: H.Tree): L.Tree = transform(tree)(Map())
private def transform(tree: H.Tree)(implicit knownFuns: KnownFunsMap): L.Tree = tree match {
case H.LetP(n, prim, args, body) => transformLetP(n, prim, args, body)
case H.LetF(funs, body) =>
transformLetF(funs, body)
case H.LetC(cnts, body) =>
val lCnts = cnts.map(c => L.Cnt(c.name, c.args, apply(c.body)))
L.LetC(lCnts, apply(body))
val lCnts = cnts.map(c => L.Cnt(c.name, c.args, transform(c.body)))
L.LetC(lCnts, transform(body))
case H.AppF(fun, retC, args) =>
val f = Symbol.fresh("f")
val newBody = L.AppF(L.AtomN(f), retC, rewrite(fun) +: args.map(rewrite))
val newArgs = Seq(rewrite(fun), L.AtomL(0))
L.LetP(f, CPS.BlockGet, newArgs, newBody)
transformAppF(fun, retC, args)
case H.AppC(cnt, args) =>
val lArgs = args.map(rewrite)
L.AppC(cnt, lArgs)
case H.If(cond, args, thenC, elseC) =>
transformIf(cond, args, thenC, elseC)
case H.Halt(v) => L.Halt(rewrite(v))
case _ => throw new Exception("Unimplemented: " + tree.getClass.toString)
}
private def transformLetF(initialFuns: Seq[H.Fun], body: H.Tree): L.LetF = {
// for each function, closes it and returns all the variables that used be free in it
// as well as the associated original function name
def transformFunAbs(funs: Seq[H.Fun]): Seq[(L.Fun, Seq[Symbol], Symbol)] = funs match {
case Nil => Nil
case f :: fs =>
val workerFun = Symbol.fresh("worker_function")
val envName = Symbol.fresh("env")
val newArgs = envName +: f.args
val funBody = apply(f.body)
// Get free variables for this function, then order them
val fv = ((freeVars(f.body) - f.name) -- f.args).toList
// Creates a letP
def argsBindings(freeVars: Seq[Symbol], counter: Int, accSubst: Subst[Symbol]): L.Tree = freeVars match {
case Nil =>
substitute(funBody)(accSubst)
case freeVar :: vs =>
// Bind the fresh variable v to a block get
val v = Symbol.fresh("block_variable")
L.LetP(v, CPS.BlockGet, Seq(L.AtomN(envName), L.AtomL(counter)),
argsBindings(vs, counter + 1, accSubst + (freeVar -> v)))
private def transformAppF(fun: H.Atom, retC: Symbol, args: Seq[H.Atom])(implicit knownFuns: KnownFunsMap): L.Tree = {
val fName = fun.asName.get
if (knownFuns.contains(fName)) {
val (wName, sName, fvs) = knownFuns(fName)
val newArgs = (args map rewrite) ++ (fvs map L.AtomN)
val res = L.AppF(L.AtomN(wName), retC, newArgs)
res
} else {
val f = Symbol.fresh("closure")
val newBody = L.AppF(L.AtomN(f), retC, rewrite(fun) +: args.map(rewrite))
val newArgs = Seq(rewrite(fun), L.AtomL(0))
val res = L.LetP(f, CPS.BlockGet, newArgs, newBody)
println("aaaaaaaaaaaa")
println(fName)
println(res)
res
}
}
private def transformLetF(initialFuns: Seq[H.Fun], body: H.Tree)(implicit oldKnownFuns: KnownFunsMap): L.LetF = {
def funsFV(definedFuns: Seq[H.Fun], prevKnownFuns: KnownFunsMap): Map[Symbol, Seq[Symbol]] = {
type FVMap = Map[Symbol, Set[Symbol]]
def fv(e: H.Tree, fvMap: FVMap): Set[Symbol] = e match {
case H.LetP(n, prim, args, body) =>
val argsFV = fvAtomSeq(args, fvMap)
(fv(body, fvMap) - n) ++ argsFV
case H.LetC(cnts, body) =>
val cntsFVs = cnts.flatMap(c => fv(c.body, fvMap) -- c.args)
fv(body, fvMap) ++ cntsFVs
case H.LetF(funs, body) =>
val funsFVs = funs.flatMap(f => (fv(f.body, fvMap) -- f.args))
(fv(body, fvMap) ++ funsFVs) -- funs.map(_.name)
case H.AppC(cnt, args) =>
fvAtomSeq(args, fvMap)
case H.AppF(fun, retC, args) =>
fvAtomSeq(args, fvMap) ++ fvMap.getOrElse(fun.asName.get, Set(fun.asName.get))
case H.If(_, args, _, _) =>
fvAtomSeq(args, fvMap)
case H.Halt(arg) =>
arg.asName.toSet
}
def fvAtomSeq(as: Seq[H.Atom], fvMap: FVMap): Set[Symbol] =
as.map(_.asName).filter(_.isDefined).toSet
.map((n: Option[Symbol]) => n.get)
def iterate(fvMap: FVMap): FVMap =
definedFuns.foldLeft (fvMap) { case (acc, H.Fun(fName, _, fArgs, fBody)) =>
val newFv = (fv(fBody, acc)) -- fArgs
val newBinding = (fName, newFv)
acc + newBinding
}
val definedFvMap = definedFuns.map(f => (f.name, Set[Symbol]())).toMap
val initialFvMap: FVMap = definedFvMap ++ prevKnownFuns.map{ case (fName, (_, _, fvs)) => (fName, fvs.toSet)}
fixedPoint(initialFvMap)(iterate) map { case (fName, fvs) => (fName, fvs.toSeq) }
}
val newFunBody = argsBindings(fv, 1, subst(f.name, envName))
val newFun = L.Fun(workerFun, f.retC, newArgs, newFunBody)
(newFun, fv, f.name) +: transformFunAbs(fs)
def bindArguments(wName: Symbol, retC: Symbol, envName: Symbol,
freeVars: Seq[Symbol], counter: Int, wArgs: Seq[L.Atom]): L.Tree = freeVars match {
case Nil => L.AppF(L.AtomN(wName), retC, wArgs)
case fv :: fvs =>
val v = Symbol.fresh("binding_" + fv.name)
L.LetP(v, CPS.BlockGet, Seq(L.AtomN(envName), L.AtomL(counter)),
bindArguments(wName, retC, envName, fvs, counter + 1, wArgs :+ L.AtomN(v)))
}
def initFuns(funsAndVars: Seq[(L.Fun, Seq[Symbol], Symbol)], lastBody: L.Tree): L.Tree = {
def initFunHelper(remVars: Seq[Symbol], counter: Int, blockAtom: L.Atom, rest: Seq[(L.Fun, Seq[Symbol], Symbol)]): L.Tree = remVars match {
val fvs = funsFV(initialFuns, oldKnownFuns)
val definedFuns = initialFuns map { case H.Fun(fName, _, fArgs, fBody) =>
val wName = Symbol.fresh(fName.name + "_worker")
val sName = Symbol.fresh(fName.name + "_wrapper")
val fv = fvs(fName)
(fName -> (wName, sName, fv))
}
println("eeeee")
println(initialFuns.map(_.name))
println(definedFuns)
println()
val knownFuns = oldKnownFuns ++ definedFuns
val workers = initialFuns map { case H.Fun(fName, fRetC, fArgs, fBody) =>
val (wName, _, fvs) = knownFuns(fName)
val us = fvs.map(f => Symbol.fresh("fv_" + f.name))
val wBody = substitute(transform(fBody)(knownFuns))((fvs zip us).toMap)
L.Fun(wName, fRetC, fArgs ++ us, wBody)
}
val wrappers = initialFuns map { case H.Fun(fName, _, fArgs, fBody) =>
val (wName, sName, fvs) = knownFuns(fName)
val sCntName = Symbol.fresh("c_wrapper")
val envName = Symbol.fresh("env")
val sArgs = fArgs map (f => Symbol.fresh("n_" + f.name))
val sBody = bindArguments(wName, sCntName, envName, fvs, 1, sArgs map (L.AtomN(_)))
L.Fun(sName, sCntName, envName +: sArgs, sBody)
}
def initFuns(funsAndVars: Seq[(Symbol, (Worker, Wrapper, FreeVars))], lastBody: L.Tree): L.Tree = {
def initFunHelper(fvs: Seq[Symbol], counter: Int, blockAtom: L.Atom, rest: Seq[(Symbol, (Worker, Wrapper, FreeVars))]): L.Tree = fvs match {
case Nil => initFuns(rest, lastBody)
case v :: vs =>
val nextBody = initFunHelper(vs, counter + 1, blockAtom, rest)
val args: Seq[L.Atom] = Seq(blockAtom, L.AtomL(counter), L.AtomN(v))
case fv :: fvs =>
val nextBody = initFunHelper(fvs, counter + 1, blockAtom, rest)
val args: Seq[L.Atom] = Seq(blockAtom, L.AtomL(counter), L.AtomN(fv))
L.LetP(Symbol.fresh("blockset_unused"), CPS.BlockSet, args, nextBody)
}
}
funsAndVars match {
case Nil => lastBody
case (workerFun, vars, originalFunName) :: rest =>
val blockAtom = L.AtomN(originalFunName)
val varInits = initFunHelper(vars, 1, blockAtom, rest)
case (fName, (worker, wrapper, fvs)) :: rest =>
val blockAtom = L.AtomN(fName)
val varInits = initFunHelper(fvs, 1, blockAtom, rest)
val t1 = Symbol.fresh("blockset_unused")
val blockSetArgs = Seq(blockAtom, L.AtomL(0), L.AtomN(workerFun.name))
val blockSetArgs = Seq(blockAtom, L.AtomL(0), L.AtomN(wrapper))
L.LetP(t1, CPS.BlockSet, blockSetArgs, varInits)
}
}
def allocFuns(funsAndVars: Seq[(L.Fun, Seq[Symbol], Symbol)], closureInits: L.Tree): L.Tree =
funsAndVars.foldRight(closureInits) { case ((worker, vars, fName), prevBody) =>
L.LetP(fName, CPS.BlockAlloc(202), Seq(L.AtomL(vars.length + 1)), prevBody)
def allocFuns(funsAndVars: Seq[(Symbol, (Worker, Wrapper, FreeVars))], closureInits: L.Tree): L.Tree =
funsAndVars.foldRight(closureInits) { case ((fName, (worker, wrapper, fvs)), prevBody) =>
L.LetP(fName, CPS.BlockAlloc(202), Seq(L.AtomL(fvs.length + 1)), prevBody)
}
val funsAndVars = transformFunAbs(initialFuns)
val lastBody = apply(body)
val closureInits = initFuns(funsAndVars, lastBody)
val closureAllocsInits = allocFuns(funsAndVars, closureInits)
val lastBody = transform(body)(knownFuns)
val closureInits = initFuns(definedFuns, lastBody)
val closureAllocsInits = allocFuns(definedFuns, closureInits)
val res = L.LetF(funsAndVars.unzip3._1, closureAllocsInits)
res
L.LetF(workers ++ wrappers, closureAllocsInits)
}
// Substitutes _free_ variables in `tree`
......@@ -158,36 +233,9 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
}
}
private def freeVars(e: H.Tree): Set[Symbol] = e match {
case H.LetP(n, prim, args, body) =>
val argsFV = freeVarsAtomSeq(args)
(freeVars(body) - n) ++ argsFV
case H.LetC(cnts, body) =>
freeVars(body) ++ cnts.map(freeVarsCont).reduce(_ ++ _)
case H.LetF(funs, body) =>
val funsFreeVars = funs.map(freeVarsFun).reduce(_ ++ _)
(freeVars(body) ++ funsFreeVars) -- funs.map(_.name)
case H.AppC(cnt, args) =>
freeVarsAtomSeq(args)
case H.AppF(fun, retC, args) =>
fun.asName.toSet ++ freeVarsAtomSeq(args)
case H.If(_, args, _, _) =>
freeVarsAtomSeq(args)
case H.Halt(arg) => arg.asName.toSet
}
private def freeVarsAtomSeq(a: Seq[H.Atom]): Set[Symbol] =
a.map(_.asName).filter(_.isDefined).map(_.get).toSet
private def freeVarsCont(cnt: H.Cnt): Set[Symbol] =
freeVars(cnt.body) -- cnt.args
private def freeVarsFun(fun: H.Fun): Set[Symbol] =
freeVars(fun.body) -- fun.args
private def getMaskR(numBits: Int): Either[H.Atom, L.Atom] = Right(L.AtomL((1 << numBits) -1))
private def applyLetP(n: H.Name, prim: L3, args: Seq[H.Atom], body: H.Tree): L.LetP = {
private def transformLetP(n: H.Name, prim: L3, args: Seq[H.Atom], body: H.Tree)(implicit knownFuns: KnownFunsMap): L.LetP = {
val lAtomOne: Either[H.Atom, L.Atom] = Right(L.AtomL(1))
lazy val x = args(0)
......@@ -202,7 +250,7 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
tempLetP(op, Seq(Right(x1), Right(y1))) { truDiv =>
// Retag the result
tempLetP(CPS.ShiftLeft, Seq(Right(truDiv), lAtomOne)) { shiftedRes =>
L.LetP(n, CPS.Add, Seq(shiftedRes, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(shiftedRes, L.AtomL(1)), transform(body))
}
}
}
......@@ -212,17 +260,17 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
prim match {
case L3.IntAdd =>
tempLetP(CPS.Sub, Seq(Left(x), lAtomOne)) { x1 =>
L.LetP(n, CPS.Add, Seq(x1, rewrite(y)), apply(body))
L.LetP(n, CPS.Add, Seq(x1, rewrite(y)), transform(body))
}
case L3.IntSub =>
tempLetP(CPS.Add, Seq(Left(x), lAtomOne)) { x1 =>
L.LetP(n, CPS.Sub, Seq(x1, rewrite(y)), apply(body))
L.LetP(n, CPS.Sub, Seq(x1, rewrite(y)), transform(body))
}
case L3.IntMul =>
tempLetP(CPS.Sub, Seq(Left(x), lAtomOne)) { x1 =>
tempLetP(CPS.ShiftRight, Seq(Left(y), lAtomOne)) { y1 =>
tempLetP(CPS.Mul, Seq(Right(x1), Right(y1))) { z =>
L.LetP(n, CPS.Add, Seq(z, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(z, L.AtomL(1)), transform(body))
}
}
}
......@@ -235,7 +283,7 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
tempLetP(CPS.Sub, Seq(Left(x), lAtomOne)) { x1 =>
tempLetP(CPS.ShiftRight, Seq(Left(y), lAtomOne)) { y1 =>
tempLetP(CPS.ShiftLeft, Seq(Right(x1), Right(y1))) { z =>
L.LetP(n, CPS.Add, Seq(z, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(z, L.AtomL(1)), transform(body))
}
}
}
......@@ -243,67 +291,72 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
case L3.IntShiftRight =>
tempLetP(CPS.ShiftRight, Seq(Left(y), lAtomOne)) { y1 =>
tempLetP(CPS.ShiftRight, Seq(Left(x), Right(y1))) { z =>
L.LetP(n, CPS.Or, Seq(z, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Or, Seq(z, L.AtomL(1)), transform(body))
}
}
case L3.IntBitwiseAnd =>
L.LetP(n, CPS.And, Seq(rewrite(x), rewrite(y)), apply(body))
L.LetP(n, CPS.And, Seq(rewrite(x), rewrite(y)), transform(body))
case L3.IntBitwiseOr =>
L.LetP(n, CPS.Or, Seq(rewrite(x), rewrite(y)), apply(body))
L.LetP(n, CPS.Or, Seq(rewrite(x), rewrite(y)), transform(body))
case L3.IntBitwiseXOr =>
tempLetP(CPS.XOr, Seq(Left(x), lAtomOne)) { x1 =>
L.LetP(n, CPS.XOr, Seq(x1, rewrite(y)), apply(body))
L.LetP(n, CPS.XOr, Seq(x1, rewrite(y)), transform(body))
}
case L3.BlockAlloc(tag) =>
tempLetP(CPS.ShiftRight, Seq(Left(x), lAtomOne)) { t1 =>
L.LetP(n, CPS.BlockAlloc(tag), Seq(t1), apply(body))
L.LetP(n, CPS.BlockAlloc(tag), Seq(t1), transform(body))
}
case L3.BlockTag =>
tempLetP(CPS.BlockTag, args map (Left(_))) { t1 =>
tempLetP(CPS.ShiftLeft, Seq(Right(t1), lAtomOne)) { t2 =>
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), transform(body))
}
}
case L3.BlockLength =>
tempLetP(CPS.BlockLength, args map (Left(_))) { t1 =>
tempLetP(CPS.ShiftLeft, Seq(Right(t1), lAtomOne)) { t2 =>
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), transform(body))
}
}
case L3.BlockSet =>
val block = rewrite(x)
val value = rewrite(z)
tempLetP(CPS.ShiftRight, Seq(Left(y), Right(L.AtomL(1)))){idx =>
L.LetP(n, CPS.BlockSet, Seq(block, idx, value), apply(body))
L.LetP(n, CPS.BlockSet, Seq(block, idx, value), transform(body))
}
case L3.BlockGet =>
val block = rewrite(x)
tempLetP(CPS.ShiftRight, Seq(Left(y), Right(L.AtomL(1)))){
idx => L.LetP(n, CPS.BlockGet, Seq(block, idx), apply(body))
idx => L.LetP(n, CPS.BlockGet, Seq(block, idx), transform(body))
}
case L3.ByteRead =>
tempLetP(CPS.ByteRead, Seq()){ t1 =>
tempLetP(CPS.ShiftLeft, Seq(Right(t1), lAtomOne)) { t2 =>
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), transform(body))
}
}
case L3.ByteWrite =>
tempLetP(CPS.ShiftRight, Seq(Left(x), lAtomOne)) { t1 =>
L.LetP(n, CPS.ByteWrite, Seq(t1), apply(body))
L.LetP(n, CPS.ByteWrite, Seq(t1), transform(body))
}
case L3.CharToInt =>
L.LetP(n, CPS.ShiftRight, Seq(rewrite(x), L.AtomL(2)), apply(body))
L.LetP(n, CPS.ShiftRight, Seq(rewrite(x), L.AtomL(2)), transform(body))
case L3.Id =>
L.LetP(n, CPS.Id, Seq(rewrite(x)), apply(body))
val newKnownFuns = x match {
case H.AtomN(xName) if knownFuns contains xName =>
knownFuns.updated(n, knownFuns(xName))
case _ => knownFuns
}
L.LetP(n, CPS.Id, Seq(rewrite(x)), transform(body)(newKnownFuns))
case L3.IntToChar =>
tempLetP(CPS.ShiftLeft, Seq(Left(x), Right(L.AtomL(2)))){ t1 =>
L.LetP(n, CPS.Add, Seq(t1, L.AtomL(2)), apply(body))
L.LetP(n, CPS.Add, Seq(t1, L.AtomL(2)), transform(body))
}
case _ => throw new Exception("Unreachable code (unary letP) " + prim.getClass)
......
......@@ -17,10 +17,12 @@ object Main {
andThen CPSValueRepresenter
andThen treePrinter("---------- After value representation")
andThen treeChecker
andThen treePrinter("---------- After hoisting")
andThen CPSHoister
andThen CPSInterpreterLow
)
val basePath = Paths.get(".").toAbsolutePath
Either.cond(! args.isEmpty, args.toIndexedSeq, "no input file given")
.flatMap(L3FileReader.readFilesExpandingModules(basePath, _))
......
......@@ -2,31 +2,39 @@
;; Test the "fun" expression
(@byte-write 73)
;;(@byte-write 73)
((fun (b) (@byte-write b)) 65)
((fun (b)
(@byte-write b)
(@byte-write (@+ b 1)))
66)
(@byte-write ((fun (x) x) 68))
;;((fun (b) (@byte-write b)) 65)
;;((fun (b)
;; (@byte-write b)
;; (@byte-write (@+ b 1)))
;; 66)
;;(@byte-write ((fun (x) x) 68))
(let ((compose (fun (f g)
(fun (x) (f (g x)))))
(succ (fun (x) (@+ x 1)))
(twice (fun (x) (@+ x x))))
(@byte-write ((compose succ twice) 34)))
;;(let ((compose (fun (f g)
;; (fun (x) (f (g x)))))
;; (succ (fun (x) (@+ x 1)))
;; (twice (fun (x) (@+ x x))))
;; (@byte-write ((compose succ twice) 34)))
((fun (x y z) #u)
(@byte-write 70)
(@byte-write 71)
(@byte-write 72))
(let
(
(compose (fun (f)
(fun (x) (f x)))
)
)
((compose 1)))
(let* ((fact (fun (self x)
(if (@= 0 x)
1
(@* x (self self (@- x 1))))))
(fix (fun (f x)
(f f x))))
(if (@= (fix fact 5) 120)
(@byte-write 73)))
;;((fun (x y z) #u)
;; (@byte-write 70)
;; (@byte-write 71)
;; (@byte-write 72))
;;(let* ((fact (fun (self x)
;; (if (@= 0 x)
;; 1
;; (@* x (self self (@- x 1))))))
;; (fix (fun (f x)
;; (f f x))))
;; (if (@= (fix fact 5) 120)
;; (@byte-write 73)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment