Skip to content
Snippets Groups Projects
CPSInterpreter.scala 8.56 KiB
package l3

import scala.annotation.tailrec
import scala.collection.mutable.{ Map => MutableMap }
import IO._

/**
  * A tree-based interpreter for the CPS languages.
  *
  * @author Michel Schinz <Michel.Schinz@epfl.ch>
  */

sealed abstract class CPSInterpreter[M <: CPSTreeModule](
  protected val treeModule: M,
  log: M#Tree => Unit = { _ : M#Tree => () }) {

  import treeModule._

  def apply(tree: Tree): TerminalPhaseResult =
    Right((eval(tree, emptyEnv), None))

  protected sealed trait Value
  protected case class FunV(retC: Name, args: Seq[Name], body: Tree, env: Env)
      extends Value
  protected case class CntV(args: Seq[Name], body: Tree, env: Env)
      extends Value

  protected type Env = PartialFunction[Name, Value]
  protected val emptyEnv: Env = Map.empty

  @tailrec
  private def eval(tree: Tree, env: Env): Int = {
    def resolve(a: Atom): Value = a match {
      case AtomN(n) => env(n)
      case AtomL(l) => evalLit(l)
    }

    log(tree)

    (tree: @unchecked) match {
      case LetP(name, prim, args, body) =>
        eval(body, Map(name->evalValuePrim(prim, args map resolve)) orElse env)

      case LetC(cnts, body) =>
        val recEnv = MutableMap[Name, Value]()
        val env1 = recEnv orElse env
        for (Cnt(name, args, body) <- cnts)
          recEnv(name) = CntV(args, body, env1)
        eval(body, env1)

      case LetF(funs, body) =>
        val recEnv = MutableMap[Name, Value]()
        val env1 = recEnv orElse env
        for (Fun(name, retC, args, body) <- funs)
          recEnv(name) = wrapFunV(FunV(retC, args, body, env1))
        eval(body, env1)

      case AppC(cnt, args) =>
        val CntV(cArgs, cBody, cEnv) = env(cnt)
        assume(cArgs.length == args.length)
        eval(cBody, (cArgs zip (args map resolve)).toMap orElse cEnv)

      case AppF(fun, retC, args) =>
        val FunV(fRetC, fArgs, fBody, fEnv) = unwrapFunV(resolve(fun))
        if (fArgs.length != args.length) {
          println(fun)
          println(fArgs.length)
          println(args.length)
        }
        assume(fArgs.length == args.length)
        val rArgs = args map resolve
        val env1 = ((fRetC +: fArgs) zip (env(retC) +: rArgs)).toMap orElse fEnv
        eval(fBody, env1)

      case If(cond, args, thenC, elseC) =>
        val cnt = if (evalTestPrim(cond, args map resolve)) thenC else elseC
        val cntV = env(cnt).asInstanceOf[CntV]
        eval(cntV.body, cntV.env)

      case Halt(name) =>
        extractInt(resolve(name))
    }
  }

  protected def extractInt(v: Value): Int

  protected def wrapFunV(funV: FunV): Value
  protected def unwrapFunV(v: Value): FunV

  protected def evalLit(l: Literal): Value
  protected def evalValuePrim(p: ValuePrimitive, args: Seq[Value]): Value
  protected def evalTestPrim(p: TestPrimitive, args: Seq[Value]): Boolean
}

object CPSInterpreterHigh extends CPSInterpreter(SymbolicCPSTreeModule)
    with (SymbolicCPSTreeModule.Tree => TerminalPhaseResult) {
  import treeModule._
  import L3Primitive._

  private case class BlockV(tag: L3BlockTag, contents: Array[Value])
      extends Value
  private case class IntV(value: L3Int) extends Value
  private case class CharV(value: L3Char) extends Value
  private case class BooleanV(value: Boolean) extends Value
  private case object UnitV extends Value

  protected def extractInt(v: Value): Int = v match { case IntV(i) => i.toInt }

  protected def wrapFunV(funV: FunV): Value =
    BlockV(l3.BlockTag.Function.id, Array(funV))
  protected def unwrapFunV(v: Value): FunV = v match {
    case BlockV(id, Array(funV: FunV)) if id == l3.BlockTag.Function.id => funV
  }

  protected def evalLit(l: Literal): Value = l match {
    case IntLit(i) => IntV(i)
    case CharLit(c) => CharV(c)
    case BooleanLit(b) => BooleanV(b)
    case UnitLit => UnitV
  }

  protected def evalValuePrim(p: ValuePrimitive, args: Seq[Value]): Value =
    (p, args) match {
      case (BlockAlloc(t), Seq(IntV(i))) =>
        BlockV(t, Array.fill(i.toInt)(UnitV))
      case (BlockTag, Seq(BlockV(t, _))) => IntV(L3Int(t))
      case (BlockLength, Seq(BlockV(_, c))) => IntV(L3Int(c.length))
      case (BlockGet, Seq(BlockV(_, v), IntV(i))) => v(i.toInt)
      case (BlockSet, Seq(BlockV(_, v), IntV(i), o)) => v(i.toInt) = o; UnitV

      case (IntAdd, Seq(IntV(v1), IntV(v2))) => IntV(v1 + v2)
      case (IntSub, Seq(IntV(v1), IntV(v2))) => IntV(v1 - v2)
      case (IntMul, Seq(IntV(v1), IntV(v2))) => IntV(v1 * v2)
      case (IntDiv, Seq(IntV(v1), IntV(v2))) => IntV(v1 / v2)
      case (IntMod, Seq(IntV(v1), IntV(v2))) => IntV(v1 % v2)
      case (IntToChar, Seq(IntV(v))) => CharV(v.toInt)

      case (IntShiftLeft, Seq(IntV(v1), IntV(v2))) => IntV(v1 << v2)
      case (IntShiftRight, Seq(IntV(v1), IntV(v2))) => IntV(v1 >> v2)
      case (IntBitwiseAnd, Seq(IntV(v1), IntV(v2))) => IntV(v1 & v2)
      case (IntBitwiseOr, Seq(IntV(v1), IntV(v2))) => IntV(v1 | v2)
      case (IntBitwiseXOr, Seq(IntV(v1), IntV(v2))) => IntV(v1 ^ v2)

      case (ByteRead, Seq()) => IntV(L3Int(readByte()))
      case (ByteWrite, Seq(IntV(c))) => writeByte(c.toInt); UnitV
      case (CharToInt, Seq(CharV(c))) => IntV(L3Int(c))

      case (Id, Seq(v)) => v
    }

  protected def evalTestPrim(p: TestPrimitive, args: Seq[Value]): Boolean =
    (p, args) match {
      case (BlockP, Seq(BlockV(_, _))) => true
      case (BlockP, Seq(_)) => false

      case (IntP, Seq(IntV(_))) => true
      case (IntP, Seq(_)) => false
      case (IntLt, Seq(IntV(v1), IntV(v2))) => v1 < v2
      case (IntLe, Seq(IntV(v1), IntV(v2))) => v1 <= v2

      case (CharP, Seq(CharV(_))) => true
      case (CharP, Seq(_)) => false

      case (BoolP, Seq(BooleanV(_))) => true
      case (BoolP, Seq(_)) => false

      case (UnitP, Seq(UnitV)) => true
      case (UnitP, Seq(_)) => false

      case (Eq, Seq(v1, v2)) => v1 == v2
    }
}

