Skip to content
Snippets Groups Projects
Commit f868bab7 authored by Sapphie's avatar Sapphie
Browse files

Import optimisation skeleton

parent 146e5f46
No related branches found
No related tags found
No related merge requests found
package l3
import scala.collection.mutable.{ Map => MutableMap }
abstract class CPSOptimizer[T <: CPSTreeModule { type Name = Symbol }]
(val treeModule: T) {
import treeModule._
protected def rewrite(tree: Tree): Tree = {
val simplifiedTree = fixedPoint(tree)(shrink)
val maxSize = size(simplifiedTree) * 3 / 2
fixedPoint(simplifiedTree, 8) { t => inline(t, maxSize) }
}
private case class Count(applied: Int = 0, asValue: Int = 0)
private case class State(
census: Map[Name, Count],
aSubst: Subst[Atom] = emptySubst,
cSubst: Subst[Name] = emptySubst,
eInvEnv: Map[(ValuePrimitive, Seq[Atom]), Atom] = Map.empty,
cEnv: Map[Name, Cnt] = Map.empty,
fEnv: Map[Name, Fun] = Map.empty) {
def dead(s: Name): Boolean =
! census.contains(s)
def appliedOnce(s: Name): Boolean =
census.get(s).contains(Count(applied = 1, asValue = 0))
def withASubst(from: Atom, to: Atom): State =
copy(aSubst = aSubst + (from -> aSubst(to)))
def withASubst(from: Name, to: Atom): State =
withASubst(AtomN(from), to)
def withASubst(from: Name, to: Literal): State =
withASubst(from, AtomL(to))
def withASubst(from: Seq[Name], to: Seq[Atom]): State =
copy(aSubst = aSubst ++ (from.map(AtomN) zip to.map(aSubst)))
def withCSubst(from: Name, to: Name): State =
copy(cSubst = cSubst + (from -> cSubst(to)))
def withExp(atom: Atom, prim: ValuePrimitive, args: Seq[Atom]): State =
copy(eInvEnv = eInvEnv + ((prim, args) -> atom))
def withExp(name: Name, prim: ValuePrimitive, args: Seq[Atom]): State =
withExp(AtomN(name), prim, args)
def withCnts(cnts: Seq[Cnt]): State =
copy(cEnv = cEnv ++ (cnts.map(_.name) zip cnts))
def withFuns(funs: Seq[Fun]): State =
copy(fEnv = fEnv ++ (funs.map(_.name) zip funs))
}
// Shrinking optimizations
private def shrink(tree: Tree): Tree =
shrink(tree, State(census(tree)))
private def shrink(tree: Tree, s: State): Tree = ???
// (Non-shrinking) inlining
private def inline(tree: Tree, maxSize: Int): Tree = {
def copyT(tree: Tree, subV: Subst[Atom], subC: Subst[Name]): Tree = {
(tree: @unchecked) match {
case LetP(name, prim, args, body) =>
val name1 = name.copy()
LetP(name1, prim, args map subV,
copyT(body, subV + (AtomN(name) -> AtomN(name1)), subC))
case LetC(cnts, body) =>
val names = cnts map (_.name)
val names1 = names map (_.copy())
val subC1 = subC ++ (names zip names1)
LetC(cnts map (copyC(_, subV, subC1)), copyT(body, subV, subC1))
case LetF(funs, body) =>
val names = funs map (_.name)
val names1 = names map (_.copy())
val subV1 = subV ++ ((names map AtomN) zip (names1 map AtomN))
LetF(funs map (copyF(_, subV1, subC)), copyT(body, subV1, subC))
case AppC(cnt, args) =>
AppC(subC(cnt), args map subV)
case AppF(fun, retC, args) =>
AppF(subV(fun), subC(retC), args map subV)
case If(cond, args, thenC, elseC) =>
If(cond, args map subV, subC(thenC), subC(elseC))
case Halt(arg) =>
Halt(subV(arg))
}
}
def copyC(cnt: Cnt, subV: Subst[Atom], subC: Subst[Name]): Cnt = {
val args1 = cnt.args map (_.copy())
val subV1 = subV ++ ((cnt.args map AtomN) zip (args1 map AtomN))
Cnt(subC(cnt.name), args1, copyT(cnt.body, subV1, subC))
}
def copyF(fun: Fun, subV: Subst[Atom], subC: Subst[Name]): Fun = {
val retC1 = fun.retC.copy()
val subC1 = subC + (fun.retC -> retC1)
val args1 = fun.args map (_.copy())
val subV1 = subV ++ ((fun.args map AtomN) zip (args1 map AtomN))
val AtomN(funName1) = subV(AtomN(fun.name))
Fun(funName1, retC1, args1, copyT(fun.body, subV1, subC1))
}
val fibonacci = Seq(1, 2, 3, 5, 8, 13)
val trees = LazyList.iterate((0, tree), fibonacci.length){ case (i, tree) =>
val funLimit = fibonacci(i)
val cntLimit = i
def sameLen[T,U](formalArgs: Seq[T], actualArgs: Seq[U]): Boolean =
formalArgs.length == actualArgs.length
def inlineT(tree: Tree)(implicit s: State): Tree = ???
(i + 1, fixedPoint(inlineT(tree)(State(census(tree))))(shrink))
}
trees.takeWhile{ case (_, tree) => size(tree) <= maxSize }.last._2
}
// Census computation
private def census(tree: Tree): Map[Name, Count] = {
val census = MutableMap[Name, Count]().withDefault(_ => Count())
val rhs = MutableMap[Name, Tree]()
def incAppUseN(name: Name): Unit = {
val currCount = census(name)
census(name) = currCount.copy(applied = currCount.applied + 1)
rhs.remove(name).foreach(addToCensus)
}
def incAppUseA(atom: Atom): Unit =
atom.asName.foreach(incAppUseN(_))
def incValUseN(name: Name): Unit = {
val currCount = census(name)
census(name) = currCount.copy(asValue = currCount.asValue + 1)
rhs.remove(name).foreach(addToCensus)
}
def incValUseA(atom: Atom): Unit =
atom.asName.foreach(incValUseN(_))
def addToCensus(tree: Tree): Unit = (tree: @unchecked) match {
case LetP(_, _, args, body) =>
args foreach incValUseA; addToCensus(body)
case LetC(cnts, body) =>
rhs ++= (cnts map { c => (c.name, c.body) }); addToCensus(body)
case LetF(funs, body) =>
rhs ++= (funs map { f => (f.name, f.body) }); addToCensus(body)
case AppC(cnt, args) =>
incAppUseN(cnt); args foreach incValUseA
case AppF(fun, retC, args) =>
incAppUseA(fun); incValUseN(retC); args foreach incValUseA
case If(_, args, thenC, elseC) =>
args foreach incValUseA; incValUseN(thenC); incValUseN(elseC)
case Halt(arg) =>
incValUseA(arg)
}
addToCensus(tree)
census.toMap
}
private def size(tree: Tree): Int = (tree: @unchecked) match {
case LetP(_, _, _, body) => size(body) + 1
case LetC(cs, body) => (cs map { c => size(c.body) }).sum + size(body)
case LetF(fs, body) => (fs map { f => size(f.body) }).sum + size(body)
case AppC(_, _) | AppF(_, _, _) | If(_, _, _, _) | Halt(_) => 1
}
protected val impure: ValuePrimitive => Boolean
protected val unstable: ValuePrimitive => Boolean
protected val blockAllocTag: PartialFunction[ValuePrimitive, Literal]
protected val blockTag: ValuePrimitive
protected val blockLength: ValuePrimitive
protected val identity: ValuePrimitive
protected val leftNeutral: Set[(Literal, ValuePrimitive)]
protected val rightNeutral: Set[(ValuePrimitive, Literal)]
protected val leftAbsorbing: Set[(Literal, ValuePrimitive)]
protected val rightAbsorbing: Set[(ValuePrimitive, Literal)]
protected val sameArgReduce: PartialFunction[(ValuePrimitive, Atom), Atom]
protected val sameArgReduceC: TestPrimitive => Boolean
protected val vEvaluator: PartialFunction[(ValuePrimitive, Seq[Literal]),
Literal]
protected val cEvaluator: PartialFunction[(TestPrimitive, Seq[Literal]),
Boolean]
}
object CPSOptimizerHigh extends CPSOptimizer(SymbolicCPSTreeModule)
with (SymbolicCPSTreeModule.Tree => SymbolicCPSTreeModule.Tree) {
import treeModule._
import L3Primitive._
def apply(tree: Tree): Tree =
rewrite(tree)
import scala.language.implicitConversions
private[this] implicit def l3IntToLit(i: L3Int): Literal = IntLit(i)
private[this] implicit def intToLit(i: Int): Literal = IntLit(L3Int(i))
protected val impure: ValuePrimitive => Boolean = ???
protected val unstable: ValuePrimitive => Boolean = ???
protected val blockAllocTag: PartialFunction[ValuePrimitive, Literal] = ???
protected val blockTag: ValuePrimitive = ???
protected val blockLength: ValuePrimitive = ???
protected val identity: ValuePrimitive = ???
protected val leftNeutral: Set[(Literal, ValuePrimitive)] = ???
protected val rightNeutral: Set[(ValuePrimitive, Literal)] = ???
protected val leftAbsorbing: Set[(Literal, ValuePrimitive)] = ???
protected val rightAbsorbing: Set[(ValuePrimitive, Literal)] = ???
protected val sameArgReduce: PartialFunction[(ValuePrimitive, Atom), Atom] =
???
protected val sameArgReduceC: PartialFunction[TestPrimitive, Boolean] = ???
protected val vEvaluator: PartialFunction[(ValuePrimitive, Seq[Literal]),
Literal] = ???
protected val cEvaluator: PartialFunction[(TestPrimitive, Seq[Literal]),
Boolean] = ???
}
object CPSOptimizerLow extends CPSOptimizer(SymbolicCPSTreeModuleLow)
with (SymbolicCPSTreeModuleLow.LetF => SymbolicCPSTreeModuleLow.LetF) {
import treeModule._
import CPSValuePrimitive._
import CPSTestPrimitive._
def apply(tree: LetF): LetF = rewrite(tree) match {
case tree @ LetF(_, _) => tree
case other => LetF(Seq(), other)
}
protected val impure: ValuePrimitive => Boolean =
Set(BlockSet, ByteRead, ByteWrite)
protected val unstable: ValuePrimitive => Boolean = {
case BlockAlloc(_) | BlockGet | ByteRead => true
case _ => false
}
protected val blockAllocTag: PartialFunction[ValuePrimitive, Literal] = {
case BlockAlloc(tag) => tag
}
protected val blockTag: ValuePrimitive = BlockTag
protected val blockLength: ValuePrimitive = BlockLength
protected val identity: ValuePrimitive = Id
protected val leftNeutral: Set[(Literal, ValuePrimitive)] =
Set((0, Add), (1, Mul), (~0, And), (0, Or), (0, XOr))
protected val rightNeutral: Set[(ValuePrimitive, Literal)] =
Set((Add, 0), (Sub, 0), (Mul, 1), (Div, 1),
(ShiftLeft, 0), (ShiftRight, 0),
(And, ~0), (Or, 0), (XOr, 0))
protected val leftAbsorbing: Set[(Literal, ValuePrimitive)] =
Set((0, Mul), (0, Div),
(0, ShiftLeft), (0, ShiftRight),
(0, And), (~0, Or))
protected val rightAbsorbing: Set[(ValuePrimitive, Literal)] =
Set((Mul, 0), (And, 0), (Or, ~0))
protected val sameArgReduce: PartialFunction[(ValuePrimitive, Atom), Atom] = {
case (And | Or, a) => a
case (Sub | Mod | XOr, _) => AtomL(0)
case (Div, _) => AtomL(1)
}
protected val sameArgReduceC: PartialFunction[TestPrimitive, Boolean] = {
case Le | Eq => true
case Lt => false
}
protected val vEvaluator: PartialFunction[(ValuePrimitive, Seq[Literal]),
Literal] = {
case (Add, Seq(x, y)) => x + y
case (Sub, Seq(x, y)) => x - y
case (Mul, Seq(x, y)) => x * y
case (Div, Seq(x, y)) if y.toInt != 0 => x / y
case (Mod, Seq(x, y)) if y.toInt != 0 => x % y
case (ShiftLeft, Seq(x, y)) => x << y
case (ShiftRight, Seq(x, y)) => x >> y
case (And, Seq(x, y)) => x & y
case (Or, Seq(x, y)) => x | y
case (XOr, Seq(x, y)) => x ^ y
}
protected val cEvaluator: PartialFunction[(TestPrimitive, Seq[Literal]),
Boolean] = {
case (Lt, Seq(x, y)) => x < y
case (Le, Seq(x, y)) => x <= y
case (Eq, Seq(x, y)) => x == y
}
}
package l3;
import scala.collection.mutable.{ Map => MutableMap }
import SymbolicCPSTreeModuleLow._
final class Statistics {
private[this] var funCount = 0
private[this] var cntCount = 0
private[this] val nodes = MutableMap[Class[_ <: Tree], Int]()
private[this] val tPrims = MutableMap[Class[_ <: l3.CPSTestPrimitive], Int]()
private[this] val vPrims = MutableMap[Class[_ <: l3.CPSValuePrimitive], Int]()
private[this] def inc[K](m: MutableMap[K, Int], k: K): Unit =
m.put(k, m.getOrElse(k, 0) + 1)
def nodeCount(cls: Class[_ <: Tree]): Int =
nodes.getOrElse(cls, 0)
def testPrimitiveCount(cls: Class[_ <: CPSTestPrimitive]): Int =
tPrims.getOrElse(cls, 0)
def valuePrimitiveCount(cls: Class[_ <: CPSValuePrimitive]): Int =
vPrims.getOrElse(cls, 0)
def functionsCount = funCount
def continuationsCount = cntCount
def log(tree: Tree): Unit = {
inc(nodes, tree.getClass)
tree match {
case LetP(_, p, _, _) =>
inc(vPrims, p.getClass)
case LetF(fs, _) =>
funCount += fs.length
case LetC(cs, _) =>
cntCount += cs.length
case If(p, _, _, _) =>
inc(tPrims, p.getClass)
case _ =>
// nothing to do
}
}
override def toString: String = {
val sb = new StringBuilder()
for ((label, map) <- Seq(("Nodes", nodes),
("Value primitives", vPrims),
("Test primitives", tPrims))
if (map.nonEmpty)) {
sb ++= (label + "\n" + "=" * label.length + "\n")
map.toList
.map { case (c, o) => c.getSimpleName.replace("$","") -> o }
.sortBy { case (_, o) => -o }
.foreach { case (c, o) => sb ++= "%,9d %s\n".format(o, c) }
sb.append("\n")
}
sb ++= " Functions defined: %,7d\n".format(funCount)
sb ++= "Continuations defined: %,7d\n".format(cntCount)
sb.toString
}
}
......@@ -57,3 +57,7 @@ object ExamplesTests2 extends TestSuite with ExamplesTests {
object ExamplesTests3 extends TestSuite with ExamplesTests {
val backEnd = L3Tester.backEnd3
}
object ExamplesTests4 extends TestSuite with ExamplesTests {
val backEnd = L3Tester.backEnd4
}
......@@ -56,4 +56,13 @@ object L3Tester {
andThen CPSHoister
andThen CPSInterpreterLow
)
val backEnd4 = (
CL3ToCPSTranslator
andThen CPSOptimizerHigh
andThen CPSValueRepresenter
andThen CPSHoister
andThen CPSOptimizerLow
andThen CPSInterpreterLow
)
}
......@@ -93,3 +93,7 @@ object SyntheticTests2 extends TestSuite with SyntheticTests {
object SyntheticTests3 extends TestSuite with SyntheticTests {
val backEnd = L3Tester.backEnd3
}
object SyntheticTests4 extends TestSuite with SyntheticTests {
val backEnd = L3Tester.backEnd4
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment