CPSInterpreter.scala 8.56 KB
Newer Older
Sapphie's avatar
Sapphie committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
package l3

import scala.annotation.tailrec
import scala.collection.mutable.{ Map => MutableMap }
import IO._

/**
  * A tree-based interpreter for the CPS languages.
  *
  * @author Michel Schinz <Michel.Schinz@epfl.ch>
  */

sealed abstract class CPSInterpreter[M <: CPSTreeModule](
  protected val treeModule: M,
  log: M#Tree => Unit = { _ : M#Tree => () }) {

  import treeModule._

  def apply(tree: Tree): TerminalPhaseResult =
    Right((eval(tree, emptyEnv), None))

  protected sealed trait Value
  protected case class FunV(retC: Name, args: Seq[Name], body: Tree, env: Env)
      extends Value
  protected case class CntV(args: Seq[Name], body: Tree, env: Env)
      extends Value

  protected type Env = PartialFunction[Name, Value]
  protected val emptyEnv: Env = Map.empty

  @tailrec
  private def eval(tree: Tree, env: Env): Int = {
    def resolve(a: Atom): Value = a match {
      case AtomN(n) => env(n)
      case AtomL(l) => evalLit(l)
    }

    log(tree)

    (tree: @unchecked) match {
      case LetP(name, prim, args, body) =>
        eval(body, Map(name->evalValuePrim(prim, args map resolve)) orElse env)

      case LetC(cnts, body) =>
        val recEnv = MutableMap[Name, Value]()
        val env1 = recEnv orElse env
        for (Cnt(name, args, body) <- cnts)
          recEnv(name) = CntV(args, body, env1)
        eval(body, env1)

      case LetF(funs, body) =>
        val recEnv = MutableMap[Name, Value]()
        val env1 = recEnv orElse env
        for (Fun(name, retC, args, body) <- funs)
          recEnv(name) = wrapFunV(FunV(retC, args, body, env1))
        eval(body, env1)

      case AppC(cnt, args) =>
        val CntV(cArgs, cBody, cEnv) = env(cnt)
        assume(cArgs.length == args.length)
        eval(cBody, (cArgs zip (args map resolve)).toMap orElse cEnv)

      case AppF(fun, retC, args) =>
        val FunV(fRetC, fArgs, fBody, fEnv) = unwrapFunV(resolve(fun))
Sapphie's avatar
Sapphie committed
65
66
67
68
69
        if (fArgs.length != args.length) {
          println(fun)
          println(fArgs.length)
          println(args.length)
        }
Sapphie's avatar
Sapphie committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        assume(fArgs.length == args.length)
        val rArgs = args map resolve
        val env1 = ((fRetC +: fArgs) zip (env(retC) +: rArgs)).toMap orElse fEnv
        eval(fBody, env1)

      case If(cond, args, thenC, elseC) =>
        val cnt = if (evalTestPrim(cond, args map resolve)) thenC else elseC
        val cntV = env(cnt).asInstanceOf[CntV]
        eval(cntV.body, cntV.env)

      case Halt(name) =>
        extractInt(resolve(name))
    }
  }

  protected def extractInt(v: Value): Int

  protected def wrapFunV(funV: FunV): Value
  protected def unwrapFunV(v: Value): FunV

  protected def evalLit(l: Literal): Value
  protected def evalValuePrim(p: ValuePrimitive, args: Seq[Value]): Value
  protected def evalTestPrim(p: TestPrimitive, args: Seq[Value]): Boolean
}

object CPSInterpreterHigh extends CPSInterpreter(SymbolicCPSTreeModule)
    with (SymbolicCPSTreeModule.Tree => TerminalPhaseResult) {
  import treeModule._
  import L3Primitive._

  private case class BlockV(tag: L3BlockTag, contents: Array[Value])
      extends Value
  private case class IntV(value: L3Int) extends Value
  private case class CharV(value: L3Char) extends Value
  private case class BooleanV(value: Boolean) extends Value
  private case object UnitV extends Value

  protected def extractInt(v: Value): Int = v match { case IntV(i) => i.toInt }

  protected def wrapFunV(funV: FunV): Value =
    BlockV(l3.BlockTag.Function.id, Array(funV))
  protected def unwrapFunV(v: Value): FunV = v match {
    case BlockV(id, Array(funV: FunV)) if id == l3.BlockTag.Function.id => funV
  }

  protected def evalLit(l: Literal): Value = l match {
    case IntLit(i) => IntV(i)
    case CharLit(c) => CharV(c)
    case BooleanLit(b) => BooleanV(b)
    case UnitLit => UnitV
  }

  protected def evalValuePrim(p: ValuePrimitive, args: Seq[Value]): Value =
    (p, args) match {
      case (BlockAlloc(t), Seq(IntV(i))) =>
        BlockV(t, Array.fill(i.toInt)(UnitV))
      case (BlockTag, Seq(BlockV(t, _))) => IntV(L3Int(t))
      case (BlockLength, Seq(BlockV(_, c))) => IntV(L3Int(c.length))
      case (BlockGet, Seq(BlockV(_, v), IntV(i))) => v(i.toInt)
      case (BlockSet, Seq(BlockV(_, v), IntV(i), o)) => v(i.toInt) = o; UnitV

      case (IntAdd, Seq(IntV(v1), IntV(v2))) => IntV(v1 + v2)
      case (IntSub, Seq(IntV(v1), IntV(v2))) => IntV(v1 - v2)
      case (IntMul, Seq(IntV(v1), IntV(v2))) => IntV(v1 * v2)
      case (IntDiv, Seq(IntV(v1), IntV(v2))) => IntV(v1 / v2)
      case (IntMod, Seq(IntV(v1), IntV(v2))) => IntV(v1 % v2)
      case (IntToChar, Seq(IntV(v))) => CharV(v.toInt)

      case (IntShiftLeft, Seq(IntV(v1), IntV(v2))) => IntV(v1 << v2)
      case (IntShiftRight, Seq(IntV(v1), IntV(v2))) => IntV(v1 >> v2)
      case (IntBitwiseAnd, Seq(IntV(v1), IntV(v2))) => IntV(v1 & v2)
      case (IntBitwiseOr, Seq(IntV(v1), IntV(v2))) => IntV(v1 | v2)
      case (IntBitwiseXOr, Seq(IntV(v1), IntV(v2))) => IntV(v1 ^ v2)

      case (ByteRead, Seq()) => IntV(L3Int(readByte()))
      case (ByteWrite, Seq(IntV(c))) => writeByte(c.toInt); UnitV
      case (CharToInt, Seq(CharV(c))) => IntV(L3Int(c))

      case (Id, Seq(v)) => v
    }

  protected def evalTestPrim(p: TestPrimitive, args: Seq[Value]): Boolean =
    (p, args) match {
      case (BlockP, Seq(BlockV(_, _))) => true
      case (BlockP, Seq(_)) => false

      case (IntP, Seq(IntV(_))) => true
      case (IntP, Seq(_)) => false
      case (IntLt, Seq(IntV(v1), IntV(v2))) => v1 < v2
      case (IntLe, Seq(IntV(v1), IntV(v2))) => v1 <= v2

      case (CharP, Seq(CharV(_))) => true
      case (CharP, Seq(_)) => false

      case (BoolP, Seq(BooleanV(_))) => true
      case (BoolP, Seq(_)) => false

      case (UnitP, Seq(UnitV)) => true
      case (UnitP, Seq(_)) => false

      case (Eq, Seq(v1, v2)) => v1 == v2
    }
}
Sapphie's avatar
Sapphie committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

