package l3 import l3.{SymbolicCPSTreeModule => H} import l3.{SymbolicCPSTreeModuleLow => L} import l3.{L3Primitive => L3} import l3.{CPSValuePrimitive => CPS} import l3.{CPSTestPrimitive => CPST} object CPSValueRepresenter extends (H.Tree => L.Tree) { private type Worker = Symbol private type Wrapper = Symbol private type FreeVars = Seq[Symbol] private type KnownFunsMap = Map[Symbol, (Worker, Wrapper, FreeVars)] def apply(tree: H.Tree): L.Tree = transform(tree)(Map()) private def transform(tree: H.Tree)(implicit knownFuns: KnownFunsMap): L.Tree = tree match { case H.LetP(n, prim, args, body) => transformLetP(n, prim, args, body) case H.LetF(funs, body) => transformLetF(funs, body) case H.LetC(cnts, body) => val lCnts = cnts.map(c => L.Cnt(c.name, c.args, transform(c.body))) L.LetC(lCnts, transform(body)) case H.AppF(fun, retC, args) => transformAppF(fun, retC, args) case H.AppC(cnt, args) => val lArgs = args.map(rewrite) L.AppC(cnt, lArgs) case H.If(cond, args, thenC, elseC) => transformIf(cond, args, thenC, elseC) case H.Halt(v) => v match { case H.AtomL(UnitLit) => L.Halt(L.AtomL(0)) case H.AtomL(IntLit(i)) => L.Halt(L.AtomL(i.toInt)) case H.AtomL(CharLit(c)) => L.Halt(L.AtomL(c.toInt)) case H.AtomL(BooleanLit(b)) => L.Halt(L.AtomL(if (b) 1 else 0)) case v1 @ H.AtomN(_) => return L.Halt(rewrite(v)) val haltContName = Symbol.fresh("c-halt") val haltContArgs = Seq(Symbol.fresh("halt_arg")) val haltContBody = L.Halt(L.AtomN(haltContArgs(0))) val letCBody = makeUntaggingTree(v1, haltContName) L.LetC(Seq(L.Cnt(haltContName, haltContArgs, haltContBody)), letCBody) } case _ => throw new Exception("Unimplemented: " + tree.getClass.toString) } private def transformAppF(fun: H.Atom, retC: Symbol, args: Seq[H.Atom])(implicit knownFuns: KnownFunsMap): L.Tree = { val fName = fun.asName.get if (knownFuns.contains(fName)) { val (wName, sName, fvs) = knownFuns(fName) val newArgs = (args map rewrite) ++ (fvs map L.AtomN) L.AppF(L.AtomN(wName), retC, newArgs) } else { val f = Symbol.fresh("closure") val newBody = L.AppF(L.AtomN(f), retC, rewrite(fun) +: args.map(rewrite)) val newArgs = Seq(rewrite(fun), L.AtomL(0)) L.LetP(f, CPS.BlockGet, newArgs, newBody) } } def makeUntaggingTree(v: H.AtomN, haltCont: Symbol): L.Tree = { def mkeUntagCont(name: String, nBitsShift: Int): L.Cnt = { val contName = Symbol.fresh("c-" + name) val argName = Symbol.fresh(name + "_arg") val shiftedName = Symbol.fresh(name + "_arg_untagged") L.Cnt(contName, Seq(argName), L.LetP(shiftedName, CPS.ShiftRight, Seq(L.AtomN(argName), L.AtomL(nBitsShift)), L.AppC(haltCont, Seq(L.AtomN(shiftedName))))) } def mkeCheckCont(name: String, body: L.Tree): L.Cnt = { val contName = Symbol.fresh(name) L.Cnt(contName, Seq(), body) } val untagIntCont = mkeUntagCont("int_untag", 1) val untagCharCont = mkeUntagCont("char_untag", 3) val untagBoolCont = mkeUntagCont("bool_untag", 4) val untagUnitCont = mkeUntagCont("unit_untag", 2) // If it's a unit, untag it, otherwise, immediately skip to halt val unitCheckCont = mkeCheckCont("unit_check", transformIf(L3.UnitP, Seq(v), untagUnitCont.name, haltCont)) // If it's a boolean, untag it, otherwise check if it's a unit val boolCheckCont = mkeCheckCont("bool_check", transformIf(L3.BoolP, Seq(v), untagBoolCont.name, unitCheckCont.name)) // if it's a character, untag it, otherwise check if it's a boolean val charCheckCont = mkeCheckCont("char_check", transformIf(L3.CharP, Seq(v), untagCharCont.name, boolCheckCont.name)) //etc val letCBody = transformIf(L3.IntP, Seq(v), untagIntCont.name, haltCont) val conts = Seq(untagIntCont, untagCharCont, untagBoolCont, untagUnitCont, unitCheckCont, charCheckCont, boolCheckCont ) L.LetC(conts, letCBody) } private def transformLetF(initialFuns: Seq[H.Fun], body: H.Tree)(implicit oldKnownFuns: KnownFunsMap): L.LetF = { def funsFV(definedFuns: Seq[H.Fun], prevKnownFuns: KnownFunsMap): Map[Symbol, Seq[Symbol]] = { type FVMap = Map[Symbol, Set[Symbol]] def fv(e: H.Tree, fvMap: FVMap): Set[Symbol] = e match { case H.LetP(n, prim, args, body) => val argsFV = fvAtomSeq(args, fvMap) (fv(body, fvMap) - n) ++ argsFV case H.LetC(cnts, body) => val cntsFVs = cnts.flatMap(c => fv(c.body, fvMap) -- c.args) fv(body, fvMap) ++ cntsFVs case H.LetF(funs, body) => val funsFVs = funs.flatMap(f => (fv(f.body, fvMap) -- f.args)) (fv(body, fvMap) ++ funsFVs) -- funs.map(_.name) case H.AppC(cnt, args) => fvAtomSeq(args, fvMap) case H.AppF(fun, retC, args) => fvAtomSeq(args, fvMap) ++ fvMap.getOrElse(fun.asName.get, Set(fun.asName.get)) case H.If(_, args, _, _) => fvAtomSeq(args, fvMap) case H.Halt(arg) => arg.asName.toSet } def fvAtomSeq(as: Seq[H.Atom], fvMap: FVMap): Set[Symbol] = as.map(_.asName).filter(_.isDefined).toSet .map((n: Option[Symbol]) => n.get) def iterate(fvMap: FVMap): FVMap = definedFuns.foldLeft (fvMap) { case (acc, H.Fun(fName, _, fArgs, fBody)) => val newFv = (fv(fBody, acc)) -- fArgs val newBinding = (fName, newFv) acc + newBinding } val definedFvMap = definedFuns.map(f => (f.name, Set[Symbol]())).toMap val initialFvMap: FVMap = definedFvMap ++ prevKnownFuns.map{ case (fName, (_, _, fvs)) => (fName, fvs.toSet)} fixedPoint(initialFvMap)(iterate) map { case (fName, fvs) => (fName, fvs.toSeq) } } def bindArguments(wName: Symbol, retC: Symbol, envName: Symbol, freeVars: Seq[Symbol], counter: Int, wArgs: Seq[L.Atom]): L.Tree = freeVars match { case Nil => L.AppF(L.AtomN(wName), retC, wArgs) case fv :: fvs => val v = Symbol.fresh("binding_" + fv.name) L.LetP(v, CPS.BlockGet, Seq(L.AtomN(envName), L.AtomL(counter)), bindArguments(wName, retC, envName, fvs, counter + 1, wArgs :+ L.AtomN(v))) } val fvs = funsFV(initialFuns, oldKnownFuns) val definedFuns = initialFuns map { case H.Fun(fName, _, fArgs, fBody) => val wName = Symbol.fresh(fName.name + "_worker") val sName = Symbol.fresh(fName.name + "_wrapper") val fv = fvs(fName) (fName -> (wName, sName, fv)) } val knownFuns = oldKnownFuns ++ definedFuns val workers = initialFuns map { case H.Fun(fName, fRetC, fArgs, fBody) => val (wName, _, fvs) = knownFuns(fName) val us = fvs.map(f => Symbol.fresh("fv_" + f.name)) val wBody = substitute(transform(fBody)(knownFuns))((fvs zip us).toMap) L.Fun(wName, fRetC, fArgs ++ us, wBody) } val wrappers = initialFuns map { case H.Fun(fName, _, fArgs, fBody) => val (wName, sName, fvs) = knownFuns(fName) val sCntName = Symbol.fresh("c_wrapper") val envName = Symbol.fresh("env") val sArgs = fArgs map (f => Symbol.fresh("n_" + f.name)) val sBody = bindArguments(wName, sCntName, envName, fvs, 1, sArgs map (L.AtomN(_))) L.Fun(sName, sCntName, envName +: sArgs, sBody) } def initFuns(funsAndVars: Seq[(Symbol, (Worker, Wrapper, FreeVars))], lastBody: L.Tree): L.Tree = { def initFunHelper(fvs: Seq[Symbol], counter: Int, blockAtom: L.Atom, rest: Seq[(Symbol, (Worker, Wrapper, FreeVars))]): L.Tree = fvs match { case Nil => initFuns(rest, lastBody) case fv :: fvs => val nextBody = initFunHelper(fvs, counter + 1, blockAtom, rest) val args: Seq[L.Atom] = Seq(blockAtom, L.AtomL(counter), L.AtomN(fv)) L.LetP(Symbol.fresh("blockset_unused"), CPS.BlockSet, args, nextBody) } funsAndVars match { case Nil => lastBody case (fName, (worker, wrapper, fvs)) :: rest => val blockAtom = L.AtomN(fName) val varInits = initFunHelper(fvs, 1, blockAtom, rest) val t1 = Symbol.fresh("blockset_unused") val blockSetArgs = Seq(blockAtom, L.AtomL(0), L.AtomN(wrapper)) L.LetP(t1, CPS.BlockSet, blockSetArgs, varInits) } } def allocFuns(funsAndVars: Seq[(Symbol, (Worker, Wrapper, FreeVars))], closureInits: L.Tree): L.Tree = funsAndVars.foldRight(closureInits) { case ((fName, (worker, wrapper, fvs)), prevBody) => L.LetP(fName, CPS.BlockAlloc(202), Seq(L.AtomL(fvs.length + 1)), prevBody) } val lastBody = transform(body)(knownFuns) val closureInits = initFuns(definedFuns, lastBody) val closureAllocsInits = allocFuns(definedFuns, closureInits) L.LetF(workers ++ wrappers, closureAllocsInits) } // Substitutes _free_ variables in `tree` // meaning that `subst` should only contain variables // that are free in `tree` def substitute(tree: L.Tree)(implicit subst: Subst[Symbol]): L.Tree = { def subtituteArgs(args: Seq[L.Atom]): Seq[L.Atom] = args.map(substituteAtom) def substituteAtom(atom: L.Atom) = atom match { case L.AtomL(_) => atom case L.AtomN(n) => L.AtomN(subst.getOrElse(n,n)) } tree match { case L.LetP(name, prim, args, body) => val newArgs = subtituteArgs(args) val newBody = substitute(body) L.LetP(name, prim, newArgs, newBody) case L.AppC(cnt, args) => L.AppC(cnt, subtituteArgs(args)) case L.AppF(fun, retC, args) => L.AppF(substituteAtom(fun), retC, subtituteArgs(args)) case L.Halt(arg) => L.Halt(substituteAtom(arg)) case L.If(cond, args, thenC, elseC) => L.If(cond, subtituteArgs(args), thenC, elseC) case L.LetC(cnts, body) => val newCnts = cnts.map { cnt => L.Cnt(cnt.name, cnt.args, substitute(cnt.body)) } val newBody = substitute(body) L.LetC(newCnts, newBody) case L.LetF(funs, body) => val newFuns = funs.map {fun => L.Fun(fun.name, fun.retC, fun.args, substitute(fun.body)) } val newBody = substitute(body) L.LetF(newFuns, newBody) } } private def transformIf(cond: L3TestPrimitive, args: Seq[H.Atom], thenC: H.Name, elseC: H.Name): L.Tree = { def maskAndCheck(numBits: Int, target: Bits32): L.LetP = { val Seq(x) = args tempLetP(CPS.And, Seq(Left(x), getMaskR(numBits))) { x1 => L.If(CPST.Eq, Seq(x1, L.AtomL(target)), thenC, elseC) } } cond match { case L3.BlockP => maskAndCheck(2, 0x0) case L3.IntP => maskAndCheck(1, 0x1) case L3.BoolP => maskAndCheck(4, 0xa) case L3.UnitP => maskAndCheck(4, 0x2) case L3.CharP => maskAndCheck(3, 0x6) case L3.Eq => L.If(CPST.Eq, args.map(rewrite), thenC, elseC) case L3.IntLe => L.If(CPST.Le, args map rewrite, thenC, elseC) case L3.IntLt => L.If(CPST.Lt, args map rewrite, thenC, elseC) } } private def getMaskR(numBits: Int): Either[H.Atom, L.Atom] = Right(L.AtomL((1 << numBits) -1)) private def transformLetP(n: H.Name, prim: L3, args: Seq[H.Atom], body: H.Tree)(implicit knownFuns: KnownFunsMap): L.LetP = { val lAtomOne: Either[H.Atom, L.Atom] = Right(L.AtomL(1)) lazy val x = args(0) lazy val y = args(1) lazy val z = args(2) def rawBinaryTree(op: CPS): L.LetP = { // Untag both values tempLetP(CPS.ShiftRight, Seq(Left(x), lAtomOne)) { x1 => tempLetP(CPS.ShiftRight, Seq(Left(y), lAtomOne)) { y1 => // Apply the actual operation tempLetP(op, Seq(Right(x1), Right(y1))) { truDiv => // Retag the result tempLetP(CPS.ShiftLeft, Seq(Right(truDiv), lAtomOne)) { shiftedRes => L.LetP(n, CPS.Add, Seq(shiftedRes, L.AtomL(1)), transform(body)) } } } } } prim match { case L3.IntAdd => tempLetP(CPS.Sub, Seq(Left(x), lAtomOne)) { x1 => L.LetP(n, CPS.Add, Seq(x1, rewrite(y)), transform(body)) } case L3.IntSub => tempLetP(CPS.Add, Seq(Left(x), lAtomOne)) { x1 => L.LetP(n, CPS.Sub, Seq(x1, rewrite(y)), transform(body)) } case L3.IntMul => tempLetP(CPS.Sub, Seq(Left(x), lAtomOne)) { x1 => tempLetP(CPS.ShiftRight, Seq(Left(y), lAtomOne)) { y1 => tempLetP(CPS.Mul, Seq(Right(x1), Right(y1))) { z => L.LetP(n, CPS.Add, Seq(z, L.AtomL(1)), transform(body)) } } } // I don't think there is a way to do this in a smart way case L3.IntDiv => rawBinaryTree(CPS.Div) case L3.IntMod => tempLetP(CPS.XOr, Seq(Left(x), lAtomOne)) { x1 => tempLetP(CPS.XOr, Seq(Left(y), lAtomOne)) { y1 => tempLetP (CPS.Mod, Seq(Right(x1), Right(y1))) { untaggedModRes => L.LetP(n, CPS.XOr, Seq(untaggedModRes, L.AtomL(1)), transform(body)) } } } case L3.IntShiftLeft => tempLetP(CPS.Sub, Seq(Left(x), lAtomOne)) { x1 => tempLetP(CPS.ShiftRight, Seq(Left(y), lAtomOne)) { y1 => tempLetP(CPS.ShiftLeft, Seq(Right(x1), Right(y1))) { z => L.LetP(n, CPS.Add, Seq(z, L.AtomL(1)), transform(body)) } } } case L3.IntShiftRight => tempLetP(CPS.ShiftRight, Seq(Left(y), lAtomOne)) { y1 => tempLetP(CPS.ShiftRight, Seq(Left(x), Right(y1))) { z => L.LetP(n, CPS.Or, Seq(z, L.AtomL(1)), transform(body)) } } case L3.IntBitwiseAnd => L.LetP(n, CPS.And, Seq(rewrite(x), rewrite(y)), transform(body)) case L3.IntBitwiseOr => L.LetP(n, CPS.Or, Seq(rewrite(x), rewrite(y)), transform(body)) case L3.IntBitwiseXOr => tempLetP(CPS.XOr, Seq(Left(x), lAtomOne)) { x1 => L.LetP(n, CPS.XOr, Seq(x1, rewrite(y)), transform(body)) } case L3.BlockAlloc(tag) => tempLetP(CPS.ShiftRight, Seq(Left(x), lAtomOne)) { t1 => L.LetP(n, CPS.BlockAlloc(tag), Seq(t1), transform(body)) } case L3.BlockTag => tempLetP(CPS.BlockTag, args map (Left(_))) { t1 => tempLetP(CPS.ShiftLeft, Seq(Right(t1), lAtomOne)) { t2 => L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), transform(body)) } } case L3.BlockLength => tempLetP(CPS.BlockLength, args map (Left(_))) { t1 => tempLetP(CPS.ShiftLeft, Seq(Right(t1), lAtomOne)) { t2 => L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), transform(body)) } } case L3.BlockSet => val block = rewrite(x) val value = rewrite(z) tempLetP(CPS.ShiftRight, Seq(Left(y), Right(L.AtomL(1)))){idx => L.LetP(n, CPS.BlockSet, Seq(block, idx, value), transform(body)) } case L3.BlockGet => val block = rewrite(x) tempLetP(CPS.ShiftRight, Seq(Left(y), Right(L.AtomL(1)))){ idx => L.LetP(n, CPS.BlockGet, Seq(block, idx), transform(body)) } case L3.ByteRead => tempLetP(CPS.ByteRead, Seq()){ t1 => tempLetP(CPS.ShiftLeft, Seq(Right(t1), lAtomOne)) { t2 => L.LetP(n, CPS.Add, Seq(t2, L.AtomL(1)), transform(body)) } } case L3.ByteWrite => tempLetP(CPS.ShiftRight, Seq(Left(x), lAtomOne)) { t1 => L.LetP(n, CPS.ByteWrite, Seq(t1), transform(body)) } case L3.CharToInt => L.LetP(n, CPS.ShiftRight, Seq(rewrite(x), L.AtomL(2)), transform(body)) case L3.Id => val newKnownFuns = x match { case H.AtomN(xName) if knownFuns contains xName => knownFuns.updated(n, knownFuns(xName)) case _ => knownFuns } L.LetP(n, CPS.Id, Seq(rewrite(x)), transform(body)(newKnownFuns)) case L3.IntToChar => tempLetP(CPS.ShiftLeft, Seq(Left(x), Right(L.AtomL(2)))){ t1 => L.LetP(n, CPS.Add, Seq(t1, L.AtomL(2)), transform(body)) } case _ => throw new Exception("Unreachable code (unary letP) " + prim.getClass) } } // Creates an outer LetP, and binds the result to a name, // then passes the name to mkBody // Works similarly to transform in the CL3->CPS translation // Does *not* tag the integers given as arguments private def tempLetP[A <: L.Tree](p: CPS, args: Seq[Either[H.Atom, L.Atom]]) (mkBody: L.Atom => A): L.LetP = { val lArgs = args.map { case Left(hAtom) => rewrite(hAtom) case Right(lAtom) => lAtom } val tmpName = Symbol.fresh("x") val innerLetP: A = mkBody(L.AtomN(tmpName)) L.LetP(tmpName, p, lArgs, innerLetP) } private def rewrite(a: H.Atom): L.Atom = a match { case H.AtomN(n) => L.AtomN(n) case H.AtomL(IntLit(i)) => L.AtomL((i.toInt << 1) | 0x1) case H.AtomL(CharLit(c)) => L.AtomL((c.toInt << 3) | 0x6) // 110 case H.AtomL(BooleanLit(b)) => if (b) L.AtomL(0x1a) // 11010 else L.AtomL(0x0a) // 01010 case H.AtomL(UnitLit) => L.AtomL(2) } }