class CPSInterpreterLow(log: SymbolicCPSTreeModuleLow.Tree => Unit)
    extends CPSInterpreter(SymbolicCPSTreeModuleLow, log)
    with (SymbolicCPSTreeModuleLow.Tree => TerminalPhaseResult) {
  import treeModule._
  import CPSValuePrimitive._
  import CPSTestPrimitive._
  import scala.language.implicitConversions

  protected case class BlockV(addr: Bits32,
                              tag: L3BlockTag,
                              contents: Array[Value])
      extends Value
  protected case class BitsV(value: Bits32) extends Value

  private var nextBlockAddr = 0
  protected def allocBlock(tag: L3BlockTag, contents: Array[Value]): BlockV = {
    val block = BlockV(nextBlockAddr, tag, contents)
    nextBlockAddr += 4
    block
  }

  private implicit def valueToBits(v: Value): Bits32 = v match {
    case BlockV(addr, _, _) => addr
    case BitsV(value)       => value
    case _: FunV | _: CntV  => sys.error(s"cannot convert $v to bits")
  }

  protected def extractInt(v: Value): Int = v match { case BitsV(i) => i }

  protected def wrapFunV(funV: FunV): Value = funV
  protected def unwrapFunV(v: Value): FunV = v.asInstanceOf[FunV]

  protected def evalLit(l: Literal): Value = BitsV(l)

  protected def evalValuePrim(p: ValuePrimitive, args: Seq[Value]): Value =
    (p, args) match {
      case (Add, Seq(v1, v2)) => BitsV(v1 + v2)
      case (Sub, Seq(v1, v2)) => BitsV(v1 - v2)
      case (Mul, Seq(v1, v2)) => BitsV(v1 * v2)
      case (Div, Seq(v1, v2)) => BitsV(v1 / v2)
      case (Mod, Seq(v1, v2)) => BitsV(v1 % v2)

      case (ShiftLeft, Seq(v1, v2)) => BitsV(v1 << v2)
      case (ShiftRight, Seq(v1, v2)) => BitsV(v1 >> v2)
      case (And, Seq(v1, v2)) => BitsV(v1 & v2)
      case (Or, Seq(v1, v2)) => BitsV(v1 | v2)
      case (XOr, Seq(v1, v2)) => BitsV(v1 ^ v2)

      case (ByteRead, Seq()) => BitsV(readByte())
      case (ByteWrite, Seq(c)) => writeByte(c); BitsV(0)

      case (BlockAlloc(t), Seq(BitsV(s))) =>
        allocBlock(t, Array.fill(s)(BitsV(0)))
      case (BlockTag, Seq(BlockV(_, t, _))) => BitsV(t)
      case (BlockLength, Seq(BlockV(_, _, c))) => BitsV(c.length)
      case (BlockGet, Seq(BlockV(_, _, c), BitsV(i))) => c(i)
      case (BlockSet, Seq(BlockV(_, _, c), BitsV(i), v)) =>
        c(i) = v; BitsV(0)

      case (Id, Seq(o)) => o
    }

  protected def evalTestPrim(p: TestPrimitive, args: Seq[Value]): Boolean =
    (p, args) match {
      case (Lt, Seq(v1, v2)) => v1 < v2
      case (Le, Seq(v1, v2)) => v1 <= v2
      case (Eq, Seq(v1, v2)) => v1 == v2
    }
}

object CPSInterpreterLow extends CPSInterpreterLow(_ => ())

object CPSInterpreterLowNoCC extends CPSInterpreterLow(_ => ()) {
  override protected def wrapFunV(funV: FunV): Value =
    allocBlock(BlockTag.Function.id, Array(funV))

  override protected def unwrapFunV(v: Value): FunV = v match {
    case BlockV(_, _, Array(funV: FunV)) => funV
  }
}