class CPSInterpreterLow(log: SymbolicCPSTreeModuleLow.Tree => Unit)
    extends CPSInterpreter(SymbolicCPSTreeModuleLow, log)
    with (SymbolicCPSTreeModuleLow.Tree => TerminalPhaseResult) {
  import treeModule._
  import CPSValuePrimitive._
  import CPSTestPrimitive._
  import scala.language.implicitConversions

  protected case class BlockV(addr: Bits32,
                              tag: L3BlockTag,
                              contents: Array[Value])
      extends Value
  protected case class BitsV(value: Bits32) extends Value

  private var nextBlockAddr = 0
  protected def allocBlock(tag: L3BlockTag, contents: Array[Value]): BlockV = {
    val block = BlockV(nextBlockAddr, tag, contents)
    nextBlockAddr += 4
    block
  }

  private implicit def valueToBits(v: Value): Bits32 = v match {
    case BlockV(addr, _, _) => addr
    case BitsV(value)       => value
    case _: FunV | _: CntV  => sys.error(s"cannot convert $v to bits")
  }

  protected def extractInt(v: Value): Int = v match { case BitsV(i) => i }

  protected def wrapFunV(funV: FunV): Value = funV
  protected def unwrapFunV(v: Value): FunV = v.asInstanceOf[FunV]

  protected def evalLit(l: Literal): Value = BitsV(l)

  protected def evalValuePrim(p: ValuePrimitive, args: Seq[Value]): Value =
    (p, args) match {
      case (Add, Seq(v1, v2)) => BitsV(v1 + v2)
      case (Sub, Seq(v1, v2)) => BitsV(v1 - v2)
      case (Mul, Seq(v1, v2)) => BitsV(v1 * v2)
      case (Div, Seq(v1, v2)) => BitsV(v1 / v2)
      case (Mod, Seq(v1, v2)) => BitsV(v1 % v2)

      case (ShiftLeft, Seq(v1, v2)) => BitsV(v1 << v2)
      case (ShiftRight, Seq(v1, v2)) => BitsV(v1 >> v2)
      case (And, Seq(v1, v2)) => BitsV(v1 & v2)
      case (Or, Seq(v1, v2)) => BitsV(v1 | v2)
      case (XOr, Seq(v1, v2)) => BitsV(v1 ^ v2)

      case (ByteRead, Seq()) => BitsV(readByte())
      case (ByteWrite, Seq(c)) => writeByte(c); BitsV(0)

      case (BlockAlloc(t), Seq(BitsV(s))) =>
        allocBlock(t, Array.fill(s)(BitsV(0)))
      case (BlockTag, Seq(BlockV(_, t, _))) => BitsV(t)
      case (BlockLength, Seq(BlockV(_, _, c))) => BitsV(c.length)
229
      case (BlockGet, Seq(BlockV(_, _, c), BitsV(i))) => c(i)
Sapphie's avatar
Sapphie committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
      case (BlockSet, Seq(BlockV(_, _, c), BitsV(i), v)) =>
        c(i) = v; BitsV(0)

      case (Id, Seq(o)) => o
    }

  protected def evalTestPrim(p: TestPrimitive, args: Seq[Value]): Boolean =
    (p, args) match {
      case (Lt, Seq(v1, v2)) => v1 < v2
      case (Le, Seq(v1, v2)) => v1 <= v2
      case (Eq, Seq(v1, v2)) => v1 == v2
    }
}

object CPSInterpreterLow extends CPSInterpreterLow(_ => ())

object CPSInterpreterLowNoCC extends CPSInterpreterLow(_ => ()) {
  override protected def wrapFunV(funV: FunV): Value =
    allocBlock(BlockTag.Function.id, Array(funV))

  override protected def unwrapFunV(v: Value): FunV = v match {
    case BlockV(_, _, Array(funV: FunV)) => funV
  }
}