Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • shchen/cs320
  • raveendr/cs320
  • mwojnaro/cs320
3 results
Show changes
Showing
with 2692 additions and 0 deletions
scalaVersion := "3.5.2"
version := "1.0.0"
organization := "ch.epfl.lara"
organizationName := "LARA"
name := "calculator"
libraryDependencies ++= Seq("org.scalatest" %% "scalatest" % "3.2.10" % "test")
\ No newline at end of file
File added
sbt.version=1.10.7
/* Copyright 2020 EPFL, Lausanne
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package calculator
import scallion.*
import silex.*
sealed trait Token
case class NumberToken(value: Int) extends Token
case class OperatorToken(operator: Char) extends Token
case class ParenthesisToken(isOpen: Boolean) extends Token
case object SpaceToken extends Token
case class UnknownToken(content: String) extends Token
object CalcLexer extends Lexers with CharLexers {
type Position = Unit
type Token = calculator.Token
val lexer = Lexer(
// Operators
oneOf("-+/*!")
|> { cs => OperatorToken(cs.head) },
// Parentheses
elem('(') |> ParenthesisToken(true),
elem(')') |> ParenthesisToken(false),
// Spaces
many1(whiteSpace) |> SpaceToken,
// Numbers
{
elem('0') |
nonZero ~ many(digit)
}
|> { cs => NumberToken(cs.mkString.toInt) }
) onError {
(cs, _) => UnknownToken(cs.mkString)
}
def apply(it: String): Iterator[Token] = {
val source = Source.fromString(it, NoPositioner)
val tokens = lexer(source)
tokens.filter((token: Token) => token != SpaceToken)
}
}
sealed abstract class TokenKind(text: String) {
override def toString = text
}
case object NumberClass extends TokenKind("<number>")
case class OperatorClass(op: Char) extends TokenKind(op.toString)
case class ParenthesisClass(isOpen: Boolean) extends TokenKind(if (isOpen) "(" else ")")
case object OtherClass extends TokenKind("?")
sealed abstract class Expr
case class LitExpr(value: Int) extends Expr
case class BinaryExpr(op: Char, left: Expr, right: Expr) extends Expr
case class UnaryExpr(op: Char, inner: Expr) extends Expr
object CalcParser extends Parsers {
type Token = calculator.Token
type Kind = calculator.TokenKind
import Implicits._
override def getKind(token: Token): TokenKind = token match {
case NumberToken(_) => NumberClass
case OperatorToken(c) => OperatorClass(c)
case ParenthesisToken(o) => ParenthesisClass(o)
case _ => OtherClass
}
val number: Syntax[Expr] = accept(NumberClass) {
case NumberToken(n) => LitExpr(n)
}
def binOp(char: Char): Syntax[Char] = accept(OperatorClass(char)) {
case _ => char
}
val plus = binOp('+')
val minus = binOp('-')
val times = binOp('*')
val div = binOp('/')
val fac: Syntax[Char] = accept(OperatorClass('!')) {
case _ => '!'
}
def parens(isOpen: Boolean) = elem(ParenthesisClass(isOpen))
val open = parens(true)
val close = parens(false)
lazy val expr: Syntax[Expr] = recursive {
(term ~ moreTerms).map {
case first ~ opNexts => opNexts.foldLeft(first) {
case (acc, op ~ next) => BinaryExpr(op, acc, next)
}
}
}
lazy val term: Syntax[Expr] = (factor ~ moreFactors).map {
case first ~ opNexts => opNexts.foldLeft(first) {
case (acc, op ~ next) => BinaryExpr(op, acc, next)
}
}
lazy val moreTerms: Syntax[Seq[Char ~ Expr]] = recursive {
epsilon(Seq.empty[Char ~ Expr]) |
((plus | minus) ~ term ~ moreTerms).map {
case op ~ t ~ ots => (op ~ t) +: ots
}
}
lazy val factor: Syntax[Expr] = (basic ~ fac.opt).map {
case e ~ None => e
case e ~ Some(op) => UnaryExpr(op, e)
}
lazy val moreFactors: Syntax[Seq[Char ~ Expr]] = recursive {
epsilon(Seq.empty[Char ~ Expr]) |
((times | div) ~ factor ~ moreFactors).map {
case op ~ t ~ ots => (op ~ t) +: ots
}
}
lazy val basic: Syntax[Expr] = number | open.skip ~ expr ~ close.skip
// Or, using operators...
//
// lazy val expr: Syntax[Expr] = recursive {
// operators(factor)(
// (times | div).is(LeftAssociative),
// (plus | minus).is(LeftAssociative)
// ) {
// case (l, op, r) => BinaryExpr(op, l, r)
// }
// }
//
// Then, you can get rid of term, moreTerms, and moreFactors.
def apply(tokens: Iterator[Token]): Option[Expr] = Parser(expr)(tokens).getValue
}
object Main {
def main(args: Array[String]): Unit = {
if (!CalcParser.expr.isLL1) {
CalcParser.debug(CalcParser.expr, false)
return
}
println("Welcome to the awesome calculator expression parser.")
while (true) {
print("Enter an expression: ")
val line = scala.io.StdIn.readLine()
if (line.isEmpty) {
return
}
CalcParser(CalcLexer(line)) match {
case None => println("Could not parse your line...")
case Some(parsed) => println("Syntax tree: " + parsed)
}
}
}
}
/* Copyright 2019 EPFL, Lausanne
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package calculator
import org.scalatest._
import flatspec._
class Tests extends AnyFlatSpec with Inside {
"Parser" should "be LL(1)" in {
assert(CalcParser.expr.isLL1)
}
it should "be able to parse some strings" in {
val result = CalcParser(CalcLexer("1 + 3 * (5! / 7) + 42"))
assert(result.nonEmpty)
val parsed = result.get
inside(parsed) {
case BinaryExpr('+', BinaryExpr('+', one, mult), fortytwo) => {
assert(one == LitExpr(1))
assert(fortytwo == LitExpr(42))
inside(mult) {
case BinaryExpr('*', three, BinaryExpr('/', UnaryExpr('!', five), seven)) => {
assert(three == LitExpr(3))
assert(five == LitExpr(5))
assert(seven == LitExpr(7))
}
}
}
}
}
}
\ No newline at end of file
**For a brief overview of Scallion and its purpose, you can watch [this
video](https://mediaspace.epfl.ch/media/0_lypn7l0x).** What follows below is
a slightly more detailed description, and an example project you can use
to familiarize yourself with Scallion.
## Introduction to Parser Combinators
The next part of the compiler you will be working on is the parser. The
goal of the parser is to convert the sequence of tokens generated by the
lexer into an Amy *abstract syntax tree* (AST).
There are many approaches to writing parsers, such as:
- Writing the parser by hand directly in the compiler's language using
mutually recursive functions, or
- Writing the parser in a *domain specific language* (DSL) and using a
parser generator (such as Bison) to produce the parser.
Another approach, which we will be using, is *parser combinators*. The
idea behind the approach is very simple:
- Have a set of simple primitive parsers, and
- Have ways to combine them together into more and more complex
parsers. Hence the name *parser combinators*.
Usually, those primitive parsers and combinators are provided as a
library directly in the language used by the compiler. In our case, we
will be working with **Scallion**, a Scala parser combinators library
developed by *LARA*.
Parser combinators have many advantages -- the main one being easy to
write, read and maintain.
## Scallion Parser Combinators
### Documentation
In this document, we will introduce parser combinators in Scallion and
showcase how to use them. This document is not intended to be a complete
reference to Scallion. Fortunately, the library comes with a
[comprehensive
API](https://epfl-lara.github.io/scallion) which
fulfills that role. Feel free to refer to it while working on your
project!
### Playground Project
We have set up [an example project](scallion-playground) that
implements a lexer and parser for a simple expression language using
Scallion. Feel free to experiment and play with it. The project
showcases the API of Scallion and some of the more advanced combinators.
### Setup
In Scallion, parsers are defined within a trait called `Syntaxes`. This
trait takes as parameters two types:
- The type of tokens,
- The type of *token kinds*. Token kinds represent groups of tokens.
They abstract away all the details found in the actual tokens, such
as for instance positions or identifiers name. Each token has a
unique kind.
In our case, the tokens will be of type `Token` that we introduced and
used in the previous project. The token kinds will be `TokenKind`, which
we have already defined for you.
object Parser extends Pipeline[Iterator[Token], Program]
with Parsers {
type Token = myproject.Token
type Kind = myproject.TokenKind
// Indicates the kind of the various tokens.
override def getKind(token: Token): TokenKind = TokenKind.of(token)
// You parser implementation goes here.
}
The `Parsers` trait (mixed into the `Parser` object above) comes from
Scallion and provides all functions and types you will use to define
your grammar and AST translation.
### Writing Parsers
When writing a parser using parser combinators, one defines many smaller
parsers and combines them together into more and more complex parsers.
The top-level, most complex, of those parser then defines the entire
syntax for the language. In our case, that top-level parser will be
called `program`.
All those parsers are objects of the type `Syntax[A]`. The type
parameter `A` indicates the type of values produced by the parser. For
instance, a parser of type `Syntax[Int]` produces `Int`s and a parser of
type `Syntax[Expr]` produces `Expr`s. Our top-level parser has the
following signature:
lazy val program: Parser[Program] = ...
Contrary to the types of tokens and token kinds, which are fixed, the
type of values produced is a type parameter of the various `Syntax`s.
This allows your different parsers to produce different types of values.
The various parsers are stored as `val` members of the `Parser` object.
In the case of mutually dependent parsers, we use `lazy val` instead.
lazy val definition: Syntax[ClassOrFunDef] =
functionDefinition | abstractClassDefinition | caseClassDefinition
lazy val functionDefinition: Syntax[ClassOrFunDef] = ...
lazy val abstractClassDefinition: Syntax[ClassOrFunDef] = ...
lazy val caseClassDefinition: Syntax[ClassOrFunDef] = ...
### Running Parsers
Parsers of type `Syntax[A]` can be converted to objects of type
`Parser[A]`, which have an `apply` method which takes as parameter an
iterator of tokens and returns a value of type `ParseResult[A]`, which
can be one of three things:
- A `Parsed(value, rest)`, which indicates that the parser was
successful and produced the value `value`. The entirety of the input
iterator was consumed by the parser.
- An `UnexpectedToken(token, rest)`, which indicates that the parser
encountered an unexpected token `token`. The input iterator was
consumed up to the erroneous token.
- An `UnexpectedEnd(rest)`, which indicates that the end of the
iterator was reached and the parser could not finish at this point.
The input iterator was completely consumed.
In each case, the additional value `rest` is itself some sort of a
`Parser[A]`. That parser represents the parser after the successful
parse or at the point of error. This parser could be used to provide
useful error messages or even to resume parsing.
override def run(ctx: Context)(tokens: Iterator[Token]): Program = {
import ctx.reporter._
val parser = Parser(program)
parser(tokens) match {
case Parsed(result, rest) => result
case UnexpectedEnd(rest) => fatal("Unexpected end of input.")
case UnexpectedToken(token, rest) => fatal("Unexpected token: " + token)
}
}
### Parsers and Grammars
As you will see, parsers built using parser combinators will look a lot
like grammars. However, unlike grammars, parsers not only describe the
syntax of your language, but also directly specify how to turn this
syntax into a value. Also, as we will see, parser combinators have a
richer vocabulary than your usual *BNF* grammars.
Interestingly, a lot of concepts that you have seen on grammars, such as
`FIRST` sets and nullability can be straightforwardly transposed to
parsers.
#### FIRST set
In Scallion, parsers offer a `first` method which returns the set of
token kinds that are accepted as a first token.
definition.first === Set(def, abstract, case)
#### Nullability
Parsers have a `nullable` method which checks for nullability of a
parser. The method returns `Some(value)` if the parser would produce
`value` given an empty input token sequence, and `None` if the parser
would not accept the empty sequence.
### Basic Parsers
We can now finally have a look at the toolbox we have at our disposition
to build parsers, starting from the basic parsers. Each parser that you
will write, however complex, is a combination of these basic parsers.
The basic parsers play the same role as terminal symbols do in grammars.
#### Elem
The first of the basic parsers is `elem(kind)`. The function `elem`
takes argument the kind of tokens to be accepted by the parser. The
value produced by the parser is the token that was matched. For
instance, here is how to match against the *end-of-file* token.
val eof: Parser[Token] = elem(EOFKind)
#### Accept
The function `accept` is a variant of `elem` which directly applies a
transformation to the matched token when it is produced.
val identifier: Syntax[String] = accept(IdentifierKind) {
case IdentifierToken(name) => name
}
#### Epsilon
The parser `epsilon(value)` is a parser that produces the `value`
without consuming any input. It corresponds to the *𝛆* found in
grammars.
### Parser Combinators
In this section, we will see how to combine parsers together to create
more complex parsers.
#### Disjunction
The first combinator we have is disjunction, that we write, for parsers
`p1` and `p2`, simply `p1 | p2`. When both `p1` and `p2` are of type
`Syntax[A]`, the disjunction `p1 | p2` is also of type `Syntax[A]`. The
disjunction operator is associative and commutative.
Disjunction works just as you think it does. If either of the parsers
`p1` or `p2` would accept the sequence of tokens, then the disjunction
also accepts the tokens. The value produced is the one produced by
either `p1` or `p2`.
Note that `p1` and `p2` must have disjoint `first` sets. This
restriction ensures that no ambiguities can arise and that parsing can
be done efficiently.[^1] We will see later how to automatically detect
when this is not the case and how fix the issue.
#### Sequencing
The second combinator we have is sequencing. We write, for parsers `p1`
and `p2`, the sequence of `p1` and `p2` as `p1 ~ p2`. When `p1` is of
type `A` and `p2` of type `B`, their sequence is of type `A ~ B`, which
is simply a pair of an `A` and a `B`.
If the parser `p1` accepts the prefix of a sequence of tokens and `p2`
accepts the postfix, the parser `p1 ~ p2` accepts the entire sequence
and produces the pair of values produced by `p1` and `p2`.
Note that the `first` set of `p2` should be disjoint from the `first`
set of all sub-parsers in `p1` that are *nullable* and in trailing
position (available via the `followLast` method). This restriction
ensures that the combinator does not introduce ambiguities.
#### Transforming Values
The method `map` makes it possible to apply a transformation to the
values produced by a parser. Using `map` does not influence the sequence
of tokens accepted or rejected by the parser, it merely modifies the
value produced. Generally, you will use `map` on a sequence of parsers,
as in:
lazy val abstractClassDefinition: Syntax[ClassOrFunDef] =
(kw("abstract") ~ kw("class") ~ identifier).map {
case kw ~ _ ~ id => AbstractClassDef(id).setPos(kw)
}
The above parser accepts abstract class definitions in Amy syntax. It
does so by accepting the sequence of keywords `abstract` and `class`,
followed by any identifier. The method `map` is used to convert the
produced values into an `AbstractClassDef`. The position of the keyword
`abstract` is used as the position of the definition.
#### Recursive Parsers
It is highly likely that some of your parsers will require to
recursively invoke themselves. In this case, you should indicate that
the parser is recursive using the `recursive` combinator:
lazy val expr: Syntax[Expr] = recursive {
...
}
If you were to omit it, a `StackOverflow` exception would be triggered
during the initialisation of your `Parser` object.
The `recursive` combinator in itself does not change the behaviour of
the underlying parser. It is there to *tie the knot*[^2].
In practice, it is only required in very few places. In order to avoid
`StackOverflow` exceptions during initialisation, you should make sure
that all recursive parsers (stored in `lazy val`s) must not be able to
reenter themselves without going through a `recursive` combinator
somewhere along the way.
#### Other Combinators
So far, many of the combinators that we have seen, such as disjunction
and sequencing, directly correspond to constructs found in `BNF`
grammars. Some of the combinators that we will see now are more
expressive and implement useful patterns.
##### Optional parsers using opt
The combinator `opt` makes a parser optional. The value produced by the
parser is wrapped in `Some` if the parser accepts the input sequence and
in `None` otherwise.
opt(p) === p.map(Some(_)) | epsilon(None)
##### Repetitions using many and many1
The combinator `many` returns a parser that accepts any number of
repetitions of its argument parser, including 0. The variant `many1`
forces the parser to match at least once.
##### Repetitions with separators repsep and rep1sep
The combinator `repsep` returns a parser that accepts any number of
repetitions of its argument parser, separated by an other parser,
including 0. The variant `rep1sep` forces the parser to match at least
once.
The separator parser is restricted to the type `Syntax[Unit]` to ensure
that important values do not get ignored. You may use `unit()` to on a
parser to turn its value to `Unit` if you explicitly want to ignore the
values a parser produces.
##### Binary operators with operators
Scallion also contains combinators to easily build parsers for infix
binary operators, with different associativities and priority levels.
This combinator is defined in an additional trait called `Operators`,
which you should mix into `Parsers` if you want to use the combinator.
By default, it should already be mixed-in.
val times: Syntax[String] =
accept(OperatorKind("*")) {
case _ => "*"
}
...
lazy val operation: Syntax[Expr] =
operators(number)(
// Defines the different operators, by decreasing priority.
times | div is LeftAssociative,
plus | minus is LeftAssociative,
...
) {
// Defines how to apply the various operators.
case (lhs, "*", rhs) => Times(lhs, rhs).setPos(lhs)
...
}
Documentation for `operators` is [available on this
page](https://epfl-lara.github.io/scallion/scallion/Operators.html).
##### Upcasting
In Scallion, the type `Syntax[A]` is invariant with `A`, meaning that,
even when `A` is a (strict) subtype of some type `B`, we *won\'t* have
that `Syntax[A]` is a subtype of `Syntax[B]`. To upcast a `Syntax[A]` to
a syntax `Syntax[B]` (when `A` is a subtype of `B`), you should use the
`.up[B]` method.
For instance, you may need to upcast a syntax of type
`Syntax[Literal[_]]` to a `Syntax[Expr]` in your assignment. To do so,
simply use `.up[Expr]`.
### LL(1) Checking
In Scallion, non-LL(1) parsers can be written, but the result of
applying such a parser is not specified. In practice, we therefore
restrict ourselves only to LL(1) parsers. The reason behind this is that
LL(1) parsers are unambiguous and can be run in time linear in the input
size.
Writing LL(1) parsers is non-trivial. However, some of the higher-level
combinators of Scallion already alleviate part of this pain. In
addition, LL(1) violations can be detected before the parser is run.
Syntaxes have an `isLL1` method which returns `true` if the parser is
LL(1) and `false` otherwise, and so without needing to see any tokens of
input.
#### Conflict Witnesses
In case your parser is not LL(1), the method `conflicts` of the parser
will return the set of all `LL1Conflict`s. The various conflicts are:
- `NullableConflict`, which indicates that two branches of a
disjunction are nullable.
- `FirstConflict`, which indicates that the `first` set of two
branches of a disjunction are not disjoint.
- `FollowConflict`, which indicates that the `first` set of a nullable
parser is not disjoint from the `first` set of a parser that
directly follows it.
The `LL1Conflict`s objects contain fields which can help you pinpoint
the exact location of conflicts in your parser and hopefully help you
fix those.
The helper method `debug` prints a summary of the LL(1) conflicts of a
parser. We added code in the handout skeleton so that, by default, a
report is outputted in case of conflicts when you initialise your
parser.
[^1]: Scallion is not the only parser combinator library to exist, far
from it! Many of those libraries do not have this restriction. Those
libraries generally need to backtrack to try the different
alternatives when a branch fails.
[^2]: See [a good explanation of what tying the knot means in the
context of lazy
languages.](https://stackoverflow.com/questions/357956/explanation-of-tying-the-knot)
## Demo
```
(func $Factorial_f (param i32 i32) (result i32) (local i32)
;;> fn f(i: Int(32), j: Int(32)): Int(32) = {
;;| val res: Int(32) =
;;| (i + j);
;;| res
;;| }
;;> i
local.get 0
;;> j
local.get 1
;;> (i + j)
i32.add
;;> val res: Int(32)
local.set 2
;;> res
local.get 2
)
(func $Factorial_fact (param i32) (result i32)
;;> fn fact(i: Int(32)): Int(32) = {
;;| (if((i < 2)) {
;;| 1
;;| } else {
;;| (i * fact((i - 1)))
;;| })
;;| }
;;> i
local.get 0
;;> 2
i32.const 2
;;> (i < 2)
i32.lt_s
;;> (if((i < 2)) {
;;| 1
;;| } else {
;;| (i * fact((i - 1)))
;;| })
if (result i32)
;;> 1
i32.const 1
else
;;> i
local.get 0
;;> fact((i - 1))
;;> i
local.get 0
;;> 1
i32.const 1
;;> (i - 1)
i32.sub
call $Factorial_fact
;;> (i * fact((i - 1)))
i32.mul
end
)
```
## WASM basics
### Stack machine
- WASM is a stack based machine.
- WASM has types. We will use exclusively i32.
- Instructions can push or pop values from the stack.
- i32.const x : push x to the stack.
- i32.add : pop 2 values, add them and push the result.
- drop : pop a value and ignore it.
- Locals can store values inside a function. Useful for val definitions among others.
- local.get x : get xth local
- local.set x : set xth local
- Globals store program wide values.
- global.get x : get xth global
- global.set x : set xth global
- Control flow.
- if : pop value from stack, if 0 goto else, otherwise continue.
- call : pop arguments from the stack, jump to function.
## Function calls
How to call a function:
- Push the required number of arguments on the stack.
- Call the function. The call instruction will pop the arguments and place them in the locals.
- The result will be placed on top of the stack.
```
(func $f (param i32 i32) (result i32)
local.get 0
local.get 1
i32.add
)
(
i32.const 3 ;; arg 0
i32.const 4 ;; arg 1
;; A
call $f
;; B
)
A:
| |
| 4 | <-- arg 1
| 3 | <-- arg 0
|-------|
B:
| |
| |
| 7 | <-- result
|-------|
```
## Store
```
Store 3 at address 48
| |
| |
| |
|--------| <-- bottom of the stack
`i32.const 48`
| |
| |
| 48 | <-- address
|--------| <-- bottom of the stack
`i32.const 3`
| |
| 3 | <-- value
| 48 | <-- address
|--------| <-- bottom of the stack
`i32.store` pops 2 values from the stack
| |
| |
| |
|--------| <-- bottom of the stack
Heap
| address | 0 | 1 | 2 | .. | 47 | 48 | 49 | .. |
|---------|----|----|----|----|----|----|----|----|
| value | 0 | 0 | 0 | .. | 0 | 3 | 0 | .. |
^
value written
```
## Values
Very similar to java.
- Ints are represented simply with an i32.
- Bools are represented with an i32, false = 0, true = 1.
- Unit is represented with an i32 with value 0.
### Strings
```
| address | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 |
|---------|----|----|----|----|----|----|----|----|
| value | 104| 101| 108| 108| 111| 33 | 0 | 0 |
| ascii | h | e | l | l | o | ! | \0 | |
| |
| |
| 24 | <-- pointer to string
|--------| <-- bottom of the stack
```
### ADTs
- store the value on the heap to reduced the size to the size of a pointer.
- store which constructor the value holds.
```scala
def getList(): List = { ... }
val ls: List = getList();
// What is the size of list here?
// Is it a Nil or a Cons?
```
```
Cons(42, Nil())
| address | value |
|---------|---------|
| 0 | 1 | \
| 1 | | | constructor id.
| 2 | | | Cons
| 3 | | /
| 4 | 42 | \
| 5 | | | first member: int
| 6 | | | 42
| 7 | | /
| 8 | 1234 | \
| 9 | | | seconder member: pointer to Nil
| 10 | | | 1234
| 11 | | /
Field offset = 4 + 4 * field number
==> Utils.scala:adtField
```
## Allocation
Utils.scala:memoryBoundary is the index of a global variable that holds a pointer to the next free bytes.
### Example in pseudocode:
Start of the program:
global.set(memoryBoundary, 0)
We want to allocate "hello!" = 7 bytes (don't forget the null terminator).
Store current memory pointer as pointer to our new string:
hello_string = global.get(memoryBoundary)
Increment the memory boundary by 7 (size of string).
global.set(memoryBoundary, global.get(memoryBoundary) + 7)
### With webassembly instructions:
```
;; With memoryBoundary = 0.
;; Load the current boundary for string
global.get 0
;; Load it again for the arithmetic
global.get 0
;; length of string
i32.const 7
;; base + length = new boundary
i32.add
;; store new boundary
global.set 0
;; now the string pointer is on the stack, we just
;; need to copy the character's bytes into it.
...
```
## Pattern matching
A pattern matching expression:
e match {
case p1 => e1
...
case pn => en
}
can be considered to be equivalent to the following pseudocode:
val v = e;
if (matchAndBind(v, p1)) e1
else if (matchAndBind(v, p2)) e2
else if ...
else if (matchAndBind(v, pn)) en
else error("Match error!")
matchAndBind is equivalent to this:
WildcardPattern:
"case _ => ..."
matchAndBind(v, _) = true
IdPattern:
"case id => ..."
matchAndBind(v, id) = { id = v; true }
LiteralPattern:
"case 3 => ..."
matchAndBind(v, lit) = { v == lit }
CaseClassPattern:
"case Cons(x, _) => ..."
matchAndBind(C_1(v_1, ..., v_n), C_2(p_1, ..., p_m)) = {
C_1 == C_2 &&
matchAndBind(v_1, p_1) &&
...
matchAndBind(v_m, p_m)
}
sbt.version=1.10.7
addSbtPlugin("com.lightbend.sbt" % "sbt-proguard" % "0.3.0")
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "1.2.0")
\ No newline at end of file
package amyc
import ast._
import utils._
import parsing._
import analyzer._
import codegen._
import interpreter.Interpreter
import java.io.File
object Main extends MainHelpers {
private def parseArgs(args: Array[String]): Context = {
var ctx = Context(new Reporter, Nil)
args foreach {
case "--printTokens" => ctx = ctx.copy(printTokens = true)
case "--printTrees" => ctx = ctx.copy(printTrees = true)
case "--printNames" => ctx = ctx.copy(printNames = true)
case "--interpret" => ctx = ctx.copy(interpret = true)
case "--type-check" => ctx = ctx.copy(typeCheck = true)
case "--help" => ctx = ctx.copy(help = true)
case file => ctx = ctx.copy(files = ctx.files :+ file)
}
ctx
}
def main(args: Array[String]): Unit = {
val ctx = parseArgs(args)
if (ctx.help) {
val helpMsg = {
"""Welcome to the Amy reference compiler, v.1.5
|
|Default behavior is to compile the program to WebAssembly and print the following files:
|(1) the resulting code in WebAssembly text format (.wat),
|(2) the resulting code in WebAssembly binary format (.wasm),
|
|Options:
| --printTokens Print lexer tokens (with positions) after lexing and exit
| --printTrees Print trees after parsing and exit
| --printNames Print trees with unique namas after name analyzer and exit
| --interpret Interpret the program instead of compiling
| --type-check Type-check the program and print trees
| --help Print this message
""".stripMargin
}
println(helpMsg)
sys.exit(0)
}
val pipeline = {
AmyLexer.andThen(
if (ctx.printTokens) DisplayTokens
else Parser.andThen(
if (ctx.printTrees) treePrinterN("Trees after parsing")
else NameAnalyzer.andThen(
if (ctx.printNames) treePrinterS("Trees after name analysis")
else TypeChecker.andThen(
if (ctx.typeCheck) then treePrinterS("Trees after type checking")
else (
if (ctx.interpret) then Interpreter
else CodeGen.andThen(CodePrinter))))))}
val files = ctx.files.map(new File(_))
try {
if (files.isEmpty) {
ctx.reporter.fatal("No input files")
}
files.find(!_.exists()).foreach { f =>
ctx.reporter.fatal(s"File not found: ${f.getName}")
}
pipeline.run(ctx)(files)
ctx.reporter.terminateIfErrors()
} catch {
case AmycFatalError(_) =>
sys.exit(1)
}
}
}
trait MainHelpers {
import SymbolicTreeModule.{Program => SP}
import NominalTreeModule.{Program => NP}
def treePrinterS(title: String): Pipeline[(SP, SymbolTable), Unit] = {
new Pipeline[(SP, SymbolTable), Unit] {
def run(ctx: Context)(v: (SP, SymbolTable)) = {
println(title)
println(SymbolicPrinter(v._1)(true))
}
}
}
def treePrinterN(title: String): Pipeline[NP, Unit] = {
new Pipeline[NP, Unit] {
def run(ctx: Context)(v: NP) = {
println(title)
println(NominalPrinter(v))
}
}
}
}
package amyc
package analyzer
import amyc.utils._
import amyc.ast.{Identifier, NominalTreeModule => N, SymbolicTreeModule => S}
// Name analyzer for Amy
// Takes a nominal program (names are plain string, qualified names are string pairs)
// and returns a symbolic program, where all names have been resolved to unique Identifiers.
// Rejects programs that violate the Amy naming rules.
// Also populates symbol table.
object NameAnalyzer extends Pipeline[N.Program, (S.Program, SymbolTable)] {
def run(ctx: Context)(p: N.Program): (S.Program, SymbolTable) = {
import ctx.reporter._
// Step 0: Initialize symbol table
val table = new SymbolTable
// Step 1: Add modules
val modNames = p.modules.groupBy(_.name)
modNames.foreach{ case (name, modules) =>
if (modules.size > 1) {
fatal(s"Two modules named $name in program", modules.head.position)
}
}
modNames.keys.toList foreach table.addModule
// Step 2: Check name uniqueness in modules
p.modules.foreach { m =>
val names = m.defs.groupBy(_.name)
names.foreach{ case (name, defs) =>
if (defs.size > 1) {
fatal(s"Two definitions named $name in module ${m.name}", defs.head)
}
}
}
// Step 3: Discover types
for {
m <- p.modules
case N.AbstractClassDef(name) <- m.defs
} table.addType(m.name, name)
def transformType(tt: N.TypeTree, inModule: String): S.Type = {
tt.tpe match {
case N.IntType => S.IntType
case N.BooleanType => S.BooleanType
case N.StringType => S.StringType
case N.UnitType => S.UnitType
case N.ClassType(qn@N.QualifiedName(module, name)) =>
table.getType(module getOrElse inModule, name) match {
case Some(symbol) =>
S.ClassType(symbol)
case None =>
fatal(s"Could not find type $qn", tt)
}
}
}
// Step 4: Discover type constructors
for {
m <- p.modules
case cc@N.CaseClassDef(name, fields, parent) <- m.defs
} {
val argTypes = fields map (tt => transformType(tt, m.name))
val retType = table.getType(m.name, parent).getOrElse(fatal(s"Parent class $parent not found", cc))
table.addConstructor(m.name, name, argTypes, retType)
}
// Step 5: Discover functions signatures.
for {
m <- p.modules
case N.FunDef(name, params, retType1, _) <- m.defs
} {
val argTypes = params map (p => transformType(p.tt, m.name))
val retType2 = transformType(retType1, m.name)
table.addFunction(m.name, name, argTypes, retType2)
}
// Step 6: We now know all definitions in the program.
// Reconstruct modules and analyse function bodies/ expressions
def transformDef(df: N.ClassOrFunDef, module: String): S.ClassOrFunDef = { df match {
case N.AbstractClassDef(name) =>
S.AbstractClassDef(table.getType(module, name).get)
case N.CaseClassDef(name, _, _) =>
val Some((sym, sig)): Option[(Identifier, ConstrSig)] = table.getConstructor(module, name) : @unchecked
S.CaseClassDef(
sym,
sig.argTypes map S.TypeTree.apply,
sig.parent
)
case fd: N.FunDef =>
transformFunDef(fd, module)
}}.setPos(df)
def transformFunDef(fd: N.FunDef, module: String): S.FunDef = {
val N.FunDef(name, params, retType, body) = fd
val Some((sym, sig)) = table.getFunction(module, name) : @unchecked
params.groupBy(_.name).foreach { case (name, ps) =>
if (ps.size > 1) {
fatal(s"Two parameters named $name in function ${fd.name}", fd)
}
}
val paramNames = params.map(_.name)
val newParams = params zip sig.argTypes map { case (pd@N.ParamDef(name, tt), tpe) =>
val s = Identifier.fresh(name)
S.ParamDef(s, S.TypeTree(tpe).setPos(tt)).setPos(pd)
}
val paramsMap = paramNames.zip(newParams.map(_.name)).toMap
S.FunDef(
sym,
newParams,
S.TypeTree(sig.retType).setPos(retType),
transformExpr(body)(module, (paramsMap, Map()))
).setPos(fd)
}
def transformExpr(expr: N.Expr)
(implicit module: String, names: (Map[String, Identifier], Map[String, Identifier])): S.Expr = {
val (params, locals) = names
val res = expr match {
case N.Variable(name) =>
S.Variable(
locals.getOrElse(name, // Local variables shadow parameters!
params.getOrElse(name,
fatal(s"Variable $name not found", expr))))
case N.IntLiteral(value) =>
S.IntLiteral(value)
case N.BooleanLiteral(value) =>
S.BooleanLiteral(value)
case N.StringLiteral(value) =>
S.StringLiteral(value)
case N.UnitLiteral() =>
S.UnitLiteral()
case N.Plus(lhs, rhs) =>
S.Plus(transformExpr(lhs), transformExpr(rhs))
case N.Minus(lhs, rhs) =>
S.Minus(transformExpr(lhs), transformExpr(rhs))
case N.Times(lhs, rhs) =>
S.Times(transformExpr(lhs), transformExpr(rhs))
case N.Div(lhs, rhs) =>
S.Div(transformExpr(lhs), transformExpr(rhs))
case N.Mod(lhs, rhs) =>
S.Mod(transformExpr(lhs), transformExpr(rhs))
case N.LessThan(lhs, rhs) =>
S.LessThan(transformExpr(lhs), transformExpr(rhs))
case N.LessEquals(lhs, rhs) =>
S.LessEquals(transformExpr(lhs), transformExpr(rhs))
case N.And(lhs, rhs) =>
S.And(transformExpr(lhs), transformExpr(rhs))
case N.Or(lhs, rhs) =>
S.Or(transformExpr(lhs), transformExpr(rhs))
case N.Equals(lhs, rhs) =>
S.Equals(transformExpr(lhs), transformExpr(rhs))
case N.Concat(lhs, rhs) =>
S.Concat(transformExpr(lhs), transformExpr(rhs))
case N.Not(e) =>
S.Not(transformExpr(e))
case N.Neg(e) =>
S.Neg(transformExpr(e))
case N.Call(qname, args) =>
val owner = qname.module.getOrElse(module)
val name = qname.name
val entry = table.getConstructor(owner, name).orElse(table.getFunction(owner, name))
entry match {
case None =>
fatal(s"Function or constructor $qname not found", expr)
case Some((sym, sig)) =>
if (sig.argTypes.size != args.size) {
fatal(s"Wrong number of arguments for function/constructor $qname", expr)
}
S.Call(sym, args map transformExpr)
}
case N.Sequence(e1, e2) =>
S.Sequence(transformExpr(e1), transformExpr(e2))
case N.Let(vd, value, body) =>
if (locals.contains(vd.name)) {
fatal(s"Variable redefinition: ${vd.name}", vd)
}
if (params.contains(vd.name)) {
warning(s"Local variable ${vd.name} shadows function parameter", vd)
}
val sym = Identifier.fresh(vd.name)
val tpe = transformType(vd.tt, module)
S.Let(
S.ParamDef(sym, S.TypeTree(tpe)).setPos(vd),
transformExpr(value),
transformExpr(body)(module, (params, locals + (vd.name -> sym)))
)
case N.Ite(cond, thenn, elze) =>
S.Ite(transformExpr(cond), transformExpr(thenn), transformExpr(elze))
case N.Match(scrut, cases) =>
def transformCase(cse: N.MatchCase) = {
val N.MatchCase(pat, rhs) = cse
val (newPat, moreLocals) = transformPattern(pat)
S.MatchCase(newPat, transformExpr(rhs)(module, (params, locals ++ moreLocals)).setPos(rhs)).setPos(cse)
}
def transformPattern(pat: N.Pattern): (S.Pattern, List[(String, Identifier)]) = {
val (newPat, newNames): (S.Pattern, List[(String, Identifier)]) = pat match {
case N.WildcardPattern() =>
(S.WildcardPattern(), List())
case N.IdPattern(name) =>
if (locals.contains(name)) {
fatal(s"Pattern identifier $name already defined", pat)
}
if (params.contains(name)) {
warning("Suspicious shadowing by an Id Pattern", pat)
}
table.getConstructor(module, name) match {
case Some((_, ConstrSig(Nil, _, _))) =>
warning(s"There is a nullary constructor in this module called '$name'. Did you mean '$name()'?", pat)
case _ =>
}
val sym = Identifier.fresh(name)
(S.IdPattern(sym), List(name -> sym))
case N.LiteralPattern(lit) =>
(S.LiteralPattern(transformExpr(lit).asInstanceOf[S.Literal[Any]]), List())
case N.CaseClassPattern(constr, args) =>
val (sym, sig) = table
.getConstructor(constr.module.getOrElse(module), constr.name)
.getOrElse(fatal(s"Constructor $constr not found", pat))
if (sig.argTypes.size != args.size) {
fatal(s"Wrong number of args for constructor $constr", pat)
}
val (newPatts, moreLocals0) = (args map transformPattern).unzip
val moreLocals = moreLocals0.flatten
moreLocals.groupBy(_._1).foreach { case (name, pairs) =>
if (pairs.size > 1) {
fatal(s"Multiple definitions of $name in pattern", pat)
}
}
(S.CaseClassPattern(sym, newPatts), moreLocals)
}
(newPat.setPos(pat), newNames)
}
S.Match(transformExpr(scrut), cases map transformCase)
case N.Error(msg) =>
S.Error(transformExpr(msg))
}
res.setPos(expr)
}
val newProgram = S.Program(
p.modules map { case mod@N.ModuleDef(name, defs, optExpr) =>
S.ModuleDef(
table.getModule(name).get,
defs map (transformDef(_, name)),
optExpr map (transformExpr(_)(name, (Map(), Map())))
).setPos(mod)
}
).setPos(p)
(newProgram, table)
}
}
package amyc.analyzer
import amyc.ast.Identifier
import amyc.ast.SymbolicTreeModule._
import amyc.utils.UniqueCounter
import scala.collection.mutable.HashMap
trait Signature[RT <: Type]{
val argTypes: List[Type]
val retType: RT
}
/**
* The signature of a function in the symbol table
*
* @param argTypes Types of the args of the function, in order
* @param retType Return type of the function
* @param owner Name of the module in which the function is defined
*/
case class FunSig(argTypes: List[Type], retType: Type, owner: Identifier) extends Signature[Type]
/**
* The signature of a constructor in the symbol table
*
* @param argTypes Types of the args of the constructor, in order
* @param parent Identifier of the abstract class that the constructor extends
* @param index Constructors extending a parent are numbered, starting at 0 for each parent.
* This is useful for code generation, where we need a runtime representation of which
* instance of the parent type a value represents.
*/
case class ConstrSig(argTypes: List[Type], parent: Identifier, index: Int) extends Signature[ClassType] {
val retType = ClassType(parent)
}
// A class that represents a dictionary of symbols for an Amy program
class SymbolTable {
private val defsByName = HashMap[(String, String), Identifier]()
private val modules = HashMap[String, Identifier]()
private val types = HashMap[Identifier, Identifier]()
private val functions = HashMap[Identifier, FunSig]()
private val constructors = HashMap[Identifier, ConstrSig]()
private val typesToConstructors = HashMap[Identifier, List[Identifier]]()
private val constrIndexes = new UniqueCounter[Identifier]
def addModule(name: String) = {
val s = Identifier.fresh(name)
modules += name -> s
s
}
def getModule(name: String) = modules.get(name)
def addType(owner: String, name: String) = {
val s = Identifier.fresh(name)
defsByName += (owner, name) -> s
types += (s -> modules.getOrElse(owner, sys.error(s"Module $name not found!")))
s
}
def getType(owner: String, name: String) =
defsByName.get(owner,name) filter types.contains
def getType(symbol: Identifier) = types.get(symbol)
def addConstructor(owner: String, name: String, argTypes: List[Type], parent: Identifier) = {
val s = Identifier.fresh(name)
defsByName += (owner, name) -> s
constructors += s -> ConstrSig(
argTypes,
parent,
constrIndexes.next(parent)
)
typesToConstructors += parent -> (typesToConstructors.getOrElse(parent, Nil) :+ s)
s
}
def getConstructor(owner: String, name: String): Option[(Identifier, ConstrSig)] = {
for {
sym <- defsByName.get(owner, name)
sig <- constructors.get(sym)
} yield (sym, sig)
}
def getConstructor(symbol: Identifier) = constructors.get(symbol)
def getConstructorsForType(t: Identifier) = typesToConstructors.get(t)
def addFunction(owner: String, name: String, argTypes: List[Type], retType: Type) = {
val s = Identifier.fresh(name)
defsByName += (owner, name) -> s
functions += s -> FunSig(argTypes, retType, getModule(owner).getOrElse(sys.error(s"Module $owner not found!")))
s
}
def getFunction(owner: String, name: String): Option[(Identifier, FunSig)] = {
for {
sym <- defsByName.get(owner, name)
sig <- functions.get(sym)
} yield (sym, sig)
}
def getFunction(symbol: Identifier) = functions.get(symbol)
}
package amyc
package analyzer
import amyc.utils._
import amyc.ast.SymbolicTreeModule._
import amyc.ast.Identifier
// The type checker for Amy
// Takes a symbolic program and rejects it if it does not follow the Amy typing rules.
object TypeChecker extends Pipeline[(Program, SymbolTable), (Program, SymbolTable)] {
def run(ctx: Context)(v: (Program, SymbolTable)): (Program, SymbolTable) = {
import ctx.reporter._
val (program, table) = v
case class Constraint(found: Type, expected: Type, pos: Position)
// Represents a type variable.
// It extends Type, but it is meant only for internal type checker use,
// since no Amy value can have such type.
case class TypeVariable private (id: Int) extends Type
object TypeVariable {
private val c = new UniqueCounter[Unit]
def fresh(): TypeVariable = TypeVariable(c.next(()))
}
// Generates typing constraints for an expression `e` with a given expected type.
// The environment `env` contains all currently available bindings (you will have to
// extend these, e.g., to account for local variables).
// Returns a list of constraints among types. These will later be solved via unification.
def genConstraints(e: Expr, expected: Type)(implicit env: Map[Identifier, Type]): List[Constraint] = {
// This helper returns a list of a single constraint recording the type
// that we found (or generated) for the current expression `e`
def topLevelConstraint(found: Type): List[Constraint] =
List(Constraint(found, expected, e.position))
e match {
case IntLiteral(_) =>
topLevelConstraint(IntType)
case Equals(lhs, rhs) =>
// HINT: Take care to implement the specified Amy semantics
??? // TODO
case Match(scrut, cases) =>
// Returns additional constraints from within the pattern with all bindings
// from identifiers to types for names bound in the pattern.
// (This is analogous to `transformPattern` in NameAnalyzer.)
def patternBindings(pat: Pattern, expected: Type): (List[Constraint], Map[Identifier, Type]) = {
??? // TODO
}
def handleCase(cse: MatchCase, scrutExpected: Type): List[Constraint] = {
val (patConstraints, moreEnv) = patternBindings(cse.pat, scrutExpected)
??? // TODO
}
val st = TypeVariable.fresh()
genConstraints(scrut, st) ++
cases.flatMap(cse => handleCase(cse, st))
??? // TODO: Implement the remaining cases
}
}
// Given a list of constraints `constraints`, replace every occurence of type variable
// with id `from` by type `to`.
def subst_*(constraints: List[Constraint], from: Int, to: Type): List[Constraint] = {
constraints map { case Constraint(found, expected, pos) =>
Constraint(subst(found, from, to), subst(expected, from, to), pos)
}
}
// Do a single substitution.
def subst(tpe: Type, from: Int, to: Type): Type = {
tpe match {
case TypeVariable(`from`) => to
case other => other
}
}
// Solve the given set of typing constraints and report errors
// using `ctx.reporter.error` if they are not satisfiable.
// We consider a set of constraints to be satisfiable exactly if they unify.
def solveConstraints(constraints: List[Constraint]): Unit = {
constraints match {
case Nil => ()
case Constraint(found, expected, pos) :: more =>
// HINT: You can use the `subst_*` helper above to replace a type variable
// by another type in your current set of constraints.
??? // TODO
}
}
// Putting it all together to type-check each module's functions and main expression.
program.modules.foreach { mod =>
mod.defs.collect { case FunDef(_, params, retType, body) =>
val env = params.map{ case ParamDef(name, tt) => name -> tt.tpe }.toMap
solveConstraints(genConstraints(body, retType.tpe)(env))
}
val tv = TypeVariable.fresh()
mod.optExpr.foreach(e => solveConstraints(genConstraints(e, tv)(Map())))
}
v
}
}
package amyc.ast
object Identifier {
private val counter = new amyc.utils.UniqueCounter[String]
def fresh(name: String): Identifier = new Identifier(name)
}
// Denotes a unique identifier in an Amy program
// Notice that we rely on reference equality to compare Identifiers.
// The numeric id will be generated lazily,
// so the Identifiers are numbered in order when we print the program.
final class Identifier private(val name: String) {
private lazy val id = Identifier.counter.next(name)
def fullName = s"${name}_$id"
override def toString: String = name
}
package amyc.ast
import scala.language.implicitConversions
import amyc.utils._
// A printer for Amy trees
trait Printer {
val treeModule: TreeModule
import treeModule._
implicit def printName(name: Name)(implicit printUniqueIds: Boolean): Document
implicit def printQName(name: QualifiedName)(implicit printUniqueIds: Boolean): Document
protected implicit def stringToDoc(s: String): Raw = Raw(s)
def apply(t: Tree)(implicit printUniqueIDs: Boolean = false): String = {
def binOp(e1: Expr, op: String, e2: Expr) = "(" <:> rec(e1) <:> " " + op + " " <:> rec(e2) <:> ")"
def rec(t: Tree, parens: Boolean = true): Document = t match {
/* Definitions */
case Program(modules) =>
Stacked(modules map (rec(_)), emptyLines = true)
case ModuleDef(name, defs, optExpr) =>
Stacked(
"object " <:> name,
"",
Indented(Stacked(defs ++ optExpr.toList map (rec(_, false)), emptyLines = true)),
"end " <:> name,
""
)
case AbstractClassDef(name) =>
"abstract class " <:> printName(name)
case CaseClassDef(name, fields, parent) =>
def printField(f: TypeTree) = "v: " <:> rec(f)
"case class " <:> name <:> "(" <:> Lined(fields map printField, ", ") <:> ") extends " <:> parent
case FunDef(name, params, retType, body) =>
Stacked(
"def " <:> name <:> "(" <:> Lined(params map (rec(_)), ", ") <:> "): " <:> rec(retType) <:> " = {",
Indented(rec(body, false)),
"}"
)
case ParamDef(name, tpe) =>
name <:> ": " <:> rec(tpe)
/* Expressions */
case Variable(name) =>
name
case IntLiteral(value) =>
value.toString
case BooleanLiteral(value) =>
value.toString
case StringLiteral(value) =>
"\"" + value + '"'
case UnitLiteral() =>
"()"
case Plus(lhs, rhs) =>
binOp(lhs, "+", rhs)
case Minus(lhs, rhs) =>
binOp(lhs, "-", rhs)
case Times(lhs, rhs) =>
binOp(lhs, "*", rhs)
case Div(lhs, rhs) =>
binOp(lhs, "/", rhs)
case Mod(lhs, rhs) =>
binOp(lhs, "%", rhs)
case LessThan(lhs, rhs) =>
binOp(lhs, "<", rhs)
case LessEquals(lhs, rhs) =>
binOp(lhs, "<=", rhs)
case And(lhs, rhs) =>
binOp(lhs, "&&", rhs)
case Or(lhs, rhs) =>
binOp(lhs, "||", rhs)
case Equals(lhs, rhs) =>
binOp(lhs, "==", rhs)
case Concat(lhs, rhs) =>
binOp(lhs, "++", rhs)
case Not(e) =>
"!(" <:> rec(e) <:> ")"
case Neg(e) =>
"-(" <:> rec(e) <:> ")"
case Call(name, args) =>
name <:> "(" <:> Lined(args map (rec(_)), ", ") <:> ")"
case Sequence(lhs, rhs) =>
val main = Stacked(
rec(lhs, false) <:> ";",
rec(rhs, false),
)
if (parens) {
Stacked(
"(",
Indented(main),
")"
)
} else {
main
}
case Let(df, value, body) =>
val main = Stacked(
"val " <:> rec(df) <:> " =",
Indented(rec(value)) <:> ";",
rec(body, false) // For demonstration purposes, the scope or df is indented
)
if (parens) {
Stacked(
"(",
Indented(main),
")"
)
} else {
main
}
case Ite(cond, thenn, elze) =>
Stacked(
"(if(" <:> rec(cond) <:> ") {",
Indented(rec(thenn)),
"} else {",
Indented(rec(elze)),
"})"
)
case Match(scrut, cases) =>
Stacked(
rec(scrut) <:> " match {",
Indented(Stacked(cases map (rec(_)))),
"}"
)
case Error(msg) =>
"error(" <:> rec(msg) <:> ")"
/* cases and patterns */
case MatchCase(pat, expr) =>
Stacked(
"case " <:> rec(pat) <:> " =>",
Indented(rec(expr))
)
case WildcardPattern() =>
"_"
case IdPattern(name) =>
name
case LiteralPattern(lit) =>
rec(lit)
case CaseClassPattern(name, args) =>
name <:> "(" <:> Lined(args map (rec(_)), ", ") <:> ")"
/* Types */
case TypeTree(tp) =>
tp match {
case IntType => "Int(32)"
case BooleanType => "Boolean"
case StringType => "String"
case UnitType => "Unit"
case ClassType(name) => name
}
}
rec(t).print
}
}
object NominalPrinter extends Printer {
val treeModule: NominalTreeModule.type = NominalTreeModule
import NominalTreeModule._
implicit def printName(name: Name)(implicit printUniqueIds: Boolean): Document = Raw(name)
implicit def printQName(name: QualifiedName)(implicit printUniqueIds: Boolean): Document = {
Raw(name match {
case QualifiedName(Some(module), name) =>
s"$module.$name"
case QualifiedName(None, name) =>
name
})
}
}
object SymbolicPrinter extends SymbolicPrinter
trait SymbolicPrinter extends Printer {
val treeModule: SymbolicTreeModule.type = SymbolicTreeModule
import SymbolicTreeModule._
implicit def printName(name: Name)(implicit printUniqueIds: Boolean): Document = {
if (printUniqueIds) {
name.fullName
} else {
name.name
}
}
@inline implicit def printQName(name: QualifiedName)(implicit printUniqueIds: Boolean): Document = {
printName(name)
}
}
package amyc.ast
import amyc.utils.Positioned
/* A polymorphic module containing definitions of Amy trees.
*
* This trait represents either nominal trees (where names have not been resolved)
* or symbolic trees (where names/qualified names) have been resolved to unique identifiers.
* This is done by having two type fields within the module,
* which will be instantiated differently by the two different modules.
*
*/
trait TreeModule { self =>
/* Represents the type for the name for this tree module.
* (It will be either a plain string, or a unique symbol)
*/
type Name
// Represents a name within an module
type QualifiedName
// A printer that knows how to print trees in this module.
// The modules will instantiate it as appropriate
val printer: Printer { val treeModule: self.type }
// Common ancestor for all trees
trait Tree extends Positioned {
override def toString: String = printer(this)
}
// Expressions
trait Expr extends Tree
// Variables
case class Variable(name: Name) extends Expr
// Literals
trait Literal[+T] extends Expr { val value: T }
case class IntLiteral(value: Int) extends Literal[Int]
case class BooleanLiteral(value: Boolean) extends Literal[Boolean]
case class StringLiteral(value: String) extends Literal[String]
case class UnitLiteral() extends Literal[Unit] { val value: Unit = () }
// Binary operators
case class Plus(lhs: Expr, rhs: Expr) extends Expr
case class Minus(lhs: Expr, rhs: Expr) extends Expr
case class Times(lhs: Expr, rhs: Expr) extends Expr
case class Div(lhs: Expr, rhs: Expr) extends Expr
case class Mod(lhs: Expr, rhs: Expr) extends Expr
case class LessThan(lhs: Expr, rhs: Expr) extends Expr
case class LessEquals(lhs: Expr, rhs: Expr) extends Expr
case class And(lhs: Expr, rhs: Expr) extends Expr
case class Or(lhs: Expr, rhs: Expr) extends Expr
case class Equals(lhs: Expr, rhs: Expr) extends Expr
case class Concat(lhs: Expr, rhs: Expr) extends Expr
// Unary operators
case class Not(e: Expr) extends Expr
case class Neg(e: Expr) extends Expr
// Function/constructor call
case class Call(qname: QualifiedName, args: List[Expr]) extends Expr
// The ; operator
case class Sequence(e1: Expr, e2: Expr) extends Expr
// Local variable definition
case class Let(df: ParamDef, value: Expr, body: Expr) extends Expr
// If-then-else
case class Ite(cond: Expr, thenn: Expr, elze: Expr) extends Expr
// Pattern matching
case class Match(scrut: Expr, cases: List[MatchCase]) extends Expr {
require(cases.nonEmpty)
}
// Represents a computational error; prints its message, then exits
case class Error(msg: Expr) extends Expr
// Cases and patterns for Match expressions
case class MatchCase(pat: Pattern, expr: Expr) extends Tree
abstract class Pattern extends Tree
case class WildcardPattern() extends Pattern // _
case class IdPattern(name: Name) extends Pattern // x
case class LiteralPattern[+T](lit: Literal[T]) extends Pattern // 42, true
case class CaseClassPattern(constr: QualifiedName, args: List[Pattern]) extends Pattern // C(arg1, arg2)
// Definitions
trait Definition extends Tree { val name: Name }
case class ModuleDef(name: Name, defs: List[ClassOrFunDef], optExpr: Option[Expr]) extends Definition
trait ClassOrFunDef extends Definition
case class FunDef(name: Name, params: List[ParamDef], retType: TypeTree, body: Expr) extends ClassOrFunDef {
def paramNames = params.map(_.name)
}
case class AbstractClassDef(name: Name) extends ClassOrFunDef
case class CaseClassDef(name: Name, fields: List[TypeTree], parent: Name) extends ClassOrFunDef
case class ParamDef(name: Name, tt: TypeTree) extends Definition
// Types
trait Type
case object IntType extends Type {
override def toString: String = "Int"
}
case object BooleanType extends Type {
override def toString: String = "Boolean"
}
case object StringType extends Type {
override def toString: String = "String"
}
case object UnitType extends Type {
override def toString: String = "Unit"
}
case class ClassType(qname: QualifiedName) extends Type {
override def toString: String = printer.printQName(qname)(false).print
}
// A wrapper for types that is also a Tree (i.e. has a position)
case class TypeTree(tpe: Type) extends Tree
// All is wrapped in a program
case class Program(modules: List[ModuleDef]) extends Tree
}
/* A module containing trees where the names have not been resolved.
* Instantiates Name to String and QualifiedName to a pair of Strings
* representing (module, name) (where module is optional)
*/
object NominalTreeModule extends TreeModule {
type Name = String
case class QualifiedName(module: Option[String], name: String) {
override def toString: String = printer.printQName(this)(false).print
}
val printer = NominalPrinter
}
/* A module containing trees where the names have been resolved to unique identifiers.
* Both Name and ModuleName are instantiated to Identifier.
*/
object SymbolicTreeModule extends TreeModule {
type Name = Identifier
type QualifiedName = Identifier
val printer = SymbolicPrinter
}
package amyc
package codegen
import analyzer._
import amyc.ast.Identifier
import amyc.ast.SymbolicTreeModule.{Call => AmyCall, Div => AmyDiv, And => AmyAnd, Or => AmyOr, _}
import amyc.utils.{Context, Pipeline}
import wasm._
import Instructions._
import Utils._
// Generates WebAssembly code for an Amy program
object CodeGen extends Pipeline[(Program, SymbolTable), Module] {
def run(ctx: Context)(v: (Program, SymbolTable)): Module = {
val (program, table) = v
// Generate code for an Amy module
def cgModule(moduleDef: ModuleDef): List[Function] = {
val ModuleDef(name, defs, optExpr) = moduleDef
// Generate code for all functions
defs.collect { case fd: FunDef if !builtInFunctions(fullName(name, fd.name)) =>
cgFunction(fd, name, false)
} ++
// Generate code for the "main" function, which contains the module expression
optExpr.toList.map { expr =>
val mainFd = FunDef(Identifier.fresh("main"), Nil, TypeTree(IntType), expr)
cgFunction(mainFd, name, true)
}
}
// Generate code for a function in module 'owner'
def cgFunction(fd: FunDef, owner: Identifier, isMain: Boolean): Function = {
// Note: We create the wasm function name from a combination of
// module and function name, since we put everything in the same wasm module.
val name = fullName(owner, fd.name)
Function(name, fd.params.size, isMain){ lh =>
val locals = fd.paramNames.zipWithIndex.toMap
val body = cgExpr(fd.body)(locals, lh)
val comment = Comment(fd.toString)
if (isMain) {
body <:> Drop // Main functions do not return a value,
// so we need to drop the value generated by their body
} else {
comment <:> body
}
}
}
// Generate code for an expression expr.
// Additional arguments are a mapping from identifiers (parameters and variables) to
// their index in the wasm local variables, and a LocalsHandler which will generate
// fresh local slots as required.
def cgExpr(expr: Expr)(implicit locals: Map[Identifier, Int], lh: LocalsHandler): Code = {
expr match {
case IntLiteral(i) =>
// Push i to the stack.
// The comments are optional but can help you debug.
Comment(expr.toString) <:> Const(i)
case Match(scrut, cases) =>
// Checks if a value matches a pattern.
// Assumes value is on top of stack (and CONSUMES it)
// Returns the code to check the value, and a map of bindings.
def matchAndBind(pat: Pattern): (Code, Map[Identifier, Int]) = pat match {
case IdPattern(id) =>
val idLocal = lh.getFreshLocal()
(Comment(pat.toString) <:>
// Assign val to id.
SetLocal(idLocal) <:>
// Return true (IdPattern always matches).
Const(1),
// Let the code generation of the expression which corresponds to this pattern
// know that the bound id is at local idLocal.
Map(id -> idLocal))
case _ => ???
}
???
case _ => ???
}
}
Module(
program.modules.last.name.name,
defaultImports,
globalsNo,
wasmFunctions ++ (program.modules flatMap cgModule)
)
}
}
package amyc
package codegen
import wasm.Module
import amyc.utils.{Context, Pipeline, Env}
import scala.sys.process._
import java.io._
// Prints all 4 different files from a wasm Module
object CodePrinter extends Pipeline[Module, Unit]{
def run(ctx: Context)(m: Module) = {
val outDirName = "wasmout"
def pathWithExt(ext: String) = s"$outDirName/${nameWithExt(ext)}"
def nameWithExt(ext: String) = s"${m.name}.$ext"
val (local, inPath) = {
import Env._
os match {
case Linux => ("./bin/wat2wasm", "wat2wasm")
case Windows => ("./bin/wat2wasm.exe", "wat2wasm.exe")
case Mac => ("./bin/wat2wasm", "wat2wasm")
}
}
val w2wOptions = s"${pathWithExt("wat")} -o ${pathWithExt("wasm")}"
val outDir = new File(outDirName)
if (!outDir.exists()) {
outDir.mkdir()
}
m.writeWasmText(pathWithExt("wat"))
try {
try {
s"$local $w2wOptions".!!
} catch {
case _: IOException =>
s"$inPath $w2wOptions".!!
}
} catch {
case _: IOException =>
ctx.reporter.fatal(
"wat2wasm utility was not found under ./bin or in system path, " +
"or did not have permission to execute. Make sure it is either in the system path, or in <root of the project>/bin"
)
case _: RuntimeException =>
ctx.reporter.fatal(s"wat2wasm failed to translate WebAssembly text file ${pathWithExt("wat")} to binary")
}
}
}
package amyc
package codegen
import amyc.ast.Identifier
import wasm.Function
import wasm.Instructions._
// Utilities for CodeGen
object Utils {
// The index of the global variable that represents the free memory boundary
val memoryBoundary: Int = 0
// The index of the global variable that represents the input buffer
val inputBuffer: Int = 1
// # of global variables
val globalsNo = 2
// The default imports we will pass to a wasm Module
val defaultImports: List[String] = List(
"\"wasi_snapshot_preview1\" \"fd_write\" (func $fd_write (param i32 i32 i32 i32) (result i32))",
"\"wasi_snapshot_preview1\" \"fd_read\" (func $fd_read (param i32 i32 i32 i32) (result i32))"
)
// We don't generate code for these functions in CodeGen (they are hard-coded here or in js wrapper)
val builtInFunctions: Set[String] = Set(
"Std_printString",
"Std_digitToString",
"Std_readInt",
"Std_readString"
)
/** Utilities */
// A globally unique name for definitions
def fullName(owner: Identifier, df: Identifier): String = owner.name + "_" + df.name
// Given a pointer to an ADT on the top of the stack,
// will point at its field in index (and consume the ADT).
// 'index' MUST be 0-based.
def adtField(index: Int): Code = {
Comment(s"adtField index: $index") <:> Const(4* (index + 1)) <:> Add
}
// Increment a local variable
def incr(local: Int): Code = {
GetLocal(local) <:> Const(1) <:> Add <:> SetLocal(local)
}
// A fresh label name
def getFreshLabel(name: String = "label") = {
Identifier.fresh(name).fullName
}
// Creates a known string constant s in memory
def mkString(s: String): Code = {
val size = s.length
val padding = 4 - size % 4
val completeS = s + 0.toChar.toString * padding
val setChars = for ((c, ind) <- completeS.zipWithIndex.toList) yield {
GetGlobal(memoryBoundary) <:> Const(ind) <:> Add <:>
Const(c.toInt) <:> Store8
}
val setMemory =
GetGlobal(memoryBoundary) <:> GetGlobal(memoryBoundary) <:> Const(size + padding) <:> Add <:>
SetGlobal(memoryBoundary)
Comment(s"mkString: $s") <:> setChars <:> setMemory
}
val stringLenImpl: Function = {
Function("String_len", 1, false) { lh =>
val size = lh.getFreshLocal()
val label = getFreshLabel()
Loop(label) <:>
// Load current character
GetLocal(0) <:> Load8_u <:>
// If != 0
If_void <:>
// Increment pointer and size
incr(0) <:> incr(size) <:>
// Jump to loop
Br(label) <:>
Else <:>
End <:>
End <:>
GetLocal(size)
}
}
// Built-in implementation of concatenation
val concatImpl: Function = {
Function("String_concat", 2, false) { lh =>
val ptrS = lh.getFreshLocal()
val ptrD = lh.getFreshLocal()
val label = getFreshLabel()
def mkLoop: Code = {
val label = getFreshLabel()
Loop(label) <:>
// Load current character
GetLocal(ptrS) <:> Load8_u <:>
// If != 0
If_void <:>
// Copy to destination
GetLocal(ptrD) <:>
GetLocal(ptrS) <:> Load8_u <:>
Store8 <:>
// Increment pointers
incr(ptrD) <:> incr(ptrS) <:>
// Jump to loop
Br(label) <:>
Else <:>
End <:>
End
}
// Instantiate ptrD to previous memory, ptrS to first string
GetGlobal(memoryBoundary) <:>
SetLocal(ptrD) <:>
GetLocal(0) <:>
SetLocal(ptrS) <:>
// Copy first string
mkLoop <:>
// Set ptrS to second string
GetLocal(1) <:>
SetLocal(ptrS) <:>
// Copy second string
mkLoop <:>
//
// Pad with zeros until multiple of 4
//
Loop(label) <:>
// Write 0
GetLocal(ptrD) <:> Const(0) <:> Store8 <:>
// Check if multiple of 4, + 3
GetLocal(ptrD) <:> Const(4) <:> Rem <:>
Const(3) <:> Eq <:>
// If not
If_void <:>
Else <:>
// Increment pointer and go back
incr(ptrD) <:>
Br(label) <:>
End <:>
End <:>
//
// Put string pointer to stack, set new memory boundary and return
GetGlobal(memoryBoundary) <:> GetLocal(ptrD) <:> Const(1) <:> Add <:> SetGlobal(memoryBoundary)
}
}
val digitToStringImpl: Function = {
Function("Std_digitToString", 1, false) { lh =>
// We know we have to create a string of total size 4 (digit code + padding), so we do it all together
// We do not need to shift the digit due to little endian structure!
GetGlobal(memoryBoundary) <:> GetLocal(0) <:> Const('0'.toInt) <:> Add <:> Store <:>
// Load memory boundary to stack, then move it by 4
GetGlobal(memoryBoundary) <:>
GetGlobal(memoryBoundary) <:> Const(4) <:> Add <:> SetGlobal(memoryBoundary)
}
}
// You don't need to understand this or printString, but know that we use the following for
// reading and writing from stdin and to stdout in WASI.
// https://github.com/WebAssembly/WASI/blob/main/legacy/preview1/docs.md
val readStringImpl: Function = {
Function("Std_readString", 0, false) { lh =>
val bytesRead = lh.getFreshLocal()
val ptr = lh.getFreshLocal()
val ptrD = lh.getFreshLocal()
// NOTE: this code assumes the input will never contain a null byte.
// This is a decently reasonable assumption.
// The only consequence of having a null byte is that any input after that in that read will be lost.
// Notice that due to how host_read *works*,
// this might read multiple lines of user input.
// We thus need to proceed as follows:
// - Keep a global pointer of buffered reading
// - Read a single line of input
// - Shift the buffered reading pointer
// - If no line is available, read a new buffer from the host
// We can't read "until a certain character is encountered". So instead we read everything, then split by lines.
// Since every line is only read once by an `fd_read` call, every further call must use the earlier read's content to get the next line. This is global state that needs to be maintained.
// E.g. if we call Std_readString 3 times with the input
// ```
// 42
// 43
// [... given later]
// 44
// ```
// 1. Std_readString -> fd_read -> obtains `42\n43` -> returns 42
// 2. Std_readString -> fd_read already called -> returns 43
// 3. Std_readString -> fd_read already called but empty -> fd_read -> obtains `44` -> return 44
//
//
//
// TODO: when we read an input that doesn't have a newline, we might have been unlucky.
// we thus might need to read again and concatenate these two strings.
// NOTE: must keep alignment
val bufferSize = 1024
assert(bufferSize % 4 == 0)
// host_read takes a pointer and a length, and returns the number of bytes read.
// We will read at most bufferSize bytes.
// The pointer is the global memory boundary
def readFromHost = {
val label = getFreshLabel()
// We want to call fd_read. We need to set the iovec structure.
// We will read up to bufferSize bytes.
// 1) Set the pointer to the buffer in a definitely-free position
// 2) Set the length
// 3) Call fd_read
// 4) Check the return value (if an error occured, we trap)
// 5) Get the number of bytes read
// 6) Get the pointer to the buffer
// 1) Set the pointer to the buffer in a definitely-free position
GetGlobal(memoryBoundary) <:> Const(bufferSize) <:> Add <:>
GetGlobal(memoryBoundary) <:> Store <:>
// 2) Set the length
GetGlobal(memoryBoundary) <:> Const(bufferSize+4) <:> Add <:>
Const(bufferSize - 1) <:> Store <:>
// 3) Call fd_read
// 3.1) File descriptor (0 for stdin)
Const(0) <:>
// 3.2) Pointer to the list
GetGlobal(memoryBoundary) <:> Const(bufferSize) <:> Add <:>
// 3.3) Length of the list
Const(1) <:>
// 3.4) Pointer to the output size
GetGlobal(memoryBoundary) <:> Const(bufferSize+8) <:> Add <:>
// 3.5) Call fd_read
Call("fd_read") <:>
// 4) Check the return value (if an error occured, we trap)
If_void <:>
Unreachable <:>
Else <:>
End <:>
// 5) Get the number of bytes read
GetGlobal(memoryBoundary) <:> Const(bufferSize+8) <:> Add <:> Load <:>
SetLocal(bytesRead) <:>
// 6) Load the pointer to the buffer
GetGlobal(memoryBoundary) <:>
// Increment the memory boundary by the number of bytes read
GetLocal(bytesRead) <:> GetGlobal(memoryBoundary) <:> Add <:> SetGlobal(memoryBoundary) <:>
// Add padding and null termination. We can do this because we read at most bufferSize-1 bytes.
Loop(label) <:>
// Add padding/null termination
GetGlobal(memoryBoundary) <:> Const(0) <:> Store8 <:>
GetGlobal(memoryBoundary) <:> Const(1) <:> Add <:> SetGlobal(memoryBoundary) <:>
// Check if the allocator is aligned.
GetGlobal(memoryBoundary) <:> Const(4) <:> Rem <:>
If_void <:>
// If not, jump back to the loop.
Br(label) <:>
Else <:>
End <:>
End
}
val lenLoop = getFreshLabel()
val cloneLoop = getFreshLabel()
val paddingLoop = getFreshLabel()
GetGlobal(inputBuffer) <:>
If_i32 <:>
// The input buffer has already been read at least once. Check if we're at the end!
GetGlobal(inputBuffer) <:>
Load8_u <:>
Const(0) <:>
Eq <:>
Else <:>
// Special case: the read buffer has never been read
// We thus always need to read.
Const(1) <:>
End <:>
// If we need to, read from the host.
If_void <:>
readFromHost <:>
SetGlobal(inputBuffer) <:>
Else <:>
End <:>
// The start of the string:
GetGlobal(inputBuffer) <:>
// Start reading a single string.
GetGlobal(inputBuffer) <:>
SetLocal(ptr) <:>
Const(0) <:>
SetLocal(bytesRead) <:>
Loop(lenLoop) <:>
GetLocal(ptr) <:> Load8_u <:>
If_void <:>
GetLocal(ptr) <:> Load8_u <:>
Const('\n'.toInt) <:> Eq <:>
If_void <:>
// Set the input buffer to the next element of the read
GetLocal(ptr) <:> Const(1) <:> Add <:> SetGlobal(inputBuffer) <:>
Else <:>
incr(bytesRead) <:>
incr(ptr) <:>
Br(lenLoop) <:>
End <:>
Else <:>
// Input buffer is empty! Next time we'll have to reread.
GetLocal(ptr) <:> SetGlobal(inputBuffer) <:>
End <:>
End <:>
SetLocal(ptr) <:>
GetGlobal(memoryBoundary) <:> SetLocal(ptrD) <:>
// bytesRead = number of non-zero bytes in the string. Time to clone!
Loop(cloneLoop) <:>
GetLocal(bytesRead) <:>
If_void <:>
GetLocal(ptrD) <:>
GetLocal(ptr) <:> Load8_u <:>
Store8 <:>
incr(ptr) <:>
incr(ptrD) <:>
GetLocal(bytesRead) <:> Const(1) <:> Sub <:> SetLocal(bytesRead) <:>
Br(cloneLoop) <:>
Else <:>
End <:>
End <:>
Loop(paddingLoop) <:>
// Write 0
GetLocal(ptrD) <:> Const(0) <:> Store8 <:>
// Check if multiple of 4
GetLocal(ptrD) <:> Const(4) <:> Rem <:>
Const(3) <:> Eq <:>
// If not
If_void <:>
Else <:>
// Increment pointer and go back
incr(ptrD) <:>
Br(paddingLoop) <:>
End <:>
End <:>
GetGlobal(memoryBoundary) <:>
GetLocal(ptrD) <:> Const(1) <:> Add <:> SetGlobal(memoryBoundary)
}
}
val readIntImpl: Function = {
Function("Std_readInt", 0, false) { lh =>
val isNegative = lh.getFreshLocal()
val ptr = lh.getFreshLocal()
val tmp = lh.getFreshLocal()
// Initialized to 0, by convention.
val value = lh.getFreshLocal()
val loop = getFreshLabel()
// Reads a digit from ptr and leaves it on the stack
// Traps if there is no digit.
def readDigit: Code = {
GetLocal(ptr) <:> Load8_u <:>
Const('0'.toInt) <:> Sub <:>
SetLocal(tmp) <:> GetLocal(tmp) <:>
// TODO: this should be an unsigned comparison
Const(10) <:> Lt_s <:>
If_i32 <:>
GetLocal(tmp) <:>
Else <:>
Unreachable <:>
End
}
// First, read a string
GetGlobal(memoryBoundary)
Call("Std_readString") <:>
SetLocal(ptr) <:>
// Then, convert it to an integer
// Check if there's a minus sign
GetLocal(ptr) <:> Load8_u <:>
Const('-'.toInt) <:> Eq <:>
If_void <:>
// Is negative
incr(ptr) <:>
Const(1) <:>
SetLocal(isNegative) <:>
Else <:>
End <:>
// Then, read all remaining digits
Loop(loop) <:>
GetLocal(ptr) <:> Load8_u <:>
Const(0) <:> Eq <:>
If_void <:>
Else <:> // Not 0! Read the digit.
GetLocal(value) <:>
Const(10) <:>
Mul <:>
readDigit <:>
Add <:>
SetLocal(value) <:>
incr(ptr) <:>
Br(loop) <:>
End <:>
End <:>
// Finally, negate if necessary
GetLocal(isNegative) <:>
If_i32 <:>
Const(0) <:>
GetLocal(value) <:>
Sub <:>
Else <:>
GetLocal(value) <:>
End
}
}
val printStringImpl: Function = {
Function("Std_printString", 1, false) { lh =>
// To print a string and a newline, we need to:
// 1) Turn our null-terminated string into a wide pointer
// 2) Call fd_write
// 3) Trap if the return value is not 0 (I/O error)
// 4) Get the number of bytes written
// 5) Create a wide pointer to a string of a single newline.
// 6) Call fd_write
// 7) Trap if the return value is not 0 (I/O error)
// 8) Return the number of bytes written
// 1) Turn our null-terminated string into a wide pointer
GetGlobal(memoryBoundary) <:>
GetLocal(0) <:> Store <:>
GetGlobal(memoryBoundary) <:> Const(4) <:> Add <:>
// Get the string length
GetLocal(0) <:> Call("String_len") <:> Store <:>
// 2) Call fd_write
// 2.1) File descriptor (1 for stdout)
Const(1) <:>
// 2.2) Pointer to the list
GetGlobal(memoryBoundary) <:>
// 2.3) Length of the list
Const(1) <:>
// 2.4) Pointer to the output size
GetGlobal(memoryBoundary) <:> Const(8) <:> Add <:>
Call("fd_write") <:>
// 3) Trap if the return value is not 0 (I/O error)
If_void <:>
Unreachable <:>
Else <:>
End <:>
// 4) Load the number of bytes written
GetGlobal(memoryBoundary) <:> Const(8) <:> Add <:> Load <:>
// 5) Create a wide pointer to a string of a single newline
// 5.1) Write the newline
GetGlobal(memoryBoundary) <:> Const(12) <:> Add <:>
Const('\n'.toInt) <:> Store <:>
// 5.2) Write the pointer
GetGlobal(memoryBoundary) <:>
GetGlobal(memoryBoundary) <:> Const(12) <:> Add <:> Store <:>
// 5.3) Write the length
GetGlobal(memoryBoundary) <:> Const(4) <:> Add <:>
Const(1) <:> Store <:>
// 6) Call fd_write
// 6.1) File descriptor (1 for stdout)
Const(1) <:>
// 6.2) Pointer to the list
GetGlobal(memoryBoundary) <:>
// 6.3) Length of the list
Const(1) <:>
// 6.4) Pointer to the output size
GetGlobal(memoryBoundary) <:> Const(8) <:> Add <:>
Call("fd_write") <:>
// 7) Trap if the return value is not 0 (I/O error)
If_void <:>
Unreachable <:>
Else <:>
End
// 8) Return the number of bytes written
// (implicit via the stack)
}
}
val wasmFunctions = List(stringLenImpl, concatImpl, digitToStringImpl, readStringImpl, readIntImpl, printStringImpl)
}
package amyc
package interpreter
import utils._
import ast.SymbolicTreeModule._
import ast.Identifier
import analyzer.SymbolTable
// An interpreter for Amy programs, implemented in Scala
object Interpreter extends Pipeline[(Program, SymbolTable), Unit] {
// A class that represents a value computed by interpreting an expression
abstract class Value {
def asInt: Int = this.asInstanceOf[IntValue].i
def asBoolean: Boolean = this.asInstanceOf[BooleanValue].b
def asString: String = this.asInstanceOf[StringValue].s
override def toString: String = this match {
case IntValue(i) => i.toString
case BooleanValue(b) => b.toString
case StringValue(s) => s
case UnitValue => "()"
case CaseClassValue(constructor, args) =>
constructor.name + "(" + args.map(_.toString).mkString(", ") + ")"
}
}
case class IntValue(i: Int) extends Value
case class BooleanValue(b: Boolean) extends Value
case class StringValue(s: String) extends Value
case object UnitValue extends Value
case class CaseClassValue(constructor: Identifier, args: List[Value]) extends Value
def run(ctx: Context)(v: (Program, SymbolTable)): Unit = {
val (program, table) = v
// These built-in functions do not have an Amy implementation in the program,
// instead their implementation is encoded in this map
val builtIns: Map[(String, String), (List[Value]) => Value] = Map(
("Std", "printInt") -> { args => println(args.head.asInt); UnitValue },
("Std", "printString") -> { args => println(args.head.asString); UnitValue },
("Std", "readString") -> { args => StringValue(scala.io.StdIn.readLine()) },
("Std", "readInt") -> { args =>
val input = scala.io.StdIn.readLine()
try {
IntValue(input.toInt)
} catch {
case ne: NumberFormatException =>
ctx.reporter.fatal(s"""Could not parse "$input" to Int""")
}
},
("Std", "intToString") -> { args => StringValue(args.head.asInt.toString) },
("Std", "digitToString") -> { args => StringValue(args.head.asInt.toString) }
)
// Utility functions to interface with the symbol table.
def isConstructor(name: Identifier) = table.getConstructor(name).isDefined
def findFunctionOwner(functionName: Identifier) = table.getFunction(functionName).get.owner.name
def findFunction(owner: String, name: String) = {
program.modules.find(_.name.name == owner).get.defs.collectFirst {
case fd@FunDef(fn, _, _, _) if fn.name == name => fd
}.get
}
// Interprets a function, using evaluations for local variables contained in 'locals'
// TODO: Complete all missing cases. Look at the given ones for guidance.
def interpret(expr: Expr)(implicit locals: Map[Identifier, Value]): Value = {
expr match {
case Variable(name) =>
???
case IntLiteral(i) =>
???
case BooleanLiteral(b) =>
???
case StringLiteral(s) =>
???
case UnitLiteral() =>
???
case Plus(lhs, rhs) =>
IntValue(interpret(lhs).asInt + interpret(rhs).asInt)
case Minus(lhs, rhs) =>
???
case Times(lhs, rhs) =>
???
case Div(lhs, rhs) =>
???
case Mod(lhs, rhs) =>
???
case LessThan(lhs, rhs) =>
???
case LessEquals(lhs, rhs) =>
???
case And(lhs, rhs) =>
???
case Or(lhs, rhs) =>
???
case Equals(lhs, rhs) =>
??? // Hint: Take care to implement Amy equality semantics
case Concat(lhs, rhs) =>
???
case Not(e) =>
???
case Neg(e) =>
???
case Call(qname, args) =>
???
// Hint: Check if it is a call to a constructor first,
// then if it is a built-in function (otherwise it is a normal function).
// Use the helper methods provided above to retrieve information from the symbol table.
// Think how locals should be modified.
case Sequence(e1, e2) =>
???
case Let(df, value, body) =>
???
case Ite(cond, thenn, elze) =>
???
case Match(scrut, cases) =>
???
// Hint: We give you a skeleton to implement pattern matching
// and the main body of the implementation
val evS = interpret(scrut)
// None = pattern does not match
// Returns a list of pairs id -> value,
// where id has been bound to value within the pattern.
// Returns None when the pattern fails to match.
// Note: Only works on well typed patterns (which have been ensured by the type checker).
def matchesPattern(v: Value, pat: Pattern): Option[List[(Identifier, Value)]] = {
((v, pat): @unchecked) match {
case (_, WildcardPattern()) =>
???
case (_, IdPattern(name)) =>
???
case (IntValue(i1), LiteralPattern(IntLiteral(i2))) =>
???
case (BooleanValue(b1), LiteralPattern(BooleanLiteral(b2))) =>
???
case (StringValue(_), LiteralPattern(StringLiteral(_))) =>
???
case (UnitValue, LiteralPattern(UnitLiteral())) =>
???
case (CaseClassValue(con1, realArgs), CaseClassPattern(con2, formalArgs)) =>
???
}
}
// Main "loop" of the implementation: Go through every case,
// check if the pattern matches, and if so return the evaluation of the case expression
cases.to(LazyList).map(matchCase =>
val MatchCase(pat, rhs) = matchCase
(rhs, matchesPattern(evS, pat))
).find(_._2.isDefined) match {
case Some((rhs, Some(moreLocals))) =>
interpret(rhs)(locals ++ moreLocals)
case _ =>
// No case matched
ctx.reporter.fatal(s"Match error: ${evS.toString}@${scrut.position}")
}
case Error(msg) =>
???
}
}
for {
m <- program.modules
e <- m.optExpr
} {
interpret(e)(Map())
}
}
}