Skip to content
Snippets Groups Projects
Interpreter.scala 6.28 KiB
Newer Older
package amyc
package interpreter

import utils._
import ast.SymbolicTreeModule._
import ast.Identifier
import analyzer.SymbolTable

// An interpreter for Amy programs, implemented in Scala
object Interpreter extends Pipeline[(Program, SymbolTable), Unit] {

  // A class that represents a value computed by interpreting an expression
  abstract class Value {
    def asInt: Int = this.asInstanceOf[IntValue].i
    def asBoolean: Boolean = this.asInstanceOf[BooleanValue].b
    def asString: String = this.asInstanceOf[StringValue].s

    override def toString: String = this match {
      case IntValue(i) => i.toString
      case BooleanValue(b) => b.toString
      case StringValue(s) => s
      case UnitValue => "()"
      case CaseClassValue(constructor, args) =>
        constructor.name + "(" + args.map(_.toString).mkString(", ") + ")"
    }
  }
  case class IntValue(i: Int) extends Value
  case class BooleanValue(b: Boolean) extends Value
  case class StringValue(s: String) extends Value
  case object UnitValue extends Value
  case class CaseClassValue(constructor: Identifier, args: List[Value]) extends Value

  def run(ctx: Context)(v: (Program, SymbolTable)): Unit = {
    val (program, table) = v

    // These built-in functions do not have an Amy implementation in the program,
    // instead their implementation is encoded in this map
    val builtIns: Map[(String, String), (List[Value]) => Value] = Map(
      ("Std", "printInt")    -> { args => println(args.head.asInt); UnitValue },
      ("Std", "printString") -> { args => println(args.head.asString); UnitValue },
      ("Std", "readString")  -> { args => StringValue(scala.io.StdIn.readLine()) },
      ("Std", "readInt")     -> { args =>
        val input = scala.io.StdIn.readLine()
        try {
          IntValue(input.toInt)
        } catch {
          case ne: NumberFormatException =>
            ctx.reporter.fatal(s"""Could not parse "$input" to Int""")
        }
      },
      ("Std", "intToString")   -> { args => StringValue(args.head.asInt.toString) },
      ("Std", "digitToString") -> { args => StringValue(args.head.asInt.toString) }
    )

    // Utility functions to interface with the symbol table.
    def isConstructor(name: Identifier) = table.getConstructor(name).isDefined
    def findFunctionOwner(functionName: Identifier) = table.getFunction(functionName).get.owner.name
    def findFunction(owner: String, name: String) = {
      program.modules.find(_.name.name == owner).get.defs.collectFirst {
        case fd@FunDef(fn, _, _, _) if fn.name == name => fd
      }.get
    }

    // Interprets a function, using evaluations for local variables contained in 'locals'
 
    // TODO: Complete all missing cases. Look at the given ones for guidance. 
    def interpret(expr: Expr)(implicit locals: Map[Identifier, Value]): Value = {
      expr match {
        case Variable(name) =>
            ???
        case IntLiteral(i) =>
            ???
        case BooleanLiteral(b) =>
            ???
        case StringLiteral(s) =>
            ???
        case UnitLiteral() =>
            ???
        case Plus(lhs, rhs) =>
          IntValue(interpret(lhs).asInt + interpret(rhs).asInt)
        case Minus(lhs, rhs) =>
            ???
        case Times(lhs, rhs) =>
            ???
        case Div(lhs, rhs) =>
            ???
        case Mod(lhs, rhs) =>
            ???
        case LessThan(lhs, rhs) =>
            ???
        case LessEquals(lhs, rhs) =>
            ???
        case And(lhs, rhs) =>
            ???
        case Or(lhs, rhs) =>
            ???
        case Equals(lhs, rhs) =>
            ??? // Hint: Take care to implement Amy equality semantics
        case Concat(lhs, rhs) =>
            ???
        case Not(e) =>
            ???
        case Neg(e) =>
            ???
        case Call(qname, args) =>
            ???
            // Hint: Check if it is a call to a constructor first,
            //       then if it is a built-in function (otherwise it is a normal function).
            //       Use the helper methods provided above to retrieve information from the symbol table.
            //       Think how locals should be modified.
        case Sequence(e1, e2) =>
            ???
        case Let(df, value, body) =>
            ???
        case Ite(cond, thenn, elze) =>
            ???
        case Match(scrut, cases) =>
            ???
            // Hint: We give you a skeleton to implement pattern matching
            //       and the main body of the implementation
          val evS = interpret(scrut)

          // None = pattern does not match
            // Returns a list of pairs id -> value,
            // where id has been bound to value within the pattern.
            // Returns None when the pattern fails to match.
            // Note: Only works on well typed patterns (which have been ensured by the type checker).
          def matchesPattern(v: Value, pat: Pattern): Option[List[(Identifier, Value)]] = {
            ((v, pat): @unchecked) match {
              case (_, WildcardPattern()) => 
                  ???
              case (_, IdPattern(name)) => 
                  ???
              case (IntValue(i1), LiteralPattern(IntLiteral(i2))) =>
                  ???
              case (BooleanValue(b1), LiteralPattern(BooleanLiteral(b2))) =>
                  ???
              case (StringValue(_), LiteralPattern(StringLiteral(_))) =>
                  ???
              case (UnitValue, LiteralPattern(UnitLiteral())) =>
                  ???
              case (CaseClassValue(con1, realArgs), CaseClassPattern(con2, formalArgs)) =>
                  ???
            }
          }
          // Main "loop" of the implementation: Go through every case,
          // check if the pattern matches, and if so return the evaluation of the case expression
          cases.to(LazyList).map(matchCase => 
            val MatchCase(pat, rhs) = matchCase
            (rhs, matchesPattern(evS, pat))
          ).find(_._2.isDefined) match {
            case Some((rhs, Some(moreLocals))) =>
              interpret(rhs)(locals ++ moreLocals)
            case _ =>
              // No case matched
              ctx.reporter.fatal(s"Match error: ${evS.toString}@${scrut.position}")
          }

        case Error(msg) =>
            ???
      }
    }

    for {
      m <- program.modules
      e <- m.optExpr
    } {
      interpret(e)(Map())
    }
  }
}