package l3 import scala.collection.mutable.{ Map => MutableMap } abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }] (val treeModule: T) { import treeModule._ protected def rewrite(tree: Tree): Tree = { val simplifiedTree = fixedPoint(tree)(shrink) val maxSize = size(simplifiedTree) * 3 / 2 fixedPoint(simplifiedTree, 8) { t => inline(t, maxSize) } } private case class Count(applied: Int = 0, asValue: Int = 0) private case class State( census: Map[Name, Count], aSubst: Subst[Atom] = emptySubst, cSubst: Subst[Name] = emptySubst, eInvEnv: Map[(ValuePrimitive, Seq[Atom]), Atom] = Map.empty, cEnv: Map[Name, Cnt] = Map.empty, fEnv: Map[Name, Fun] = Map.empty) { def eEnv: Map[Atom, (ValuePrimitive, Seq[Atom])] = eInvEnv.map(_.swap) def dead(s: Name): Boolean = ! census.contains(s) def appliedOnce(s: Name): Boolean = census.get(s).contains(Count(applied = 1, asValue = 0)) def withASubst(from: Atom, to: Atom): State = copy(aSubst = aSubst + (from -> aSubst(to))) def withASubst(from: Name, to: Atom): State = withASubst(AtomN(from), to) def withASubst(from: Name, to: Literal): State = withASubst(from, AtomL(to)) def withASubst(from: Seq[Name], to: Seq[Atom]): State = copy(aSubst = aSubst ++ (from.map(AtomN) zip to.map(aSubst))) def withCSubst(from: Name, to: Name): State = copy(cSubst = cSubst + (from -> cSubst(to))) def withExp(atom: Atom, prim: ValuePrimitive, args: Seq[Atom]): State = copy(eInvEnv = eInvEnv + ((prim, args) -> atom)) def withExp(name: Name, prim: ValuePrimitive, args: Seq[Atom]): State = withExp(AtomN(name), prim, args) def withCnts(cnts: Seq[Cnt]): State = copy(cEnv = cEnv ++ (cnts.map(_.name) zip cnts)) def withFuns(funs: Seq[Fun]): State = copy(fEnv = fEnv ++ (funs.map(_.name) zip funs)) } // Shrinking optimizations private def shrink(tree: Tree): Tree = shrink(tree, State(census(tree))) private def replaceArgs(args: Seq[Atom], s: State): Seq[Atom] = args map { a => s.aSubst.getOrElse(a, a) } private def replaceCnt(cnt: Symbol, s: State): Symbol = s.cSubst.getOrElse(cnt, cnt) private def shrink(tree: Tree, s: State): Tree = tree match { case AppC(oldCntName, args) => { val replacedArgs = replaceArgs(args, s) val cntName = replaceCnt(oldCntName, s) if (s.appliedOnce(cntName) && s.cEnv.contains(cntName)) { val cnt = s.cEnv(cntName) val newState = s.withASubst(cnt.args, replacedArgs) shrink(cnt.body, newState) } else { AppC(cntName, replacedArgs) } } case AppF(oldFunAtom, oldRetC, args) => { val funAtom = s.aSubst.getOrElse(oldFunAtom, oldFunAtom) val replacedArgs = replaceArgs(args, s) val retC = replaceCnt(oldRetC, s) funAtom match { case AtomN(n) if s.fEnv.contains(n) && s.appliedOnce(n) => val fun = s.fEnv(n) val newState = s.withASubst(fun.args, replacedArgs).withCSubst(fun.retC, retC) shrink(fun.body, newState) case _ => AppF(funAtom, retC, replacedArgs) } } case LetF(funs, body) => { val undeadFuns = funs.filterNot(f => s.dead(f.name)) val nonInlined = undeadFuns.filter(fun => !s.appliedOnce(fun.name)) val newState = s.withFuns(undeadFuns) val shrunkFuns = nonInlined.map(f => Fun(f.name, f.retC, f.args, shrink(f.body, newState))) val shrunkBody = shrink(body, newState) if (shrunkFuns.isEmpty) shrunkBody else LetF(shrunkFuns, shrunkBody) } case LetP(name, prim, args, body) => { val replacedArgs = replaceArgs(args, s) lazy val newState = s.withExp(name, prim, replacedArgs) lazy val shrunkBody = shrink(body, newState) lazy val noOp = LetP(name, prim, replacedArgs, shrunkBody) // Dead code elim if (s.dead(name) && !impure(prim)) { shrunkBody } else { val allLitOpt = replacedArgs.map(_.asLiteral) lazy val asLit = allLitOpt.map(_.get) if (!unstable(prim) && !impure(prim) && s.eInvEnv.contains((prim, replacedArgs))) { val preComputedAtom = s.eInvEnv((prim, replacedArgs)) shrink(body, newState.withASubst(name, preComputedAtom)) } else if (prim == identity) { shrink(body, s.withASubst(name, replacedArgs(0))) } else replacedArgs match { // Constant folding case Seq(AtomL(l1), AtomL(l2)) if vEvaluator.isDefinedAt((prim, asLit)) => shrink(body, s.withASubst(name, vEvaluator((prim, asLit)))) // Same argument reduction case Seq(a1, a2) if a1 == a2 && sameArgReduce.isDefinedAt(prim, a1) => val reduced = sameArgReduce((prim, a1)) shrink(body, s.withASubst(name, sameArgReduce((prim, reduced)))) // Left Neutral case Seq(AtomL(l1), a2) if leftNeutral((l1, prim)) => shrink(body, s.withASubst(name, a2)) // Left Absorbing case Seq(AtomL(l1), a2) if leftAbsorbing((l1, prim)) => shrink(body, s.withASubst(name, l1)) // Right Neutral case Seq(a1, AtomL(l2)) if rightNeutral((prim, l2)) => shrink(body, s.withASubst(name, a1)) // Right Absorbing case Seq(a1, AtomL(l2)) if rightAbsorbing((prim, l2)) => shrink(body, s.withASubst(name, l2)) case Seq(a1) if s.eEnv.isDefinedAt(a1) => val (maybeBlockAlloc, maybeLengthAtom) = s.eEnv(a1) blockAllocTag.lift(maybeBlockAlloc) match { case Some(tag) => if (prim == blockTag) { shrink(body, s.withASubst(name, AtomL(tag))) } else if (prim == blockLength) { shrink(body, s.withASubst(name, maybeLengthAtom(0))) } else { noOp } case None => noOp } case _ => noOp } } } case LetC(cnts, body) => { val undeadConts = cnts.filterNot(cnt => s.dead(cnt.name)) val undeadShrunkConts = undeadConts.map { cnt => val newBody = shrink(cnt.body, s) Cnt(cnt.name, cnt.args, newBody) } val nonInlinedConts = undeadShrunkConts.filterNot { cnt => s.appliedOnce(cnt.name) } val newBody = shrink(body, s.withCnts(undeadShrunkConts)) if (nonInlinedConts.isEmpty) { newBody } else { LetC(nonInlinedConts, newBody) } } case If(cond, args, thenC, elseC) => val newArgs = replaceArgs(args, s) def getApp(b: Boolean): AppC = { val contToUse = if (b) thenC else elseC AppC(contToUse, Seq()) } val allLitOpt = newArgs.map(_.asLiteral) val isAllLit = allLitOpt.forall(_.isDefined) lazy val asLit = allLitOpt.map(_.get) // Constant folding if (isAllLit && cEvaluator.isDefinedAt((cond, asLit))) { getApp(cEvaluator((cond, asLit))) } else { lazy val x = newArgs(0) lazy val y = newArgs(1) if (newArgs.length == 2 && x == y) { getApp(sameArgReduceC(cond)) } else { If(cond, newArgs, replaceCnt(thenC, s), replaceCnt(elseC, s)) } } case Halt(a) => Halt(s.aSubst.getOrElse(a, a)) } // (Non-shrinking) inlining private def inline(tree: Tree, maxSize: Int): Tree = { def copyT(tree: Tree, subV: Subst[Atom], subC: Subst[Name]): Tree = { (tree: @unchecked) match { case LetP(name, prim, args, body) => val name1 = name.copy() LetP(name1, prim, args map subV, copyT(body, subV + (AtomN(name) -> AtomN(name1)), subC)) case LetC(cnts, body) => val names = cnts map (_.name) val names1 = names map (_.copy()) val subC1 = subC ++ (names zip names1) LetC(cnts map (copyC(_, subV, subC1)), copyT(body, subV, subC1)) case LetF(funs, body) => val names = funs map (_.name) val names1 = names map (_.copy()) val subV1 = subV ++ ((names map AtomN) zip (names1 map AtomN)) LetF(funs map (copyF(_, subV1, subC)), copyT(body, subV1, subC)) case AppC(cnt, args) => AppC(subC(cnt), args map subV) case AppF(fun, retC, args) => AppF(subV(fun), subC(retC), args map subV) case If(cond, args, thenC, elseC) => If(cond, args map subV, subC(thenC), subC(elseC)) case Halt(arg) => Halt(subV(arg)) } } def copyC(cnt: Cnt, subV: Subst[Atom], subC: Subst[Name]): Cnt = { val args1 = cnt.args map (_.copy()) val subV1 = subV ++ ((cnt.args map AtomN) zip (args1 map AtomN)) Cnt(subC(cnt.name), args1, copyT(cnt.body, subV1, subC)) } def copyF(fun: Fun, subV: Subst[Atom], subC: Subst[Name]): Fun = { val retC1 = fun.retC.copy() val subC1 = subC + (fun.retC -> retC1) val args1 = fun.args map (_.copy()) val subV1 = subV ++ ((fun.args map AtomN) zip (args1 map AtomN)) val AtomN(funName1) = subV(AtomN(fun.name)) Fun(funName1, retC1, args1, copyT(fun.body, subV1, subC1)) } val fibonacci = Seq(1, 2, 3, 5, 8, 13) val trees = LazyList.iterate((0, tree), fibonacci.length){ case (i, tree) => val funLimit = fibonacci(i) val cntLimit = i def sameLen[T,U](formalArgs: Seq[T], actualArgs: Seq[U]): Boolean = formalArgs.length == actualArgs.length def inlineT(tree: Tree)(implicit s: State): Tree = tree match { case AppC(oldCntName, oldArgs) => { val cntName = replaceCnt(oldCntName, s) val args = replaceArgs(oldArgs, s) if (s.cEnv.contains(cntName)) { val cnt = copyC(s.cEnv(cntName), s.aSubst, s.cSubst) val newState = s.withASubst(cnt.args, args) inlineT(cnt.body)(newState) } else { AppC(cntName, args) } } case AppF(oldFunName, oldRetC, oldArgs) => { val funName = s.aSubst.getOrElse(oldFunName, oldFunName) val retC = replaceCnt(oldRetC, s) val args = replaceArgs(oldArgs, s) funName match { case AtomN(n) if s.fEnv.contains(n) => { val oldFun = s.fEnv(n) val notUsedAsValue = s.census.get(oldFun.name).fold(false)(count => count.asValue == 0) val nonRecursive = notUsedAsValue && !census(oldFun.body).contains(oldFun.name) if (nonRecursive && sameLen(oldFun.args, args)) { val fun = copyF(oldFun, s.aSubst, s.cSubst) val newState = s.withASubst(fun.args, args).withCSubst(fun.retC, retC) inlineT(fun.body)(newState) } else { AppF(funName, retC, args) } } case _ => AppF(funName, retC, args) } } case Halt(arg) => Halt(s.aSubst.getOrElse(arg, arg)) case If(cond, args, thenC, elseC) => If(cond, replaceArgs(args, s), replaceCnt(thenC, s), replaceCnt(elseC, s)) case LetP(name, prim, args, body) => LetP(name, prim, replaceArgs(args, s), inlineT(body)) case LetC(oldCnts, body) => { val cnts = oldCnts map (c => Cnt(c.name, c.args, inlineT(c.body))) val inlinedCnts = cnts.filter(c => size(c.body) <= cntLimit) val newState = s.withCnts(inlinedCnts) LetC(cnts, inlineT(body)(newState)) } case LetF(oldFuns, body) => { val funs = oldFuns map (f => Fun(f.name, f.retC, f.args, inlineT(f.body))) val inlinedFuns = funs.filter(f => size(f.body) <= funLimit) val newState = s.withFuns(inlinedFuns) LetF(funs, inlineT(body)(newState)) } } (i + 1, fixedPoint(inlineT(tree)(State(census(tree))))(shrink)) } trees.takeWhile{ case (_, tree) => size(tree) <= maxSize }.last._2 } // Census computation private def census(tree: Tree): Map[Name, Count] = { val census = MutableMap[Name, Count]().withDefault(_ => Count()) val rhs = MutableMap[Name, Tree]() def incAppUseN(name: Name): Unit = { val currCount = census(name) census(name) = currCount.copy(applied = currCount.applied + 1) rhs.remove(name).foreach(addToCensus) } def incAppUseA(atom: Atom): Unit = atom.asName.foreach(incAppUseN(_)) def incValUseN(name: Name): Unit = { val currCount = census(name) census(name) = currCount.copy(asValue = currCount.asValue + 1) rhs.remove(name).foreach(addToCensus) } def incValUseA(atom: Atom): Unit = atom.asName.foreach(incValUseN(_)) def addToCensus(tree: Tree): Unit = (tree: @unchecked) match { case LetP(_, _, args, body) => args foreach incValUseA; addToCensus(body) case LetC(cnts, body) => rhs ++= (cnts map { c => (c.name, c.body) }); addToCensus(body) case LetF(funs, body) => rhs ++= (funs map { f => (f.name, f.body) }); addToCensus(body) case AppC(cnt, args) => incAppUseN(cnt); args foreach incValUseA case AppF(fun, retC, args) => incAppUseA(fun); incValUseN(retC); args foreach incValUseA case If(_, args, thenC, elseC) => args foreach incValUseA; incValUseN(thenC); incValUseN(elseC) case Halt(arg) => incValUseA(arg) } addToCensus(tree) census.toMap } private def size(tree: Tree): Int = (tree: @unchecked) match { case LetP(_, _, _, body) => size(body) + 1 case LetC(cs, body) => (cs map { c => size(c.body) }).sum + size(body) case LetF(fs, body) => (fs map { f => size(f.body) }).sum + size(body) case AppC(_, _) | AppF(_, _, _) | If(_, _, _, _) | Halt(_) => 1 } protected val impure: ValuePrimitive => Boolean protected val unstable: ValuePrimitive => Boolean protected val blockAllocTag: PartialFunction[ValuePrimitive, Literal] protected val blockTag: ValuePrimitive protected val blockLength: ValuePrimitive protected val identity: ValuePrimitive protected val leftNeutral: Set[(Literal, ValuePrimitive)] protected val rightNeutral: Set[(ValuePrimitive, Literal)] protected val leftAbsorbing: Set[(Literal, ValuePrimitive)] protected val rightAbsorbing: Set[(ValuePrimitive, Literal)] protected val sameArgReduce: PartialFunction[(ValuePrimitive, Atom), Atom] protected val sameArgReduceC: TestPrimitive => Boolean protected val vEvaluator: PartialFunction[(ValuePrimitive, Seq[Literal]), Literal] protected val cEvaluator: PartialFunction[(TestPrimitive, Seq[Literal]), Boolean] } object CPSOptimizerHigh extends CPSOptimizer(SymbolicCPSTreeModule) with (SymbolicCPSTreeModule.Tree => SymbolicCPSTreeModule.Tree) { import treeModule._ import L3Primitive._ def apply(tree: Tree): Tree = rewrite(tree) import scala.language.implicitConversions private[this] implicit def l3IntToLit(i: L3Int): Literal = IntLit(i) private[this] implicit def intToLit(i: Int): Literal = IntLit(L3Int(i)) protected val impure: ValuePrimitive => Boolean = Set(ByteRead, ByteWrite, BlockSet) protected val unstable: ValuePrimitive => Boolean = { case BlockAlloc(_) | BlockGet | ByteRead => true case _ => false } protected val blockAllocTag: PartialFunction[ValuePrimitive, Literal] = { case BlockAlloc(t) => t } protected val blockTag: ValuePrimitive = BlockTag protected val blockLength: ValuePrimitive = BlockLength protected val identity: ValuePrimitive = Id protected val leftNeutral: Set[(Literal, ValuePrimitive)] = Set( (IntLit(L3Int(0)), IntAdd), (IntLit(L3Int(1)), IntMul), (IntLit(L3Int((-1 << 1) >> 1)), IntBitwiseAnd), (IntLit(L3Int(0)), IntBitwiseOr), (IntLit(L3Int(0)), IntBitwiseXOr) ) protected val rightNeutral: Set[(ValuePrimitive, Literal)] = Set( (IntAdd, IntLit(L3Int(0))), (IntSub, IntLit(L3Int(0))), (IntMul, IntLit(L3Int(1))), (IntDiv, IntLit(L3Int(1))), (IntShiftLeft, IntLit(L3Int(0))), (IntShiftRight, IntLit(L3Int(0))), (IntBitwiseAnd, IntLit(L3Int((-1 << 1) >> 1))), (IntBitwiseOr, IntLit(L3Int(0))), (IntBitwiseXOr, IntLit(L3Int(0))), ) protected val leftAbsorbing: Set[(Literal, ValuePrimitive)] = Set( (IntLit(L3Int(0)), IntMul), (IntLit(L3Int(0)), IntMod), (IntLit(L3Int(0)), IntBitwiseAnd), (IntLit(L3Int((-1 << 1) >> 1)), IntBitwiseOr), (IntLit(L3Int(0)), IntShiftLeft), (IntLit(L3Int(0)), IntShiftRight) ) protected val rightAbsorbing: Set[(ValuePrimitive, Literal)] = Set( (IntMul, IntLit(L3Int(0))), (IntBitwiseAnd, IntLit(L3Int(0))), (IntBitwiseOr, IntLit(L3Int((-1 << 1) >> 1))) ) protected val sameArgReduce: PartialFunction[(ValuePrimitive, Atom), Atom] = { case (IntBitwiseAnd | IntBitwiseOr, a) => a case (IntSub | IntBitwiseXOr | IntMod, _) => AtomL(IntLit(L3Int(0))) case (IntDiv, _) => AtomL(IntLit(L3Int(1))) } protected val sameArgReduceC: PartialFunction[TestPrimitive, Boolean] = { case IntLt => false case IntLe | Eq => true } protected val vEvaluator: PartialFunction[(ValuePrimitive, Seq[Literal]), Literal] = { case (vPrim, Seq(IntLit(x), IntLit(y))) => vPrim match { case IntAdd => IntLit(x + y) case IntSub => IntLit(x - y) case IntMod => IntLit(x % y) case IntDiv => IntLit(x / y) case IntMul => IntLit(x * y) case IntBitwiseAnd => IntLit(x & y) case IntBitwiseOr => IntLit(x | y) case IntBitwiseXOr => IntLit(x ^ y) case IntShiftLeft => IntLit(x << y) case IntShiftRight => IntLit(x >> y) } } protected val cEvaluator: PartialFunction[(TestPrimitive, Seq[Literal]), Boolean] = { case (IntLe, Seq(IntLit(x), IntLit(y))) => x <= y case (IntLt, Seq(IntLit(x), IntLit(y))) => x < y case (Eq, Seq(l1, l2)) => l1 == l2 } } object CPSOptimizerLow extends CPSOptimizer(SymbolicCPSTreeModuleLow) with (SymbolicCPSTreeModuleLow.LetF => SymbolicCPSTreeModuleLow.LetF) { import treeModule._ import CPSValuePrimitive._ import CPSTestPrimitive._ def apply(tree: LetF): LetF = rewrite(tree) match { case tree @ LetF(_, _) => tree case other => LetF(Seq(), other) } protected val impure: ValuePrimitive => Boolean = Set(BlockSet, ByteRead, ByteWrite) protected val unstable: ValuePrimitive => Boolean = { case BlockAlloc(_) | BlockGet | ByteRead => true case _ => false } protected val blockAllocTag: PartialFunction[ValuePrimitive, Literal] = { case BlockAlloc(tag) => tag } protected val blockTag: ValuePrimitive = BlockTag protected val blockLength: ValuePrimitive = BlockLength protected val identity: ValuePrimitive = Id protected val leftNeutral: Set[(Literal, ValuePrimitive)] = Set((0, Add), (1, Mul), (~0, And), (0, Or), (0, XOr)) protected val rightNeutral: Set[(ValuePrimitive, Literal)] = Set((Add, 0), (Sub, 0), (Mul, 1), (Div, 1), (ShiftLeft, 0), (ShiftRight, 0), (And, ~0), (Or, 0), (XOr, 0)) protected val leftAbsorbing: Set[(Literal, ValuePrimitive)] = Set((0, Mul), (0, Div), (0, ShiftLeft), (0, ShiftRight), (0, And), (~0, Or)) protected val rightAbsorbing: Set[(ValuePrimitive, Literal)] = Set((Mul, 0), (And, 0), (Or, ~0)) protected val sameArgReduce: PartialFunction[(ValuePrimitive, Atom), Atom] = { case (And | Or, a) => a case (Sub | Mod | XOr, _) => AtomL(0) case (Div, _) => AtomL(1) } protected val sameArgReduceC: PartialFunction[TestPrimitive, Boolean] = { case Le | Eq => true case Lt => false } protected val vEvaluator: PartialFunction[(ValuePrimitive, Seq[Literal]), Literal] = { case (Add, Seq(x, y)) => x + y case (Sub, Seq(x, y)) => x - y case (Mul, Seq(x, y)) => x * y case (Div, Seq(x, y)) if y.toInt != 0 => x / y case (Mod, Seq(x, y)) if y.toInt != 0 => x % y case (ShiftLeft, Seq(x, y)) => x << y case (ShiftRight, Seq(x, y)) => x >> y case (And, Seq(x, y)) => x & y case (Or, Seq(x, y)) => x | y case (XOr, Seq(x, y)) => x ^ y } protected val cEvaluator: PartialFunction[(TestPrimitive, Seq[Literal]), Boolean] = { case (Lt, Seq(x, y)) => x < y case (Le, Seq(x, y)) => x <= y case (Eq, Seq(x, y)) => x == y } }