Commit 33e844bf authored by Sapphie's avatar Sapphie
Browse files

Import memory management skeleton

parent 418c23de
This source diff could not be displayed because it is too large. You can view the blob instead.
package l3
import java.io.PrintWriter
import java.nio.file.Path
import java.nio.file.Files.newBufferedWriter
import scala.util.Using.{resource => using}
import PCRelativeASMInstructionModule._
/**
* Assembly program writer. Dumps a program to a textual file, in
* which each line is composed of an encoded instruction represented
* as a 32-bit hexadecimal value, followed by a textual
* representation of the instruction.
*
* @author Michel Schinz <Michel.Schinz@epfl.ch>
*/
object ASMFileWriter extends (String => (Program => TerminalPhaseResult)) {
def apply(fileName: String): (Program => TerminalPhaseResult) = { program =>
using(new PrintWriter(newBufferedWriter(Path.of(fileName)))) { outWriter =>
for ((instr, index) <- program.zipWithIndex) {
val address = index * 4
val target = targetAddress(address, instr)
.map(a => f" [target address: ${a}%04x]")
.getOrElse("")
outWriter.println(
f"${encode(instr)}%08x | ${address}%04x: ${instr}%-20s${target}%s"
.trim)
}
}
Right((0, Some(s"Wrote assembly program to ${fileName}")))
}
private object Opcode extends Enumeration {
val ADD, SUB, MUL, DIV, MOD = Value
val LSL, LSR, AND, OR, XOR = Value
val JLT, JLE, JEQ, JNE, JI = Value
val CALL_NI, CALL_ND, CALL_TI, CALL_TD, RET, HALT = Value
val LDLO, LDHI, MOVE = Value
val RALO, BALO, BSIZ, BTAG, BGET, BSET = Value
val BREA, BWRI = Value
}
private def targetAddress(instrAddr: Bits32,
instr: Instruction): Option[Bits32] =
instr match {
case JLT(_, _, d) => Some(instrAddr + 4 * d)
case JLE(_, _, d) => Some(instrAddr + 4 * d)
case JEQ(_, _, d) => Some(instrAddr + 4 * d)
case JNE(_, _, d) => Some(instrAddr + 4 * d)
case JI(d) => Some(instrAddr + 4 * d)
case CALL_ND(d) => Some(instrAddr + 4 * d)
case CALL_TD(d) => Some(instrAddr + 4 * d)
case _ => None
}
private def encode(instr: Instruction): Int = instr match {
case ADD(a, b, c) => packRRR(Opcode.ADD, a, b, c)
case SUB(a, b, c) => packRRR(Opcode.SUB, a, b, c)
case MUL(a, b, c) => packRRR(Opcode.MUL, a, b, c)
case DIV(a, b, c) => packRRR(Opcode.DIV, a, b, c)
case MOD(a, b, c) => packRRR(Opcode.MOD, a, b, c)
case LSL(a, b, c) => packRRR(Opcode.LSL, a, b, c)
case LSR(a, b, c) => packRRR(Opcode.LSR, a, b, c)
case AND(a, b, c) => packRRR(Opcode.AND, a, b, c)
case OR(a, b, c) => packRRR(Opcode.OR, a, b, c)
case XOR(a, b, c) => packRRR(Opcode.XOR, a, b, c)
case JLT(a, b, d) => packRRD(Opcode.JLT, a, b, d)
case JLE(a, b, d) => packRRD(Opcode.JLE, a, b, d)
case JEQ(a, b, d) => packRRD(Opcode.JEQ, a, b, d)
case JNE(a, b, d) => packRRD(Opcode.JNE, a, b, d)
case JI(d) => packD(Opcode.JI, d)
case CALL_NI(r) => packR(Opcode.CALL_NI, r)
case CALL_ND(d) => packD(Opcode.CALL_ND, d)
case CALL_TI(r) => packR(Opcode.CALL_TI, r)
case CALL_TD(d) => packD(Opcode.CALL_TD, d)
case RET(r) => packR(Opcode.RET, r)
case HALT(r) => packR(Opcode.HALT, r)
case LDLO(a, s) =>
pack(encOp(Opcode.LDLO), encSInt(s, 19), encReg(a))
case LDHI(a, u) =>
pack(encOp(Opcode.LDHI), pad(3), encUInt(u, 16), encReg(a))
case MOVE(a, b) => packRR(Opcode.MOVE, a, b)
case RALO(l, o) =>
pack(encOp(Opcode.RALO), pad(11), encUInt(o, 8), encUInt(l, 8))
case BALO(a, b, t) =>
pack(encOp(Opcode.BALO), pad(3), encUInt(t, 8), encReg(b), encReg(a))
case BSIZ(a, b) => packRR(Opcode.BSIZ, a, b)
case BTAG(a, b) => packRR(Opcode.BTAG, a, b)
case BGET(a, b, c) => packRRR(Opcode.BGET, a, b, c)
case BSET(a, b, c) => packRRR(Opcode.BSET, a, b, c)
case BREA(a) => packR(Opcode.BREA, a)
case BWRI(a) => packR(Opcode.BWRI, a)
}
private type BitField = (Int, Int)
private def packD(opcode: Opcode.Value, d: Int): Int =
pack(encOp(opcode), encSInt(d, 27))
private def packR(opcode: Opcode.Value, a: ASMRegister): Int =
pack(encOp(opcode), pad(19), encReg(a))
private def packRR(opcode: Opcode.Value,
a: ASMRegister, b: ASMRegister): Int =
pack(encOp(opcode), pad(11), encReg(b), encReg(a))
private def packRRR(opcode: Opcode.Value,
a: ASMRegister, b: ASMRegister, c: ASMRegister): Int =
pack(encOp(opcode), pad(3), encReg(c), encReg(b), encReg(a))
private def packRRD(opcode: Opcode.Value,
a: ASMRegister, b: ASMRegister, d: Int): Int =
pack(encOp(opcode), encSInt(d, 11), encReg(b), encReg(a))
private def encOp(opcode: Opcode.Value): BitField =
encUInt(opcode.id, 5)
private def encReg(r: ASMRegister): BitField = (r: @unchecked) match {
case ASMRegister(ASMRegisterFile.Cb, i) => encUInt(0 * 32 + i, 8)
case ASMRegister(ASMRegisterFile.Ib, i) => encUInt(1 * 32 + (i - 4), 8)
case ASMRegister(ASMRegisterFile.Lb, i) => encUInt(2 * 32 + i, 8)
case ASMRegister(ASMRegisterFile.Ob, i) => encUInt(7 * 32 + (i - 4), 8)
}
private def encUInt(i: Int, len: Int): BitField = {
require(0 <= i && i < (1 << len))
(i, len)
}
private def encSInt(i: Int, len: Int): BitField = {
require(-(1 << (len - 1)) <= i && i < (1 << (len - 1)))
(i & ((1 << len) - 1), len)
}
private def pad(len: Int): BitField =
encUInt(0, len)
private def pack(values: BitField*): Int = {
var packed: Int = 0
for ((value, length) <- values)
packed = (packed << length) | value
packed
}
}
package l3
/**
* A module for ASM instructions.
*
* @author Michel Schinz <Michel.Schinz@epfl.ch>
*/
trait ASMInstructionModule {
type Label
type Constant
sealed abstract class Instruction
case class ADD(a: ASMRegister, b: ASMRegister, c: ASMRegister)
extends Instruction
case class SUB(a: ASMRegister, b: ASMRegister, c: ASMRegister)
extends Instruction
case class MUL(a: ASMRegister, b: ASMRegister, c: ASMRegister)
extends Instruction
case class DIV(a: ASMRegister, b: ASMRegister, c: ASMRegister)
extends Instruction
case class MOD(a: ASMRegister, b: ASMRegister, c: ASMRegister)
extends Instruction
case class LSL(a: ASMRegister, b: ASMRegister, c: ASMRegister)
extends Instruction
case class LSR(a: ASMRegister, b: ASMRegister, c: ASMRegister)
extends Instruction
case class AND(a: ASMRegister, b: ASMRegister, c: ASMRegister)
extends Instruction
case class OR(a: ASMRegister, b: ASMRegister, c: ASMRegister)
extends Instruction
case class XOR(a: ASMRegister, b: ASMRegister, c: ASMRegister)
extends Instruction
case class JLT(a: ASMRegister, b: ASMRegister, d: Constant)
extends Instruction
case class JLE(a: ASMRegister, b: ASMRegister, d: Constant)
extends Instruction
case class JEQ(a: ASMRegister, b: ASMRegister, d: Constant)
extends Instruction
case class JNE(a: ASMRegister, b: ASMRegister, d: Constant)
extends Instruction
case class JI(d: Label) extends Instruction
case class CALL_NI(r: ASMRegister) extends Instruction
case class CALL_ND(d: Label) extends Instruction
case class CALL_TI(r: ASMRegister) extends Instruction
case class CALL_TD(d: Label) extends Instruction
case class RET(r: ASMRegister) extends Instruction
case class HALT(r: ASMRegister) extends Instruction
case class LDLO(a: ASMRegister, s: Constant) extends Instruction
case class LDHI(a: ASMRegister, u: Int) extends Instruction
case class MOVE(a: ASMRegister, b: ASMRegister) extends Instruction
case class RALO(l: Int, o: Int) extends Instruction
case class BALO(a: ASMRegister, b: ASMRegister, t: L3BlockTag)
extends Instruction
case class BSIZ(a: ASMRegister, b: ASMRegister) extends Instruction
case class BTAG(a: ASMRegister, b: ASMRegister) extends Instruction
case class BGET(a: ASMRegister, b: ASMRegister, c: ASMRegister)
extends Instruction
case class BSET(a: ASMRegister, b: ASMRegister, c: ASMRegister)
extends Instruction
case class BREA(a: ASMRegister) extends Instruction
case class BWRI(a: ASMRegister) extends Instruction
sealed case class LabeledInstruction(labels: Set[Label],
instruction: Instruction) {
override def toString: String =
labels.map(l => s"$l:\n").mkString + " " + instruction
}
def nl(i: Instruction): LabeledInstruction =
LabeledInstruction(Set.empty, i)
def labeled(labels: Set[Label], code: LabeledProgram): LabeledProgram =
code match {
case Seq(LabeledInstruction(labels1, i1), rest @ _*) =>
LabeledInstruction(labels1 ++ labels, i1) +: rest
}
def labeled(label: Label, code: LabeledProgram): LabeledProgram =
labeled(Set(label), code)
type Program = Seq[Instruction]
type LabeledProgram = Seq[LabeledInstruction]
}
/**
* A module for ASM instructions labeled explicitly by a symbol.
*/
object LabeledASMInstructionModule extends ASMInstructionModule {
type Label = Symbol
sealed trait Constant {
override def toString: String = this match {
case LabelC(l) => l.toString
case IntC(v) => v.toString
}
}
case class LabelC(l: Label) extends Constant
case class IntC(v: Int) extends Constant
}
/**
* A module for ASM instructions labeled implicitly by their
* position.
*/
object PCRelativeASMInstructionModule extends ASMInstructionModule {
type Label = Int
type Constant = Int
}
package l3
import PCRelativeASMInstructionModule._
import IO._
/**
* An interpreter for the ASM language.
*
* @author Michel Schinz <Michel.Schinz@epfl.ch>
*/
object ASMInterpreter extends (Seq[Instruction] => TerminalPhaseResult) {
def apply(program: Seq[Instruction]): TerminalPhaseResult =
try {
Right((interpret(program.toArray), None))
} catch {
case e: EvalError => Left(e.msg)
}
private class EvalError(val msg: String) extends Exception
private def interpret(program: Array[Instruction]): Int = {
import scala.language.implicitConversions
import ASMRegisterFile.{ Ib, Lb, Ob }
var PC: Bits32 = 0
def error(msg: String): Nothing =
throw new EvalError(s"${msg} (at PC = ${PC})")
implicit def bitsToValue(i: Bits32): Value = BitsV(i)
implicit def valueToBits(v: Value): Bits32 = v match {
case BitsV(i) => i
case BlockV(a, _, _) => a
case _ => error(s"expected bits, found $v")
}
implicit def valueToBlock(v: Value): BlockV = v match {
case b: BlockV => b
case _ => error(s"expected block, found $v")
}
trait Value
case class BitsV(value: Bits32) extends Value
case class BlockV(addr: Bits32, tag: L3BlockTag, contents: Array[Value])
extends Value
case object UndefV extends Value
var nextBlockAddr = 0
def allocBlock(tag: L3BlockTag, size: Bits32): BlockV = {
nextBlockAddr += 4
BlockV(nextBlockAddr, tag, Array.fill(size)(UndefV))
}
val I0 = ASMRegisterFile.in(0)
val I1 = ASMRegisterFile.in(1)
val I2 = ASMRegisterFile.in(2)
val I3 = ASMRegisterFile.in(3)
val O0 = ASMRegisterFile.out(0)
val O1 = ASMRegisterFile.out(1)
val O2 = ASMRegisterFile.out(2)
val O3 = ASMRegisterFile.out(3)
val O4 = ASMRegisterFile.out(4)
object R {
private val Cb: Value = {
val constBlock = allocBlock(BlockTag.RegisterFrame.id, 32)
for (i <- 0 until 32)
constBlock.contents(i) = i
constBlock
}
private var Ib: Value = UndefV
private var Lb: Value = UndefV
private var Ob: Value = UndefV
private def checkedContents(r: ASMRegister): Array[Value] = {
val contents = this(r.base).contents
if (0 <= r.index && r.index <= contents.length)
contents
else
error(s"unmapped register: $r")
}
def apply(r: ASMBaseRegister): Value = (r: @unchecked) match {
case ASMRegisterFile.Cb => Cb
case ASMRegisterFile.Ib => Ib
case ASMRegisterFile.Lb => Lb
case ASMRegisterFile.Ob => Ob
}
def apply(r: ASMRegister): Value =
checkedContents(r)(r.index)
def update(r: ASMRegister, v: Value): Unit =
checkedContents(r)(r.index) = v
def update(r: ASMBaseRegister, v: Value): Unit = (r: @unchecked) match {
case ASMRegisterFile.Ib => Ib = v
case ASMRegisterFile.Lb => Lb = v
case ASMRegisterFile.Ob => Ob = v
}
}
def call(targetPc: Bits32,
savedIb: Value,
savedLb: Value,
savedOb: Value,
retPc: Bits32): Unit = {
// save caller context
R(O0) = savedIb
R(O1) = savedLb
R(O2) = savedOb
R(O3) = retPc
// initialize callee context
R(Ib) = R(Ob)
R(Lb) = UndefV
R(Ob) = UndefV
PC = targetPc
}
def ret(retValue: Value): Unit = {
PC = R(I3)
R(Ob) = R(I2)
R(Lb) = R(I1)
R(Ib) = R(I0)
R(O4) = retValue
}
while (true) {
program(PC) match {
case ADD(a, b, c) =>
R(a) = R(b) + R(c)
PC += 1
case SUB(a, b, c) =>
R(a) = R(b) - R(c)
PC += 1
case MUL(a, b, c) =>
R(a) = R(b) * R(c)
PC += 1
case DIV(a, b, c) =>
R(a) = R(b) / R(c)
PC += 1
case MOD(a, b, c) =>
R(a) = R(b) % R(c)
PC += 1
case LSL(a, b, c) =>
R(a) = R(b) << R(c)
PC += 1
case LSR(a, b, c) =>
R(a) = R(b) >> R(c)
PC += 1
case AND(a, b, c) =>
R(a) = R(b) & R(c)
PC += 1
case OR(a, b, c) =>
R(a) = R(b) | R(c)
PC += 1
case XOR(a, b, c) =>
R(a) = R(b) ^ R(c)
PC += 1
case JLT(a, b, d) =>
PC += (if (R(a) < R(b)) d else 1)
case JLE(a, b, d) =>
PC += (if (R(a) <= R(b)) d else 1)
case JEQ(a, b, d) =>
PC += (if (R(a) == R(b)) d else 1)
case JNE(a, b, d) =>
PC += (if (R(a) != R(b)) d else 1)
case JI(d) =>
PC += d
case CALL_NI(a) =>
call(R(a) >> 2, R(Ib), R(Lb), R(Ob), PC + 1)
case CALL_ND(d) =>
call(PC + d, R(Ib), R(Lb), R(Ob), PC + 1)
case CALL_TI(a) =>
call(R(a) >> 2, R(I0), R(I1), R(I2), R(I3))
case CALL_TD(d) =>
call(PC + d, R(I0), R(I1), R(I2), R(I3))
case RET(a) =>
ret(R(a))
case HALT(a) =>
return R(a)
case LDLO(a, s) =>
R(a) = s
PC += 1
case LDHI(a, u) =>
R(a) = (u << 16) | (R(a) & 0xFFFF)
PC += 1
case MOVE(a, b) =>
R(a) = R(b)
PC += 1
case RALO(l, o) =>
if (l > 0) R(Lb) = allocBlock(BlockTag.RegisterFrame.id, l)
if (o > 0) R(Ob) = allocBlock(BlockTag.RegisterFrame.id, o)
PC += 1
case BALO(a, b, t) =>
R(a) = allocBlock(t, R(b))
PC += 1
case BSIZ(a, b) =>
R(a) = R(b).contents.length
PC += 1
case BTAG(a, b) =>
R(a) = R(b).tag
PC += 1
case BGET(a, b, c) =>
R(a) = R(b).contents(R(c))
PC += 1
case BSET(a, b, c) =>
R(b).contents(R(c)) = R(a)
PC += 1
case BREA(a) =>
R(a) = readByte()
PC += 1
case BWRI(a) =>
writeByte(R(a))
PC += 1
}
}
0 // should not be needed
}
}
package l3
import l3.{ LabeledASMInstructionModule => L }
import l3.{ PCRelativeASMInstructionModule => R }
/**
* Label resolution for the ASM language. Translates a program in
* which addresses are represented as symbolic labels to one where
* they are represented as PC-relative (or absolute, in some cases)
* addresses.
*
* @author Michel Schinz <Michel.Schinz@epfl.ch>
*/
object ASMLabelResolver extends (L.LabeledProgram => R.Program) {
def apply(labeledProgram: L.LabeledProgram): R.Program =
resolve(fixedPoint(labeledProgram)(expand))
private def expand(program: L.LabeledProgram): L.LabeledProgram = {
val indexedProgram = program.zipWithIndex
val labelAddr = labelMap(indexedProgram)
(for ((labeledInstr, addr) <- indexedProgram) yield {
def delta(l: L.Label): Int = labelAddr(l) - addr
val expanded = labeledInstr.instruction match {
case L.JLT(a, b, L.LabelC(l)) if !fitsInNSignedBits(11)(delta(l)) =>
Seq(L.nl(L.JLE(b, a, L.IntC(2))), L.nl(L.JI(l)))
case L.JLE(a, b, L.LabelC(l)) if !fitsInNSignedBits(11)(delta(l)) =>
Seq(L.nl(L.JLT(b, a, L.IntC(2))), L.nl(L.JI(l)))
case L.JEQ(a, b, L.LabelC(l)) if !fitsInNSignedBits(11)(delta(l)) =>
Seq(L.nl(L.JNE(a, b, L.IntC(2))), L.nl(L.JI(l)))
case L.JNE(a, b, L.LabelC(l)) if !fitsInNSignedBits(11)(delta(l)) =>
Seq(L.nl(L.JEQ(a, b, L.IntC(2))), L.nl(L.JI(l)))
// TODO: LDLO
case _ => Seq(labeledInstr)
}
L.labeled(labeledInstr.labels, expanded)
}).flatten
}
private def resolve(program: L.LabeledProgram): R.Program = {
val indexedProgram = program.zipWithIndex
val labelAddr = labelMap(indexedProgram)
for ((labeledInstr, addr) <- indexedProgram) yield {
def delta(l: L.Label): Int = labelAddr(l) - addr
labeledInstr.instruction match {
case L.ADD(a, b, c) => R.ADD(a, b, c)
case L.SUB(a, b, c) => R.SUB(a, b, c)
case L.MUL(a, b, c) => R.MUL(a, b, c)
case L.DIV(a, b, c) => R.DIV(a, b, c)
case L.MOD(a, b, c) => R.MOD(a, b, c)
case L.LSL(a, b, c) => R.LSL(a, b, c)