Skip to content
Snippets Groups Projects
Commit c15ee6be authored by Luca Bataillard's avatar Luca Bataillard
Browse files

Merge branch 'improvedCC' into 'master'

Improved cc

See merge request !2
parents f89f9e4e 3844c623
No related branches found
No related tags found
1 merge request!2Improved cc
......@@ -62,6 +62,11 @@ sealed abstract class CPSInterpreter[M <: CPSTreeModule](
case AppF(fun, retC, args) =>
val FunV(fRetC, fArgs, fBody, fEnv) = unwrapFunV(resolve(fun))
if (fArgs.length != args.length) {
println(fun)
println(fArgs.length)
println(args.length)
}
assume(fArgs.length == args.length)
val rArgs = args map resolve
val env1 = ((fRetC +: fArgs) zip (env(retC) +: rArgs)).toMap orElse fEnv
......
......@@ -7,91 +7,166 @@ import l3.{CPSValuePrimitive => CPS}
import l3.{CPSTestPrimitive => CPST}
object CPSValueRepresenter extends (H.Tree => L.Tree) {
def apply(tree: H.Tree): L.Tree = tree match {
case H.LetP(n, prim, args, body) => applyLetP(n, prim, args, body)
private type Worker = Symbol
private type Wrapper = Symbol
private type FreeVars = Seq[Symbol]
private type KnownFunsMap = Map[Symbol, (Worker, Wrapper, FreeVars)]
def apply(tree: H.Tree): L.Tree = transform(tree)(Map())
private def transform(tree: H.Tree)(implicit knownFuns: KnownFunsMap): L.Tree = tree match {
case H.LetP(n, prim, args, body) => transformLetP(n, prim, args, body)
case H.LetF(funs, body) =>
transformLetF(funs, body)
case H.LetC(cnts, body) =>
val lCnts = cnts.map(c => L.Cnt(c.name, c.args, apply(c.body)))
L.LetC(lCnts, apply(body))
val lCnts = cnts.map(c => L.Cnt(c.name, c.args, transform(c.body)))
L.LetC(lCnts, transform(body))
case H.AppF(fun, retC, args) =>
val f = Symbol.fresh("f")
val newBody = L.AppF(L.AtomN(f), retC, rewrite(fun) +: args.map(rewrite))
val newArgs = Seq(rewrite(fun), L.AtomL(0))
L.LetP(f, CPS.BlockGet, newArgs, newBody)
transformAppF(fun, retC, args)
case H.AppC(cnt, args) =>
val lArgs = args.map(rewrite)
L.AppC(cnt, lArgs)
case H.If(cond, args, thenC, elseC) =>
transformIf(cond, args, thenC, elseC)
case H.Halt(v) => L.Halt(rewrite(v))
case _ => throw new Exception("Unimplemented: " + tree.getClass.toString)
}
private def transformLetF(initialFuns: Seq[H.Fun], body: H.Tree): L.LetF = {
// for each function, closes it and returns all the variables that used be free in it
// as well as the associated original function name
def transformFunAbs(funs: Seq[H.Fun]): Seq[(L.Fun, Seq[Symbol], Symbol)] = funs match {
case Nil => Nil
case f :: fs =>
val workerFun = Symbol.fresh("worker_function")
val envName = Symbol.fresh("env")
val newArgs = envName +: f.args
val funBody = apply(f.body)
// Get free variables for this function, then order them
val fv = ((freeVars(f.body) - f.name) -- f.args).toList
// Creates a letP
def argsBindings(freeVars: Seq[Symbol], counter: Int, accSubst: Subst[Symbol]): L.Tree = freeVars match {
case Nil =>
substitute(funBody)(accSubst)
case freeVar :: vs =>
// Bind the fresh variable v to a block get
val v = Symbol.fresh("block_variable")
L.LetP(v, CPS.BlockGet, Seq(L.AtomN(envName), L.AtomL(counter)),
argsBindings(vs, counter + 1, accSubst + (freeVar -> v)))
private def transformAppF(fun: H.Atom, retC: Symbol, args: Seq[H.Atom])(implicit knownFuns: KnownFunsMap): L.Tree = {
val fName = fun.asName.get
if (knownFuns.contains(fName)) {
val (wName, sName, fvs) = knownFuns(fName)
val newArgs = (args map rewrite) ++ (fvs map L.AtomN)
val res = L.AppF(L.AtomN(wName), retC, newArgs)
res
} else {
val f = Symbol.fresh("closure")
val newBody = L.AppF(L.AtomN(f), retC, rewrite(fun) +: args.map(rewrite))
val newArgs = Seq(rewrite(fun), L.AtomL(0))
val res = L.LetP(f, CPS.BlockGet, newArgs, newBody)
println("aaaaaaaaaaaa")
println(fName)
println(res)
res
}
}
private def transformLetF(initialFuns: Seq[H.Fun], body: H.Tree)(implicit oldKnownFuns: KnownFunsMap): L.LetF = {
def funsFV(definedFuns: Seq[H.Fun], prevKnownFuns: KnownFunsMap): Map[Symbol, Seq[Symbol]] = {
type FVMap = Map[Symbol, Set[Symbol]]
def fv(e: H.Tree, fvMap: FVMap): Set[Symbol] = e match {
case H.LetP(n, prim, args, body) =>
val argsFV = fvAtomSeq(args, fvMap)
(fv(body, fvMap) - n) ++ argsFV
case H.LetC(cnts, body) =>
val cntsFVs = cnts.flatMap(c => fv(c.body, fvMap) -- c.args)
fv(body, fvMap) ++ cntsFVs
case H.LetF(funs, body) =>
val funsFVs = funs.flatMap(f => (fv(f.body, fvMap) -- f.args))
(fv(body, fvMap) ++ funsFVs) -- funs.map(_.name)
case H.AppC(cnt, args) =>
fvAtomSeq(args, fvMap)
case H.AppF(fun, retC, args) =>
fvAtomSeq(args, fvMap) ++ fvMap.getOrElse(fun.asName.get, Set(fun.asName.get))
case H.If(_, args, _, _) =>
fvAtomSeq(args, fvMap)
case H.Halt(arg) =>
arg.asName.toSet
}
def fvAtomSeq(as: Seq[H.Atom], fvMap: FVMap): Set[Symbol] =
as.map(_.asName).filter(_.isDefined).toSet
.map((n: Option[Symbol]) => n.get)
def iterate(fvMap: FVMap): FVMap =
definedFuns.foldLeft (fvMap) { case (acc, H.Fun(fName, _, fArgs, fBody)) =>
val newFv = (fv(fBody, acc)) -- fArgs
val newBinding = (fName, newFv)
acc + newBinding
}
val definedFvMap = definedFuns.map(f => (f.name, Set[Symbol]())).toMap
val initialFvMap: FVMap = definedFvMap ++ prevKnownFuns.map{ case (fName, (_, _, fvs)) => (fName, fvs.toSet)}
fixedPoint(initialFvMap)(iterate) map { case (fName, fvs) => (fName, fvs.toSeq) }
}
val newFunBody = argsBindings(fv, 1, subst(f.name, envName))
val newFun = L.Fun(workerFun, f.retC, newArgs, newFunBody)
(newFun, fv, f.name) +: transformFunAbs(fs)
def bindArguments(wName: Symbol, retC: Symbol, envName: Symbol,
freeVars: Seq[Symbol], counter: Int, wArgs: Seq[L.Atom]): L.Tree = freeVars match {
case Nil => L.AppF(L.AtomN(wName), retC, wArgs)
case fv :: fvs =>
val v = Symbol.fresh("binding_" + fv.name)
L.LetP(v, CPS.BlockGet, Seq(L.AtomN(envName), L.AtomL(counter)),
bindArguments(wName, retC, envName, fvs, counter + 1, wArgs :+ L.AtomN(v)))
}
def initFuns(funsAndVars: Seq[(L.Fun, Seq[Symbol], Symbol)], lastBody: L.Tree): L.Tree = {
def initFunHelper(remVars: Seq[Symbol], counter: Int, blockAtom: L.Atom, rest: Seq[(L.Fun, Seq[Symbol], Symbol)]): L.Tree = remVars match {
val fvs = funsFV(initialFuns, oldKnownFuns)
val definedFuns = initialFuns map { case H.Fun(fName, _, fArgs, fBody) =>
val wName = Symbol.fresh(fName.name + "_worker")
val sName = Symbol.fresh(fName.name + "_wrapper")
val fv = fvs(fName)
(fName -> (wName, sName, fv))
}
println("eeeee")
println(initialFuns.map(_.name))
println(definedFuns)
println()
val knownFuns = oldKnownFuns ++ definedFuns
val workers = initialFuns map { case H.Fun(fName, fRetC, fArgs, fBody) =>
val (wName, _, fvs) = knownFuns(fName)
val us = fvs.map(f => Symbol.fresh("fv_" + f.name))
val wBody = substitute(transform(fBody)(knownFuns))((fvs zip us).toMap)
L.Fun(wName, fRetC, fArgs ++ us, wBody)
}
val wrappers = initialFuns map { case H.Fun(fName, _, fArgs, fBody) =>
val (wName, sName, fvs) = knownFuns(fName)
val sCntName = Symbol.fresh("c_wrapper")
val envName = Symbol.fresh("env")
val sArgs = fArgs map (f => Symbol.fresh("n_" + f.name))
val sBody = bindArguments(wName, sCntName, envName, fvs, 1, sArgs map (L.AtomN(_)))
L.Fun(sName, sCntName, envName +: sArgs, sBody)
}
def initFuns(funsAndVars: Seq[(Symbol, (Worker, Wrapper, FreeVars))], lastBody: L.Tree): L.Tree = {
def initFunHelper(fvs: Seq[Symbol], counter: Int, blockAtom: L.Atom, rest: Seq[(Symbol, (Worker, Wrapper, FreeVars))]): L.Tree = fvs match {
case Nil => initFuns(rest, lastBody)
case v :: vs =>
val nextBody = initFunHelper(vs, counter + 1, blockAtom, rest)
val args: Seq[L.Atom] = Seq(blockAtom, L.AtomL(counter), L.AtomN(v))
case fv :: fvs =>
val nextBody = initFunHelper(fvs, counter + 1, blockAtom, rest)
val args: Seq[L.Atom] = Seq(blockAtom, L.AtomL(counter), L.AtomN(fv))
L.LetP(Symbol.fresh("blockset_unused"), CPS.BlockSet, args, nextBody)
}
}
funsAndVars match {
case Nil => lastBody
case (workerFun, vars, originalFunName) :: rest =>
val blockAtom = L.AtomN(originalFunName)
val varInits = initFunHelper(vars, 1, blockAtom, rest)
case (fName, (worker, wrapper, fvs)) :: rest =>
val blockAtom = L.AtomN(fName)
val varInits = initFunHelper(fvs, 1, blockAtom, rest)
val t1 = Symbol.fresh("blockset_unused")
val blockSetArgs = Seq(blockAtom, L.AtomL(0), L.AtomN(workerFun.name))
val blockSetArgs = Seq(blockAtom, L.AtomL(0), L.AtomN(wrapper))
L.LetP(t1, CPS.BlockSet, blockSetArgs, varInits)
}
}
def allocFuns(funsAndVars: Seq[(L.Fun, Seq[Symbol], Symbol)], closureInits: L.Tree): L.Tree =
funsAndVars.foldRight(closureInits) { case ((worker, vars, fName), prevBody) =>
L.LetP(fName, CPS.BlockAlloc(202), Seq(L.AtomL(vars.length + 1)), prevBody)
def allocFuns(funsAndVars: Seq[(Symbol, (Worker, Wrapper, FreeVars))], closureInits: L.Tree): L.Tree =
funsAndVars.foldRight(closureInits) { case ((fName, (worker, wrapper, fvs)), prevBody) =>
L.LetP(fName, CPS.BlockAlloc(202), Seq(L.AtomL(fvs.length + 1)), prevBody)
}
val funsAndVars = transformFunAbs(initialFuns)
val lastBody = apply(body)
val closureInits = initFuns(funsAndVars, lastBody)
val closureAllocsInits = allocFuns(funsAndVars, closureInits)
val lastBody = transform(body)(knownFuns)
val closureInits = initFuns(definedFuns, lastBody)
val closureAllocsInits = allocFuns(definedFuns, closureInits)
val res = L.LetF(funsAndVars.unzip3._1, closureAllocsInits)
res
L.LetF(workers ++ wrappers, closureAllocsInits)
}
// Substitutes _free_ variables in `tree`
......@@ -158,36 +233,9 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
}
}
private def freeVars(e: H.Tree): Set[Symbol] = e match {
case H.LetP(n, prim, args, body) =>
val argsFV = freeVarsAtomSeq(args)
(freeVars(body) - n) ++ argsFV
case H.LetC(cnts, body) =>
freeVars(body) ++ cnts.map(freeVarsCont).reduce(_ ++ _)
case H.LetF(funs, body) =>
val funsFreeVars = funs.map(freeVarsFun).reduce(_ ++ _)
(freeVars(body) ++ funsFreeVars) -- funs.map(_.name)
case H.AppC(cnt, args) =>
freeVarsAtomSeq(args)
case H.AppF(fun, retC, args) =>
fun.asName.toSet ++ freeVarsAtomSeq(args)
case H.If(_, args, _, _) =>
freeVarsAtomSeq(args)
case H.Halt(arg) => arg.asName.toSet
}
private def freeVarsAtomSeq(a: Seq[H.Atom]): Set[Symbol] =
a.map(_.asName).filter(_.isDefined).map(_.get).toSet
private def freeVarsCont(cnt: H.Cnt): Set[Symbol] =
freeVars(cnt.body) -- cnt.args
private def freeVarsFun(fun: H.Fun): Set[Symbol] =
freeVars(fun.body) -- fun.args
private def getMaskR(numBits: Int): Either[H.Atom, L.Atom] = Right(L.AtomL((1 << numBits) -1))
private def applyLetP(n: H.Name, prim: L3, args: Seq[H.Atom], body: H.Tree): L.LetP = {
private def transformLetP(n: H.Name, prim: L3, args: Seq[H.Atom], body: H.Tree)(implicit knownFuns: KnownFunsMap): L.LetP = {
val lAtomOne: Either[H.Atom, L.Atom] = Right(L.AtomL(1))
lazy val x = args(0)
......@@ -202,7 +250,7 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
tempLetP(op, Seq(Right(x1), Right(y1))) { truDiv =>
// Retag the result
tempLetP(CPS.ShiftLeft, Seq(Right(truDiv), lAtomOne)) { shiftedRes =>
L.LetP(n, CPS.Add, Seq(shiftedRes, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(shiftedRes, L.AtomL(1)), transform(body))
}
}
}
......@@ -212,17 +260,17 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
prim match {
case L3.IntAdd =>
tempLetP(CPS.Sub, Seq(Left(x), lAtomOne)) { x1 =>
L.LetP(n, CPS.Add, Seq(x1, rewrite(y)), apply(body))
L.LetP(n, CPS.Add, Seq(x1, rewrite(y)), transform(body))
}
case L3.IntSub =>
tempLetP(CPS.Add, Seq(Left(x), lAtomOne)) { x1 =>
L.LetP(n, CPS.Sub, Seq(x1, rewrite(y)), apply(body))
L.LetP(n, CPS.Sub, Seq(x1, rewrite(y)), transform(body))
}
case L3.IntMul =>
tempLetP(CPS.Sub, Seq(Left(x), lAtomOne)) { x1 =>
tempLetP(CPS.ShiftRight, Seq(Left(y), lAtomOne)) { y1 =>
tempLetP(CPS.Mul, Seq(Right(x1), Right(y1))) { z =>
L.LetP(n, CPS.Add, Seq(z, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(z, L.AtomL(1)), transform(body))
}
}
}
......@@ -235,7 +283,7 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
tempLetP(CPS.Sub, Seq(Left(x), lAtomOne)) { x1 =>
tempLetP(CPS.ShiftRight, Seq(Left(y), lAtomOne)) { y1 =>
tempLetP(CPS.ShiftLeft, Seq(Right(x1), Right(y1))) { z =>
L.LetP(n, CPS.Add, Seq(z, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(z, L.AtomL(1)), transform(body))
}
}
}
......@@ -243,67 +291,72 @@ object CPSValueRepresenter extends (H.Tree => L.Tree) {
case L3.IntShiftRight =>
tempLetP(CPS.ShiftRight, Seq(Left(y), lAtomOne)) { y1 =>
tempLetP(CPS.ShiftRight, Seq(Left(x), Right(y1))) { z =>
L.LetP(n, CPS.Or, Seq(z, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Or, Seq(z, L.AtomL(1)), transform(body))
}
}
case L3.IntBitwiseAnd =>
L.LetP(n, CPS.And, Seq(rewrite(x), rewrite(y)), apply(body))
L.LetP(n, CPS.And, Seq(rewrite(x), rewrite(y)), transform(body))
case L3.IntBitwiseOr =>
L.LetP(n, CPS.Or, Seq(rewrite(x), rewrite(y)), apply(body))
L.LetP(n, CPS.Or, Seq(rewrite(x), rewrite(y)), transform(body))
case L3.IntBitwiseXOr =>
tempLetP(CPS.XOr, Seq(Left(x), lAtomOne)) { x1 =>
L.LetP(n, CPS.XOr, Seq(x1, rewrite(y)), apply(body))
L.LetP(n, CPS.XOr, Seq(x1, rewrite(y)), transform(body))
}
case L3.BlockAlloc(tag) =>
tempLetP(CPS.ShiftRight, Seq(Left(x), lAtomOne)) { t1 =>
L.LetP(n, CPS.BlockAlloc(tag), Seq(t1), apply(body))
L.LetP(n, CPS.BlockAlloc(tag), Seq(t1), transform(body))
}
case L3.BlockTag =>
tempLetP(CPS.BlockTag, args map (Left(_))) { t1 =>
tempLetP(CPS.ShiftLeft, Seq(Right(t1), lAtomOne)) { t2 =>
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), transform(body))
}
}
case L3.BlockLength =>
tempLetP(CPS.BlockLength, args map (Left(_))) { t1 =>
tempLetP(CPS.ShiftLeft, Seq(Right(t1), lAtomOne)) { t2 =>
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), transform(body))
}
}
case L3.BlockSet =>
val block = rewrite(x)
val value = rewrite(z)
tempLetP(CPS.ShiftRight, Seq(Left(y), Right(L.AtomL(1)))){idx =>
L.LetP(n, CPS.BlockSet, Seq(block, idx, value), apply(body))
L.LetP(n, CPS.BlockSet, Seq(block, idx, value), transform(body))
}
case L3.BlockGet =>
val block = rewrite(x)
tempLetP(CPS.ShiftRight, Seq(Left(y), Right(L.AtomL(1)))){
idx => L.LetP(n, CPS.BlockGet, Seq(block, idx), apply(body))
idx => L.LetP(n, CPS.BlockGet, Seq(block, idx), transform(body))
}
case L3.ByteRead =>
tempLetP(CPS.ByteRead, Seq()){ t1 =>
tempLetP(CPS.ShiftLeft, Seq(Right(t1), lAtomOne)) { t2 =>
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), apply(body))
L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), transform(body))
}
}
case L3.ByteWrite =>
tempLetP(CPS.ShiftRight, Seq(Left(x), lAtomOne)) { t1 =>
L.LetP(n, CPS.ByteWrite, Seq(t1), apply(body))
L.LetP(n, CPS.ByteWrite, Seq(t1), transform(body))
}
case L3.CharToInt =>
L.LetP(n, CPS.ShiftRight, Seq(rewrite(x), L.AtomL(2)), apply(body))
L.LetP(n, CPS.ShiftRight, Seq(rewrite(x), L.AtomL(2)), transform(body))
case L3.Id =>
L.LetP(n, CPS.Id, Seq(rewrite(x)), apply(body))
val newKnownFuns = x match {
case H.AtomN(xName) if knownFuns contains xName =>
knownFuns.updated(n, knownFuns(xName))
case _ => knownFuns
}
L.LetP(n, CPS.Id, Seq(rewrite(x)), transform(body)(newKnownFuns))
case L3.IntToChar =>
tempLetP(CPS.ShiftLeft, Seq(Left(x), Right(L.AtomL(2)))){ t1 =>
L.LetP(n, CPS.Add, Seq(t1, L.AtomL(2)), apply(body))
L.LetP(n, CPS.Add, Seq(t1, L.AtomL(2)), transform(body))
}
case _ => throw new Exception("Unreachable code (unary letP) " + prim.getClass)
......
......@@ -17,10 +17,12 @@ object Main {
andThen CPSValueRepresenter
andThen treePrinter("---------- After value representation")
andThen treeChecker
andThen treePrinter("---------- After hoisting")
andThen CPSHoister
andThen CPSInterpreterLow
)
val basePath = Paths.get(".").toAbsolutePath
Either.cond(! args.isEmpty, args.toIndexedSeq, "no input file given")
.flatMap(L3FileReader.readFilesExpandingModules(basePath, _))
......
......@@ -2,31 +2,39 @@
;; Test the "fun" expression
(@byte-write 73)
;;(@byte-write 73)
((fun (b) (@byte-write b)) 65)
((fun (b)
(@byte-write b)
(@byte-write (@+ b 1)))
66)
(@byte-write ((fun (x) x) 68))
;;((fun (b) (@byte-write b)) 65)
;;((fun (b)
;; (@byte-write b)
;; (@byte-write (@+ b 1)))
;; 66)
;;(@byte-write ((fun (x) x) 68))
(let ((compose (fun (f g)
(fun (x) (f (g x)))))
(succ (fun (x) (@+ x 1)))
(twice (fun (x) (@+ x x))))
(@byte-write ((compose succ twice) 34)))
;;(let ((compose (fun (f g)
;; (fun (x) (f (g x)))))
;; (succ (fun (x) (@+ x 1)))
;; (twice (fun (x) (@+ x x))))
;; (@byte-write ((compose succ twice) 34)))
((fun (x y z) #u)
(@byte-write 70)
(@byte-write 71)
(@byte-write 72))
(let
(
(compose (fun (f)
(fun (x) (f x)))
)
)
((compose 1)))
(let* ((fact (fun (self x)
(if (@= 0 x)
1
(@* x (self self (@- x 1))))))
(fix (fun (f x)
(f f x))))
(if (@= (fix fact 5) 120)
(@byte-write 73)))
;;((fun (x y z) #u)
;; (@byte-write 70)
;; (@byte-write 71)
;; (@byte-write 72))
;;(let* ((fact (fun (self x)
;; (if (@= 0 x)
;; 1
;; (@* x (self self (@- x 1))))))
;; (fix (fun (f x)
;; (f f x))))
;; (if (@= (fix fact 5) 120)
;; (@byte-write 73)))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment