package l3 import scala.collection.mutable.{ Map => MutableMap } abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] (val treeModule: T) { import treeModule._ protected def rewrite(tree: Tree): Tree = { val simplifiedTree = fixedPoint(tree)(shrink) val maxSize = size(simplifiedTree) * 3 / 2 fixedPoint(simplifiedTree, 8) { t => inline(t, maxSize) } } private case class Count(applied: Int = 0, asValue: Int = 0) private case class State( census: Map[Name, Count], aSubst: Subst[Atom] = emptySubst, cSubst: Subst[Name] = emptySubst, eInvEnv: Map[(ValuePrimitive, Seq[Atom]), Atom] = Map.empty, cEnv: Map[Name, Cnt] = Map.empty, fEnv: Map[Name, Fun] = Map.empty) { def dead(s: Name): Boolean = ! census.contains(s) def appliedOnce(s: Name): Boolean = census.get(s).contains(Count(applied = 1, asValue = 0)) def withASubst(from: Atom, to: Atom): State = copy(aSubst = aSubst + (from -> aSubst(to))) def withASubst(from: Name, to: Atom): State = withASubst(AtomN(from), to) def withASubst(from: Name, to: Literal): State = withASubst(from, AtomL(to)) def withASubst(from: Seq[Name], to: Seq[Atom]): State = copy(aSubst = aSubst ++ (from.map(AtomN) zip to.map(aSubst))) def withCSubst(from: Name, to: Name): State = copy(cSubst = cSubst + (from -> cSubst(to))) def withExp(atom: Atom, prim: ValuePrimitive, args: Seq[Atom]): State = copy(eInvEnv = eInvEnv + ((prim, args) -> atom)) def withExp(name: Name, prim: ValuePrimitive, args: Seq[Atom]): State = withExp(AtomN(name), prim, args) def withCnts(cnts: Seq[Cnt]): State = copy(cEnv = cEnv ++ (cnts.map(_.name) zip cnts)) def withFuns(funs: Seq[Fun]): State = copy(fEnv = fEnv ++ (funs.map(_.name) zip funs)) } // Shrinking optimizations private def shrink(tree: Tree): Tree = shrink(tree, State(census(tree))) private def shrink(tree: Tree, s: State): Tree = tree match { case AppF(funAtom, retC, args) => funAtom match { case AtomN(n) if s.fEnv.contains(n) && s.appliedOnce(n) => // Inline val fun = s.fEnv(n) val newState = s.withASubst(fun.args, args) .withCSubst(fun.retC, retC) val newBody = shrink(fun.body, newState) newBody case _ => tree } case LetF(funs, body) => val undeadFuns = funs.filter(fun => !s.dead(fun.name)) val undeadShrunkFuns = undeadFuns.map { fun => Fun(fun.name, fun.retC, fun.args, shrink(fun.body, s))} val nonInlined = undeadShrunkFuns .filter(fun => !s.appliedOnce(fun.name)) val newState = s.withFuns(undeadShrunkFuns) val newBody = shrink(body, newState) if (nonInlined.isEmpty) { newBody } else { LetF(nonInlined, newBody) } case LetP(name, prim, args, body) => if (s.dead(name) && !impure(prim)) { shrink(body, s) } else { val replacedArgs = args map { a => s.aSubst.getOrElse(a, a)} val primArgPair = (prim, replacedArgs) if (!unstable(prim) && s.eInvEnv.contains(primArgPair)) { val newS = s.withASubst(name, s.eInvEnv(primArgPair)) val newBody = shrink(body, newS) newBody } else { val newS = s.withExp(name, prim, replacedArgs) LetP(name, prim, replacedArgs, shrink(body, newS)) } } case LetC(cnts, body) => val undeadConts = cnts.filter(cnt => !s.dead(cnt.name)) val undeadShrunkConts = undeadConts.map { cnt => val newBody = shrink(cnt.body, s) Cnt(cnt.name, cnt.args, newBody) } val nonInlinedConts = undeadShrunkConts.filter { cnt => !s.appliedOnce(cnt.name) } val newBody = shrink(body, s.withCnts(undeadShrunkConts)) if (nonInlinedConts.isEmpty) { newBody } else { LetC(nonInlinedConts, newBody) } case Halt(a) => Halt(s.aSubst.getOrElse(a, a)) } // (Non-shrinking) inlining private def inline(tree: Tree, maxSize: Int): Tree = { def copyT(tree: Tree, subV: Subst[Atom], subC: Subst[Name]): Tree = { (tree: @unchecked) match { case LetP(name, prim, args, body) => val name1 = name.copy() LetP(name1, prim, args map subV, copyT(body, subV + (AtomN(name) -> AtomN(name1)), subC)) case LetC(cnts, body) => val names = cnts map (_.name) val names1 = names map (_.copy()) val subC1 = subC ++ (names zip names1) LetC(cnts map (copyC(_, subV, subC1)), copyT(body, subV, subC1)) case LetF(funs, body) => val names = funs map (_.name) val names1 = names map (_.copy()) val subV1 = subV ++ ((names map AtomN) zip (names1 map AtomN)) LetF(funs map (copyF(_, subV1, subC)), copyT(body, subV1, subC)) case AppC(cnt, args) => AppC(subC(cnt), args map subV) case AppF(fun, retC, args) => AppF(subV(fun), subC(retC), args map subV) case If(cond, args, thenC, elseC) => If(cond, args map subV, subC(thenC), subC(elseC)) case Halt(arg) => Halt(subV(arg)) } } def copyC(cnt: Cnt, subV: Subst[Atom], subC: Subst[Name]): Cnt = { val args1 = cnt.args map (_.copy()) val subV1 = subV ++ ((cnt.args map AtomN) zip (args1 map AtomN)) Cnt(subC(cnt.name), args1, copyT(cnt.body, subV1, subC)) } def copyF(fun: Fun, subV: Subst[Atom], subC: Subst[Name]): Fun = { val retC1 = fun.retC.copy() val subC1 = subC + (fun.retC -> retC1) val args1 = fun.args map (_.copy()) val subV1 = subV ++ ((fun.args map AtomN) zip (args1 map AtomN)) val AtomN(funName1) = subV(AtomN(fun.name)) Fun(funName1, retC1, args1, copyT(fun.body, subV1, subC1)) } val fibonacci = Seq(1, 2, 3, 5, 8, 13) val trees = LazyList.iterate((0, tree), fibonacci.length){ case (i, tree) => val funLimit = fibonacci(i) val cntLimit = i def sameLen[T,U](formalArgs: Seq[T], actualArgs: Seq[U]): Boolean = formalArgs.length == actualArgs.length def inlineT(tree: Tree)(implicit s: State): Tree = ??? (i + 1, fixedPoint(inlineT(tree)(State(census(tree))))(shrink)) } trees.takeWhile{ case (_, tree) => size(tree) <= maxSize }.last._2 } // Census computation private def census(tree: Tree): Map[Name, Count] = { val census = MutableMap[Name, Count]().withDefault(_ => Count()) val rhs = MutableMap[Name, Tree]() def incAppUseN(name: Name): Unit = { val currCount = census(name) census(name) = currCount.copy(applied = currCount.applied + 1) rhs.remove(name).foreach(addToCensus) } def incAppUseA(atom: Atom): Unit = atom.asName.foreach(incAppUseN(_)) def incValUseN(name: Name): Unit = { val currCount = census(name) census(name) = currCount.copy(asValue = currCount.asValue + 1) rhs.remove(name).foreach(addToCensus) } def incValUseA(atom: Atom): Unit = atom.asName.foreach(incValUseN(_)) def addToCensus(tree: Tree): Unit = (tree: @unchecked) match { case LetP(_, _, args, body) => args foreach incValUseA; addToCensus(body) case LetC(cnts, body) => rhs ++= (cnts map { c => (c.name, c.body) }); addToCensus(body) case LetF(funs, body) => rhs ++= (funs map { f => (f.name, f.body) }); addToCensus(body) case AppC(cnt, args) => incAppUseN(cnt); args foreach incValUseA case AppF(fun, retC, args) => incAppUseA(fun); incValUseN(retC); args foreach incValUseA case If(_, args, thenC, elseC) => args foreach incValUseA; incValUseN(thenC); incValUseN(elseC) case Halt(arg) => incValUseA(arg) } addToCensus(tree) census.toMap } private def size(tree: Tree): Int = (tree: @unchecked) match { case LetP(_, _, _, body) => size(body) + 1 case LetC(cs, body) => (cs map { c => size(c.body) }).sum + size(body) case LetF(fs, body) => (fs map { f => size(f.body) }).sum + size(body) case AppC(_, _) | AppF(_, _, _) | If(_, _, _, _) | Halt(_) => 1 } protected val impure: ValuePrimitive => Boolean protected val unstable: ValuePrimitive => Boolean protected val blockAllocTag: PartialFunction[ValuePrimitive, Literal] protected val blockTag: ValuePrimitive protected val blockLength: ValuePrimitive protected val identity: ValuePrimitive protected val leftNeutral: Set[(Literal, ValuePrimitive)] protected val rightNeutral: Set[(ValuePrimitive, Literal)] protected val leftAbsorbing: Set[(Literal, ValuePrimitive)] protected val rightAbsorbing: Set[(ValuePrimitive, Literal)] protected val sameArgReduce: PartialFunction[(ValuePrimitive, Atom), Atom] protected val sameArgReduceC: TestPrimitive => Boolean protected val vEvaluator: PartialFunction[(ValuePrimitive, Seq[Literal]), Literal] protected val cEvaluator: PartialFunction[(TestPrimitive, Seq[Literal]), Boolean] } object CPSOptimizerHigh extends CPSOptimizer(SymbolicCPSTreeModule) with (SymbolicCPSTreeModule.Tree => SymbolicCPSTreeModule.Tree) { import treeModule._ import L3Primitive._ def apply(tree: Tree): Tree = rewrite(tree) import scala.language.implicitConversions private[this] implicit def l3IntToLit(i: L3Int): Literal = IntLit(i) private[this] implicit def intToLit(i: Int): Literal = IntLit(L3Int(i)) protected val impure: ValuePrimitive => Boolean = Set(ByteRead, ByteWrite, BlockSet) protected val unstable: ValuePrimitive => Boolean = ??? protected val blockAllocTag: PartialFunction[ValuePrimitive, Literal] = { case BlockAlloc(t) => t } protected val blockTag: ValuePrimitive = BlockTag protected val blockLength: ValuePrimitive = BlockLength protected val identity: ValuePrimitive = Id protected val leftNeutral: Set[(Literal, ValuePrimitive)] = Set( (IntLit(L3Int(0)), IntAdd), (IntLit(L3Int(1)), IntMul), (IntLit(L3Int((-1 << 1) >> 1)), IntBitwiseAnd), (IntLit(L3Int(0)), IntBitwiseOr), (IntLit(L3Int(0)), IntBitwiseXOr) ) protected val rightNeutral: Set[(ValuePrimitive, Literal)] = Set( (IntAdd, IntLit(L3Int(0))), (IntSub, IntLit(L3Int(0))), (IntMul, IntLit(L3Int(1))), (IntDiv, IntLit(L3Int(1))), (IntShiftLeft, IntLit(L3Int(0))), (IntShiftRight, IntLit(L3Int(0))), (IntBitwiseAnd, IntLit(L3Int((-1 << 1) >> 1))), (IntBitwiseOr, IntLit(L3Int(0))), (IntBitwiseXOr, IntLit(L3Int(0))), ) protected val leftAbsorbing: Set[(Literal, ValuePrimitive)] = Set( (IntLit(L3Int(0)), IntMul), (IntLit(L3Int(0)), IntMod), (IntLit(L3Int(0)), IntBitwiseAnd), (IntLit(L3Int((-1 << 1) >> 1)), IntBitwiseOr), (IntLit(L3Int(0)), IntShiftLeft), (IntLit(L3Int(0)), IntShiftRight) ) protected val rightAbsorbing: Set[(ValuePrimitive, Literal)] = Set( (IntMul, IntLit(L3Int(0))), (IntBitwiseAnd, IntLit(L3Int(0))), (IntBitwiseOr, IntLit(L3Int((-1 << 1) >> 1))) ) protected val sameArgReduce: PartialFunction[(ValuePrimitive, Atom), Atom] = { case (IntBitwiseAnd | IntBitwiseOr, a) => a case (IntSub | IntBitwiseXOr | IntMod, _) => AtomL(IntLit(L3Int(0))) case (IntDiv, _) => AtomL(IntLit(L3Int(1))) } protected val sameArgReduceC: PartialFunction[TestPrimitive, Boolean] = { case IntLt => false case IntLe | Eq => true } protected val vEvaluator: PartialFunction[(ValuePrimitive, Seq[Literal]), Literal] = { case (vPrim, Seq(IntLit(x), IntLit(y))) => vPrim match { case IntAdd => IntLit(x + y) case IntSub => IntLit(x - y) case IntMod => IntLit(x % y) case IntDiv => IntLit(x / y) case IntMul => IntLit(x * y) case IntBitwiseAnd => IntLit(x & y) case IntBitwiseOr => IntLit(x | y) case IntBitwiseXOr => IntLit(x ^ y) case IntShiftLeft => IntLit(x << y) case IntShiftRight => IntLit(x >> y) } } protected val cEvaluator: PartialFunction[(TestPrimitive, Seq[Literal]), Boolean] = { case (IntLe, Seq(IntLit(x), IntLit(y))) => x <= y case (IntLt, Seq(IntLit(x), IntLit(y))) => x < y case (Eq, Seq(l1, l2)) => l1 == l2 } } object CPSOptimizerLow extends CPSOptimizer(SymbolicCPSTreeModuleLow) with (SymbolicCPSTreeModuleLow.LetF => SymbolicCPSTreeModuleLow.LetF) { import treeModule._ import CPSValuePrimitive._ import CPSTestPrimitive._ def apply(tree: LetF): LetF = rewrite(tree) match { case tree @ LetF(_, _) => tree case other => LetF(Seq(), other) } protected val impure: ValuePrimitive => Boolean = Set(BlockSet, ByteRead, ByteWrite) protected val unstable: ValuePrimitive => Boolean = { case BlockAlloc(_) | BlockGet | ByteRead => true case _ => false } protected val blockAllocTag: PartialFunction[ValuePrimitive, Literal] = { case BlockAlloc(tag) => tag } protected val blockTag: ValuePrimitive = BlockTag protected val blockLength: ValuePrimitive = BlockLength protected val identity: ValuePrimitive = Id protected val leftNeutral: Set[(Literal, ValuePrimitive)] = Set((0, Add), (1, Mul), (~0, And), (0, Or), (0, XOr)) protected val rightNeutral: Set[(ValuePrimitive, Literal)] = Set((Add, 0), (Sub, 0), (Mul, 1), (Div, 1), (ShiftLeft, 0), (ShiftRight, 0), (And, ~0), (Or, 0), (XOr, 0)) protected val leftAbsorbing: Set[(Literal, ValuePrimitive)] = Set((0, Mul), (0, Div), (0, ShiftLeft), (0, ShiftRight), (0, And), (~0, Or)) protected val rightAbsorbing: Set[(ValuePrimitive, Literal)] = Set((Mul, 0), (And, 0), (Or, ~0)) protected val sameArgReduce: PartialFunction[(ValuePrimitive, Atom), Atom] = { case (And | Or, a) => a case (Sub | Mod | XOr, _) => AtomL(0) case (Div, _) => AtomL(1) } protected val sameArgReduceC: PartialFunction[TestPrimitive, Boolean] = { case Le | Eq => true case Lt => false } protected val vEvaluator: PartialFunction[(ValuePrimitive, Seq[Literal]), Literal] = { case (Add, Seq(x, y)) => x + y case (Sub, Seq(x, y)) => x - y case (Mul, Seq(x, y)) => x * y case (Div, Seq(x, y)) if y.toInt != 0 => x / y case (Mod, Seq(x, y)) if y.toInt != 0 => x % y case (ShiftLeft, Seq(x, y)) => x << y case (ShiftRight, Seq(x, y)) => x >> y case (And, Seq(x, y)) => x & y case (Or, Seq(x, y)) => x | y case (XOr, Seq(x, y)) => x ^ y } protected val cEvaluator: PartialFunction[(TestPrimitive, Seq[Literal]), Boolean] = { case (Lt, Seq(x, y)) => x < y case (Le, Seq(x, y)) => x <= y case (Eq, Seq(x, y)) => x == y } }