Commit f868bab7 authored by Sapphie's avatar Sapphie
Browse files

Import optimisation skeleton

parent 146e5f46
package l3
import scala.collection.mutable.{ Map => MutableMap }
abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
(val treeModule: T) {
import treeModule._
protected def rewrite(tree: Tree): Tree = {
val simplifiedTree = fixedPoint(tree)(shrink)
val maxSize = size(simplifiedTree) * 3 / 2
fixedPoint(simplifiedTree, 8) { t => inline(t, maxSize) }
}
private case class Count(applied: Int = 0, asValue: Int = 0)
private case class State(
census: Map[Name, Count],
aSubst: Subst[Atom] = emptySubst,
cSubst: Subst[Name] = emptySubst,
eInvEnv: Map[(ValuePrimitive, Seq[Atom]), Atom] = Map.empty,
cEnv: Map[Name, Cnt] = Map.empty,
fEnv: Map[Name, Fun] = Map.empty) {
def dead(s: Name): Boolean =
! census.contains(s)
def appliedOnce(s: Name): Boolean =
census.get(s).contains(Count(applied = 1, asValue = 0))
def withASubst(from: Atom, to: Atom): State =
copy(aSubst = aSubst + (from -> aSubst(to)))
def withASubst(from: Name, to: Atom): State =
withASubst(AtomN(from), to)
def withASubst(from: Name, to: Literal): State =
withASubst(from, AtomL(to))
def withASubst(from: Seq[Name], to: Seq[Atom]): State =
copy(aSubst = aSubst ++ (from.map(AtomN) zip to.map(aSubst)))
def withCSubst(from: Name, to: Name): State =
copy(cSubst = cSubst + (from -> cSubst(to)))
def withExp(atom: Atom, prim: ValuePrimitive, args: Seq[Atom]): State =
copy(eInvEnv = eInvEnv + ((prim, args) -> atom))
def withExp(name: Name, prim: ValuePrimitive, args: Seq[Atom]): State =
withExp(AtomN(name), prim, args)
def withCnts(cnts: Seq[Cnt]): State =
copy(cEnv = cEnv ++ (cnts.map(_.name) zip cnts))
def withFuns(funs: Seq[Fun]): State =
copy(fEnv = fEnv ++ (funs.map(_.name) zip funs))
}
// Shrinking optimizations
private def shrink(tree: Tree): Tree =
shrink(tree, State(census(tree)))
private def shrink(tree: Tree, s: State): Tree = ???
// (Non-shrinking) inlining
private def inline(tree: Tree, maxSize: Int): Tree = {
def copyT(tree: Tree, subV: Subst[Atom], subC: Subst[Name]): Tree = {
(tree: @unchecked) match {
case LetP(name, prim, args, body) =>
val name1 = name.copy()
LetP(name1, prim, args map subV,
copyT(body, subV + (AtomN(name) -> AtomN(name1)), subC))
case LetC(cnts, body) =>
val names = cnts map (_.name)
val names1 = names map (_.copy())
val subC1 = subC ++ (names zip names1)
LetC(cnts map (copyC(_, subV, subC1)), copyT(body, subV, subC1))
case LetF(funs, body) =>
val names = funs map (_.name)
val names1 = names map (_.copy())
val subV1 = subV ++ ((names map AtomN) zip (names1 map AtomN))
LetF(funs map (copyF(_, subV1, subC)), copyT(body, subV1, subC))
case AppC(cnt, args) =>
AppC(subC(cnt), args map subV)
case AppF(fun, retC, args) =>
AppF(subV(fun), subC(retC), args map subV)
case If(cond, args, thenC, elseC) =>
If(cond, args map subV, subC(thenC), subC(elseC))
case Halt(arg) =>
Halt(subV(arg))
}
}
def copyC(cnt: Cnt, subV: Subst[Atom], subC: Subst[Name]): Cnt = {
val args1 = cnt.args map (_.copy())
val subV1 = subV ++ ((cnt.args map AtomN) zip (args1 map AtomN))
Cnt(subC(cnt.name), args1, copyT(cnt.body, subV1, subC))
}
def copyF(fun: Fun, subV: Subst[Atom], subC: Subst[Name]): Fun = {
val retC1 = fun.retC.copy()
val subC1 = subC + (fun.retC -> retC1)
val args1 = fun.args map (_.copy())
val subV1 = subV ++ ((fun.args map AtomN) zip (args1 map AtomN))
val AtomN(funName1) = subV(AtomN(fun.name))
Fun(funName1, retC1, args1, copyT(fun.body, subV1, subC1))
}
val fibonacci = Seq(1, 2, 3, 5, 8, 13)
val trees = LazyList.iterate((0, tree), fibonacci.length){ case (i, tree) =>
val funLimit = fibonacci(i)
val cntLimit = i
def sameLen[T,U](formalArgs: Seq[T], actualArgs: Seq[U]): Boolean =
formalArgs.length == actualArgs.length
def inlineT(tree: Tree)(implicit s: State): Tree = ???
(i + 1, fixedPoint(inlineT(tree)(State(census(tree))))(shrink))
}
trees.takeWhile{ case (_, tree) => size(tree) <= maxSize }.last._2
}
// Census computation
private def census(tree: Tree): Map[Name, Count] = {
val census = MutableMap[Name, Count]().withDefault(_ => Count())
val rhs = MutableMap[Name, Tree]()
def incAppUseN(name: Name): Unit = {
val currCount = census(name)
census(name) = currCount.copy(applied = currCount.applied + 1)
rhs.remove(name).foreach(addToCensus)
}
def incAppUseA(atom: Atom): Unit =
atom.asName.foreach(incAppUseN(_))
def incValUseN(name: Name): Unit = {
val currCount = census(name)
census(name) = currCount.copy(asValue = currCount.asValue + 1)
rhs.remove(name).foreach(addToCensus)
}
def incValUseA(atom: Atom): Unit =
atom.asName.foreach(incValUseN(_))
def addToCensus(tree: Tree): Unit = (tree: @unchecked) match {
case LetP(_, _, args, body) =>
args foreach incValUseA; addToCensus(body)
case LetC(cnts, body) =>
rhs ++= (cnts map { c => (c.name, c.body) }); addToCensus(body)
case LetF(funs, body) =>
rhs ++= (funs map { f => (f.name, f.body) }); addToCensus(body)
case AppC(cnt, args) =>
incAppUseN(cnt); args foreach incValUseA
case AppF(fun, retC, args) =>
incAppUseA(fun); incValUseN(retC); args foreach incValUseA
case If(_, args, thenC, elseC) =>
args foreach incValUseA; incValUseN(thenC); incValUseN(elseC)
case Halt(arg) =>
incValUseA(arg)
}
addToCensus(tree)
census.toMap
}
private def size(tree: Tree): Int = (tree: @unchecked) match {
case LetP(_, _, _, body) => size(body) + 1
case LetC(cs, body) => (cs map { c => size(c.body) }).sum + size(body)
case LetF(fs, body) => (fs map { f => size(f.body) }).sum + size(body)
case AppC(_, _) | AppF(_, _, _) | If(_, _, _, _) | Halt(_) => 1
}
protected val impure: ValuePrimitive => Boolean
protected val unstable: ValuePrimitive => Boolean
protected val blockAllocTag: PartialFunction[ValuePrimitive, Literal]
protected val blockTag: ValuePrimitive
protected val blockLength: ValuePrimitive
protected val identity: ValuePrimitive
protected val leftNeutral: Set[(Literal, ValuePrimitive)]
protected val rightNeutral: Set[(ValuePrimitive, Literal)]
protected val leftAbsorbing: Set[(Literal, ValuePrimitive)]
protected val rightAbsorbing: Set[(ValuePrimitive, Literal)]
protected val sameArgReduce: PartialFunction[(ValuePrimitive, Atom), Atom]
protected val sameArgReduceC: TestPrimitive => Boolean
protected val vEvaluator: PartialFunction[(ValuePrimitive, Seq[Literal]),
Literal]
protected val cEvaluator: PartialFunction[(TestPrimitive, Seq[Literal]),
Boolean]
}
object CPSOptimizerHigh extends CPSOptimizer(SymbolicCPSTreeModule)
with (SymbolicCPSTreeModule.Tree => SymbolicCPSTreeModule.Tree) {
import treeModule._
import L3Primitive._
def apply(tree: Tree): Tree =
rewrite(tree)
import scala.language.implicitConversions
private[this] implicit def l3IntToLit(i: L3Int): Literal = IntLit(i)
private[this] implicit def intToLit(i: Int): Literal = IntLit(L3Int(i))
protected val impure: ValuePrimitive => Boolean = ???
protected val unstable: ValuePrimitive => Boolean = ???
protected val blockAllocTag: PartialFunction[ValuePrimitive, Literal] = ???
protected val blockTag: ValuePrimitive = ???
protected val blockLength: ValuePrimitive = ???
protected val identity: ValuePrimitive = ???
protected val leftNeutral: Set[(Literal, ValuePrimitive)] = ???
protected val rightNeutral: Set[(ValuePrimitive, Literal)] = ???
protected val leftAbsorbing: Set[(Literal, ValuePrimitive)] = ???
protected val rightAbsorbing: Set[(ValuePrimitive, Literal)] = ???
protected val sameArgReduce: PartialFunction[(ValuePrimitive, Atom), Atom] =
???
protected val sameArgReduceC: PartialFunction[TestPrimitive, Boolean] = ???
protected val vEvaluator: PartialFunction[(ValuePrimitive, Seq[Literal]),
Literal] = ???
protected val cEvaluator: PartialFunction[(TestPrimitive, Seq[Literal]),
Boolean] = ???
}
object CPSOptimizerLow extends CPSOptimizer(SymbolicCPSTreeModuleLow)
with (SymbolicCPSTreeModuleLow.LetF => SymbolicCPSTreeModuleLow.LetF) {
import treeModule._
import CPSValuePrimitive._
import CPSTestPrimitive._
def apply(tree: LetF): LetF = rewrite(tree) match {
case tree @ LetF(_, _) => tree
case other => LetF(Seq(), other)
}
protected val impure: ValuePrimitive => Boolean =
Set(BlockSet, ByteRead, ByteWrite)
protected val unstable: ValuePrimitive => Boolean = {
case BlockAlloc(_) | BlockGet | ByteRead => true
case _ => false
}
protected val blockAllocTag: PartialFunction[ValuePrimitive, Literal] = {
case BlockAlloc(tag) => tag
}
protected val blockTag: ValuePrimitive = BlockTag
protected val blockLength: ValuePrimitive = BlockLength
protected val identity: ValuePrimitive = Id
protected val leftNeutral: Set[(Literal, ValuePrimitive)] =
Set((0, Add), (1, Mul), (~0, And), (0, Or), (0, XOr))
protected val rightNeutral: Set[(ValuePrimitive, Literal)] =
Set((Add, 0), (Sub, 0), (Mul, 1), (Div, 1),
(ShiftLeft, 0), (ShiftRight, 0),
(And, ~0), (Or, 0), (XOr, 0))
protected val leftAbsorbing: Set[(Literal, ValuePrimitive)] =
Set((0, Mul), (0, Div),
(0, ShiftLeft), (0, ShiftRight),
(0, And), (~0, Or))
protected val rightAbsorbing: Set[(ValuePrimitive, Literal)] =
Set((Mul, 0), (And, 0), (Or, ~0))
protected val sameArgReduce: PartialFunction[(ValuePrimitive, Atom), Atom] = {
case (And | Or, a) => a
case (Sub | Mod | XOr, _) => AtomL(0)
case (Div, _) => AtomL(1)
}
protected val sameArgReduceC: PartialFunction[TestPrimitive, Boolean] = {
case Le | Eq => true
case Lt => false
}
protected val vEvaluator: PartialFunction[(ValuePrimitive, Seq[Literal]),
Literal] = {
case (Add, Seq(x, y)) => x + y
case (Sub, Seq(x, y)) => x - y
case (Mul, Seq(x, y)) => x * y
case (Div, Seq(x, y)) if y.toInt != 0 => x / y
case (Mod, Seq(x, y)) if y.toInt != 0 => x % y
case (ShiftLeft, Seq(x, y)) => x << y
case (ShiftRight, Seq(x, y)) => x >> y
case (And, Seq(x, y)) => x & y
case (Or, Seq(x, y)) => x | y
case (XOr, Seq(x, y)) => x ^ y
}
protected val cEvaluator: PartialFunction[(TestPrimitive, Seq[Literal]),
Boolean] = {
case (Lt, Seq(x, y)) => x < y
case (Le, Seq(x, y)) => x <= y
case (Eq, Seq(x, y)) => x == y
}
}
package l3;
import scala.collection.mutable.{ Map => MutableMap }
import SymbolicCPSTreeModuleLow._
final class Statistics {
private[this] var funCount = 0
private[this] var cntCount = 0
private[this] val nodes = MutableMap[Class[_ <: Tree], Int]()
private[this] val tPrims = MutableMap[Class[_ <: l3.CPSTestPrimitive], Int]()
private[this] val vPrims = MutableMap[Class[_ <: l3.CPSValuePrimitive], Int]()
private[this] def inc[K](m: MutableMap[K, Int], k: K): Unit =
m.put(k, m.getOrElse(k, 0) + 1)
def nodeCount(cls: Class[_ <: Tree]): Int =
nodes.getOrElse(cls, 0)
def testPrimitiveCount(cls: Class[_ <: CPSTestPrimitive]): Int =
tPrims.getOrElse(cls, 0)
def valuePrimitiveCount(cls: Class[_ <: CPSValuePrimitive]): Int =
vPrims.getOrElse(cls, 0)
def functionsCount = funCount
def continuationsCount = cntCount
def log(tree: Tree): Unit = {
inc(nodes, tree.getClass)
tree match {
case LetP(_, p, _, _) =>
inc(vPrims, p.getClass)
case LetF(fs, _) =>
funCount += fs.length
case LetC(cs, _) =>
cntCount += cs.length
case If(p, _, _, _) =>
inc(tPrims, p.getClass)
case _ =>
// nothing to do
}
}
override def toString: String = {
val sb = new StringBuilder()
for ((label, map) <- Seq(("Nodes", nodes),
("Value primitives", vPrims),
("Test primitives", tPrims))
if (map.nonEmpty)) {
sb ++= (label + "\n" + "=" * label.length + "\n")
map.toList
.map { case (c, o) => c.getSimpleName.replace("$","") -> o }
.sortBy { case (_, o) => -o }
.foreach { case (c, o) => sb ++= "%,9d %s\n".format(o, c) }
sb.append("\n")
}
sb ++= " Functions defined: %,7d\n".format(funCount)
sb ++= "Continuations defined: %,7d\n".format(cntCount)
sb.toString
}
}
......@@ -57,3 +57,7 @@ object ExamplesTests2 extends TestSuite with ExamplesTests {
object ExamplesTests3 extends TestSuite with ExamplesTests {
val backEnd = L3Tester.backEnd3
}
object ExamplesTests4 extends TestSuite with ExamplesTests {
val backEnd = L3Tester.backEnd4
}
......@@ -56,4 +56,13 @@ object L3Tester {
andThen CPSHoister
andThen CPSInterpreterLow
)
val backEnd4 = (
CL3ToCPSTranslator
andThen CPSOptimizerHigh
andThen CPSValueRepresenter
andThen CPSHoister
andThen CPSOptimizerLow
andThen CPSInterpreterLow
)
}
......@@ -93,3 +93,7 @@ object SyntheticTests2 extends TestSuite with SyntheticTests {
object SyntheticTests3 extends TestSuite with SyntheticTests {
val backEnd = L3Tester.backEnd3
}
object SyntheticTests4 extends TestSuite with SyntheticTests {
val backEnd = L3Tester.backEnd4
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment