diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index edab9e06b7335e81ec5d7c2dcd14547a1b06550a..636f8fd1a27c7518acfb0ce91adb83c521f4d9f7 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -21,4 +21,4 @@ test: tags: - cs320 script: - - sbt "runMain lecture1.javaThreads" + - sbt "runMain lecture1.javaThreads; runMain lecture1.scalaThreadsWrapper; runMain lecture1.ExampleThread; runMain lecture3.intersectionWrong; runMain lecture3.intersectionCorrect; runMain lecture3.intersectionNoSideEffect; runMain lecture3.parallelGraphContraction; runMain lecture3.parallelGraphContractionCorrect; runMain midterm22.mock1; runMain midterm22.part3; test" diff --git a/build.sbt b/build.sbt index 5f2272d6f4733212721543aef3fa721a55264f7f..f93a9289a63b922e2f82b2c0e10223cea6f86457 100644 --- a/build.sbt +++ b/build.sbt @@ -4,8 +4,12 @@ scalaVersion := "3.2.0" libraryDependencies ++= Seq( ("com.storm-enroute" %% "scalameter-core" % "0.21").cross(CrossVersion.for3Use2_13), "org.scala-lang.modules" %% "scala-parallel-collections" % "1.0.4", - "org.scalameta" %% "munit" % "0.7.26" % Test + "junit" % "junit" % "4.13" % Test, + "com.github.sbt" % "junit-interface" % "0.13.3" % Test, ) scalacOptions ++= Seq("-unchecked", "-deprecation") +Test / parallelExecution := false +Test / testOptions += Tests.Argument(TestFrameworks.JUnit) + enablePlugins(JmhPlugin, ScalafmtPlugin) diff --git a/src/main/scala/benchmarks/midterm22/AbstractCollectionBenchmark.scala b/src/main/scala/benchmarks/midterm22/AbstractCollectionBenchmark.scala new file mode 100644 index 0000000000000000000000000000000000000000..16ecbb7cc075fb172fd69597d76a8145ea862939 --- /dev/null +++ b/src/main/scala/benchmarks/midterm22/AbstractCollectionBenchmark.scala @@ -0,0 +1,23 @@ +package benchmarks.midterm22 + +import org.openjdk.jmh.annotations.* + +@State(Scope.Benchmark) +//@Fork(jvmArgsAppend = Array("-Djava.util.concurrent.ForkJoinPool.common.parallelism=4")) +abstract class AbstractCollectionBenchmark: + @Param(Array("10000", "100000", "1000000", "10000000")) + var size: Int = _ + + @Param(Array("Vector", "Array", "ArrayBuffer", "List")) + var collection: String = _ + + var haystack: Iterable[Int] = _ + + @Setup(Level.Invocation) + def setup() = + val gen = (1 to (size * 2) by 2) + haystack = collection match + case "Vector" => gen.toVector + case "Array" => gen.toArray + case "ArrayBuffer" => gen.toBuffer + case "List" => gen.toList diff --git a/src/main/scala/benchmarks/midterm22/CollectionBenchmark.scala b/src/main/scala/benchmarks/midterm22/CollectionBenchmark.scala new file mode 100644 index 0000000000000000000000000000000000000000..fc2bd0f5345fed09d434cab8d2bd5bcb19a483ef --- /dev/null +++ b/src/main/scala/benchmarks/midterm22/CollectionBenchmark.scala @@ -0,0 +1,12 @@ +package benchmarks.midterm22 + +import org.openjdk.jmh.annotations.* + +class CollectionBenchmark extends AbstractCollectionBenchmark: + @Benchmark + def take() = + haystack.take(size / 2) + + @Benchmark + def drop() = + haystack.drop(size / 2) diff --git a/src/main/scala/benchmarks/midterm22/Part2Benchmark.scala b/src/main/scala/benchmarks/midterm22/Part2Benchmark.scala new file mode 100644 index 0000000000000000000000000000000000000000..9e97cc16865290f24a958c7f06624e1d1215ad07 --- /dev/null +++ b/src/main/scala/benchmarks/midterm22/Part2Benchmark.scala @@ -0,0 +1,20 @@ +package benchmarks.midterm22 + +import midterm22.contains + +import org.openjdk.jmh.annotations.* + +class MidtermPart2Benchmark extends AbstractCollectionBenchmark: + val needle = 10 + + @Param(Array("true", "false")) + var parallel: Boolean = _ + + @Setup(Level.Invocation) + override def setup() = + super.setup() + midterm22.parallelismEnabled = parallel + + @Benchmark + def containsBenchmark() = + contains(haystack, needle) diff --git a/src/main/scala/instrumentation/Monitor.scala b/src/main/scala/instrumentation/Monitor.scala new file mode 100644 index 0000000000000000000000000000000000000000..d8b3b68ce3aea5ab0d8e43ea34abf6f57bc7925c --- /dev/null +++ b/src/main/scala/instrumentation/Monitor.scala @@ -0,0 +1,40 @@ +package instrumentation + +class Dummy + +trait Monitor: + given dummy: Dummy = new Dummy + + def wait()(implicit i: Dummy = dummy) = waitDefault() + + def synchronized[T](e: => T)(implicit i: Dummy = dummy) = synchronizedDefault( + e + ) + + def notify()(implicit i: Dummy = dummy) = notifyDefault() + + def notifyAll()(implicit i: Dummy = dummy) = notifyAllDefault() + + private val lock = new AnyRef + + // Can be overridden. + def waitDefault(): Unit = lock.wait() + def synchronizedDefault[T](toExecute: => T): T = lock.synchronized(toExecute) + def notifyDefault(): Unit = lock.notify() + def notifyAllDefault(): Unit = lock.notifyAll() + +trait LockFreeMonitor extends Monitor: + override def waitDefault() = + throw new Exception("Please use lock-free structures and do not use wait()") + override def synchronizedDefault[T](toExecute: => T): T = + throw new Exception( + "Please use lock-free structures and do not use synchronized()" + ) + override def notifyDefault() = + throw new Exception( + "Please use lock-free structures and do not use notify()" + ) + override def notifyAllDefault() = + throw new Exception( + "Please use lock-free structures and do not use notifyAll()" + ) diff --git a/src/main/scala/midterm22/Mock1.scala b/src/main/scala/midterm22/Mock1.scala new file mode 100644 index 0000000000000000000000000000000000000000..ffd473476da4ad48e5a651bd05dd7cc3515aacf7 --- /dev/null +++ b/src/main/scala/midterm22/Mock1.scala @@ -0,0 +1,14 @@ +package midterm22 + +import scala.collection.mutable.Set + +@main def mock1() = + val values = Set[Int]() + for _ <- 1 to 100000 do + var sum = 0 + val t1 = task { sum += 1 } + val t2 = task { sum += 1 } + t1.join() + t2.join() + values += sum + println(values) diff --git a/src/main/scala/midterm22/Mock2.scala b/src/main/scala/midterm22/Mock2.scala new file mode 100644 index 0000000000000000000000000000000000000000..d9ef1cd11b353d52aa92018f495fe1d834458386 --- /dev/null +++ b/src/main/scala/midterm22/Mock2.scala @@ -0,0 +1,20 @@ +package midterm22 + +import instrumentation.Monitor + +class Account(private var amount: Int = 0) extends Monitor: + def transfer(target: Account, n: Int) = + this.synchronized { + target.synchronized { + this.amount -= n + target.amount += n + } + } + +@main def mock2() = + val a = new Account(50) + val b = new Account(70) + val t1 = task { a.transfer(b, 10) } + val t2 = task { b.transfer(a, 10) } + t1.join() + t2.join() diff --git a/src/main/scala/midterm22/Part1.scala b/src/main/scala/midterm22/Part1.scala new file mode 100644 index 0000000000000000000000000000000000000000..90c19bcb3bb71577774ba50501b6518dd52ec127 --- /dev/null +++ b/src/main/scala/midterm22/Part1.scala @@ -0,0 +1,45 @@ +package midterm22 + +import scala.collection.parallel.Task +import scala.collection.parallel.CollectionConverters.* + +// Questions 1-3 + +// See tests in midterm22.Part1Test. +// Run with `sbt "testOnly midterm22.Part1Test2"`. + +def parallel3[A, B, C](op1: => A, op2: => B, op3: => C): (A, B, C) = + val res1 = task { op1 } + val res2 = task { op2 } + val res3 = op3 + (res1.join(), res2.join(), res3) + +def find(arr: Array[Int], value: Int, threshold: Int): Option[Int] = + def findHelper(start: Int, end: Int): Option[Int] = + if end - start <= threshold then + var i = start + while i < end do + if arr(i) == value then return Some(value) + i += 1 + None + else + val inc = (end - start) / 3 + val (res1, res2, res3) = parallel3( + findHelper(start, start + inc), + findHelper(start + inc, start + 2 * inc), + findHelper(start + 2 * inc, end) + ) + res1.orElse(res2).orElse(res3) + findHelper(0, arr.size) + +def findAggregated(arr: Array[Int], value: Int): Option[Int] = + val no: Option[Int] = None + val yes: Option[Int] = Some(value) + def f = (x1: Option[Int], x2: Int) => if x2 == value then Some(x2) else x1 + def g = (x1: Option[Int], x2: Option[Int]) => if x1 != None then x1 else x2 + arr.par.aggregate(no)(f, g) + +@main def part1() = + println(find(Array(1, 2, 3), 2, 1)) + +// See tests in Part1Test diff --git a/src/main/scala/midterm22/Part2.scala b/src/main/scala/midterm22/Part2.scala new file mode 100644 index 0000000000000000000000000000000000000000..15dabfed3a5385e2f3a6e87d23a2146dd8d51f07 --- /dev/null +++ b/src/main/scala/midterm22/Part2.scala @@ -0,0 +1,50 @@ +package midterm22 + +import scala.annotation.nowarn + +// Questions 4-7 + +// See tests in midterm22.Part2Test. +// Run with `sbt testOnly midterm22.Part2Test`. + +/* +Answers to the exam questions: + When called with a Vector: + The total amount of work is O(n), as it is dominated by the time needed to + read the array. More precisely W(n) = c + 2*W(n/2) = O(n). + + The depth is O(log(n)), because every recursion takes constant time + and we divide the size of the input by 2 every time, i.e. D(n) = c + D(n/2) = O(log(n)). + + Note however that in practice it is often still faster to manipulate + start and end indices rather than using take and drop. + + When called with a List: + Every recursion takes up to time O(n) rather than constant time. + + The total amount of work is O(n) times the number of recursion, because + take and drop takes time O(n) on lists. Precisely, W(n) = n + 2*W(n/2) = O(log(n)*n) + + The depth is computed similarly: D(n) = n + D(n/2) = O(n), i.e. + +Note: these are theoretical results. In practice, you should always double-check +that kind of conclusions with benchmarks. We did so in +`midterm-code/src/main/scala/bench`. Results are available in `bench-results`. +From these results, we can conclude that +1. Vectors are indeed faster in this case, and +2. parallelization of `contains` yields a 2x speedup. + */ +@nowarn +def contains[A](l: Iterable[A], elem: A): Boolean = + val n = l.size + if n <= 5 then + for i <- l do + if i == elem then + return true + false + else + val (p0, p1) = parallel( + contains(l.take(n / 2), elem), + contains(l.drop(n / 2), elem) + ) + p0 || p1 diff --git a/src/main/scala/midterm22/Part3.scala b/src/main/scala/midterm22/Part3.scala new file mode 100644 index 0000000000000000000000000000000000000000..e1b4b20602c801c20d3fc19e3da25baf66748749 --- /dev/null +++ b/src/main/scala/midterm22/Part3.scala @@ -0,0 +1,14 @@ +package midterm22 + +// Question 8 + +// Run with `sbt runMain midterm22.part3` + +@main def part3() = + def thread(b: => Unit) = + val t = new Thread: + override def run() = b + t + val t = thread { println(s"Hello World") } + t.join() + println(s"Hello") diff --git a/src/main/scala/midterm22/Part4.scala b/src/main/scala/midterm22/Part4.scala new file mode 100644 index 0000000000000000000000000000000000000000..f158e895729dc5ea30a081b397fbc691076892da --- /dev/null +++ b/src/main/scala/midterm22/Part4.scala @@ -0,0 +1,23 @@ +package midterm22 + +import java.util.concurrent.atomic.AtomicInteger +import instrumentation.* + +import instrumentation.Monitor +// Questions 9-15 + +// See tests in midterm22.Part4Test. +// Run with `sbt testOnly midterm22.Part4Test`. + +class Node( + // Globally unique identifier. Different for each instance. + val guid: Int +) extends Monitor + +// This function might be called concurrently. +def lockFun(nodes: List[Node], fn: (e: Int) => Unit): Unit = + if nodes.size > 0 then + nodes.head.synchronized { + fn(nodes(0).guid) + lockFun(nodes.tail, fn) + } diff --git a/src/main/scala/midterm22/Part6.scala b/src/main/scala/midterm22/Part6.scala new file mode 100644 index 0000000000000000000000000000000000000000..934857c04f8846b9d7cfdb507080b5961df32e4e --- /dev/null +++ b/src/main/scala/midterm22/Part6.scala @@ -0,0 +1,20 @@ +package midterm22 + +import instrumentation.Monitor + +// Question 21 + +// See tests in midterm22.Part6Test. +// Run with `sbt testOnly midterm22.Part6Test`. + +class TicketsManager(totalTickets: Int) extends Monitor: + var remainingTickets = totalTickets + + // This method might be called concurrently + def getTicket(): Boolean = + if remainingTickets > 0 then + this.synchronized { + remainingTickets -= 1 + } + true + else false diff --git a/src/main/scala/midterm22/Part7.scala b/src/main/scala/midterm22/Part7.scala new file mode 100644 index 0000000000000000000000000000000000000000..1632a6b14b7017bce3926c757f86c254a4b0c27e --- /dev/null +++ b/src/main/scala/midterm22/Part7.scala @@ -0,0 +1,47 @@ +package midterm22 + +import instrumentation.Monitor + +// Questions 22-24 + +// See tests in midterm22.Part7Test. +// Run with `sbt testOnly midterm22.Part7Test`. + +class NIC(private val _index: Int, private var _assigned: Boolean) + extends Monitor: + def index = _index + def assigned = _assigned + def assigned_=(v: Boolean) = _assigned = v + +class NICManager(n: Int): + // Creates a list with n NICs + val nics = (for i <- 0 until n yield NIC(i, false)).toList + + // This method might be called concurrently + def assignNICs(limitRecvNICs: Boolean = false): (Int, Int) = + var recvNIC: Int = 0 + var sendNIC: Int = 0 + var gotRecvNIC: Boolean = false + var gotSendNIC: Boolean = false + + /// Obtaining receiving NIC... + while !gotRecvNIC do + nics(recvNIC).synchronized { + if !nics(recvNIC).assigned then + nics(recvNIC).assigned = true + gotRecvNIC = true + else recvNIC = (recvNIC + 1) % (if limitRecvNICs then n - 1 else n) + } + // Successfully obtained receiving NIC + + // Obtaining sending NIC... + while !gotSendNIC do + nics(sendNIC).synchronized { + if !nics(sendNIC).assigned then + nics(sendNIC).assigned = true + gotSendNIC = true + else sendNIC = (sendNIC + 1) % n + } + // Successfully obtained sending NIC + + return (recvNIC, sendNIC) diff --git a/src/main/scala/midterm22/Part8.scala b/src/main/scala/midterm22/Part8.scala new file mode 100644 index 0000000000000000000000000000000000000000..7242c058fc3561454b4c333a85d792e6e4ab2302 --- /dev/null +++ b/src/main/scala/midterm22/Part8.scala @@ -0,0 +1,84 @@ +package midterm22 + +import scala.collection.concurrent.{TrieMap, Map} +import java.util.concurrent.atomic.AtomicInteger +import scala.annotation.tailrec + +// Question 25 + +// See tests in midterm22.Part8Test. +// Run with `sbt testOnly midterm22.Part8Test`. + +// Represent a social network where user can follow each other. Each user is +// represented by an id that is an `Int`. +abstract class AbstractInstagram: + // The map storing the "following" relation of our social network. + // `graph(a)` contains the list of user ids that user `a` follows. + val graph: Map[Int, List[Int]] = new TrieMap[Int, List[Int]]() + + // The maximum user id allocated until now. This value should be incremented + // by one each time a new user is added. + val maxId = new AtomicInteger(0) + + // Allocates a new user and returns its unique id. Internally, this should + // also create an empty list at the corresponding id in `graph`. The + // implementation must be thread-safe. + def add(): Int + + // Make `a` follow `b`. The implementation must be thread-safe. + def follow(a: Int, b: Int): Unit + + // Makes `a` unfollow `b`. The implementation must be thread-safe. + def unfollow(a: Int, b: Int): Unit + + // Removes user with id `a`. This should also remove all references to `a` + // in `graph`. The implementation must be thread-safe. + def remove(a: Int): Unit + +class Instagram extends AbstractInstagram: + // This method is worth 6 points. + def add(): Int = + // It is important to increment and read the value in one atomic step. See + // test `testParallelWrongAdd` for an alternative wrong implementation. + val id = maxId.incrementAndGet + // Note: it is also correct to use `graph.putIfAbsent`, but not needed as + // `id` is always new and therefore absent from the map at this point. + graph.update(id, Nil) + id + + // This method is worth 8 points. + def remove(a: Int): Unit = + graph.remove(a) + // Iterate through all keys to make sure that nobody follows `a` anymore. + // For each key, we need to unfollow a in a thread-safe manner. Calling + // `unfollow` is the simplest way to so, as it is already guaranteed to be + // thread-safe. See test `testParallelWrongRemove` for an example of wrong + // implementation. + for b <- graph.keys do unfollow(b, a) + + // This method is worth 10 points. + def unfollow(a: Int, b: Int) = + // Here, it is important to read the value only once. First calling + // `.contains(a)` and then `graph(a)` does not work, as `a` might be removed + // between the two calls. See `testParallelWrongUnfollow` for an example of + // this wrong implementation. + val prev = graph.get(a) + // Returns silently if `a` does not exist. + if prev.isEmpty then return + // We replace the list of users that `a` follows in an atomic manner. If the + // list of followed users changed concurrently, we start over. + if !graph.replace(a, prev.get, prev.get.filter(_ != b)) then unfollow(a, b) + + // This method is worth 12 points. + def follow(a: Int, b: Int) = + val prev = graph.get(a) + // Returns silently if `a` does not exist. + if prev.isEmpty then return + // We replace the list of users that `a` follows in an atomic manner. If the + // list of followed users changed concurrently, we start over. + if !graph.replace(a, prev.get, b :: prev.get) then follow(a, b) + // Difficult: this handles the case where `b` is concurrently removed by + // another thread. To detect this case, we must check if `b` still exists + // after we have followed it, and unfollow it if it is not the case. See + // test `testParallelFollowABRemoveB`. This last step is worth 4 points. + else if !graph.contains(b) then unfollow(a, b) diff --git a/src/main/scala/midterm22/appendix/appendix.scala b/src/main/scala/midterm22/appendix/appendix.scala new file mode 100644 index 0000000000000000000000000000000000000000..05c1609dcf7d7fcddc6704f4ad803fdb09b529cc --- /dev/null +++ b/src/main/scala/midterm22/appendix/appendix.scala @@ -0,0 +1,133 @@ +package midterm22.appendix + +// Represents optional values. Instances of Option are either an instance of +// scala.Some or the object None. +sealed abstract class Option[+A]: + // Returns the option's value if the option is an instance of scala.Some, or + // throws an exception if the option is None. + def get: A + // Returns true if the option is an instance of scala.Some, false otherwise. + // This is equivalent to: + // option match + // case Some(v) => true + // case None => false + def isDefined: Boolean + // Returns this scala.Option if it is nonempty, otherwise return the result of + // evaluating alternative. + def orElse[B >: A](alternative: => Option[B]): Option[B] + +abstract class Iterable[+A]: + // Selects all elements except first n ones. + def drop(n: Int): Iterable[A] + // The size of this collection. + def size: Int + // Selects the first n elements. + def take(n: Int): Iterable[A] + +abstract class List[+A] extends Iterable[A]: + // Adds an element at the beginning of this list. + def ::[B >: A](elem: B): List[B] + // A copy of this sequence with an element appended. + def appended[B >: A](elem: B): List[B] + // Get the element at the specified index. + def apply(n: Int): A + // Selects all elements of this list which satisfy a predicate. + def filter(pred: (A) => Boolean): List[A] + // Selects the first element of this list. + def head: A + // Sorts this sequence according to a comparison function. + def sortWith(lt: (A, A) => Boolean): List[A] + // Selects all elements except the first. + def tail: List[A] + +abstract class Array[+A] extends Iterable[A]: + // Get the element at the specified index. + def apply(n: Int): A + +abstract class Thread: + // Subclasses should override this method. + def run(): Unit + // Causes this thread to begin execution; the Java Virtual Machine calls the + // run method of this thread. + def start(): Unit + // Waits for this thread to die. + def join(): Unit + +// Creates and starts a new task ran concurrently. +def task[T](body: => T): ForkJoinTask[T] = ??? + +abstract class ForkJoinTask[T]: + // Returns the result of the computation when it is done. + def join(): T + +// A concurrent hash-trie or TrieMap is a concurrent thread-safe lock-free +// implementation of a hash array mapped trie. +abstract class TrieMap[K, V]: + // Retrieves the value which is associated with the given key. Throws a + // NoSuchElementException if there is no mapping from the given key to a + // value. + def apply(key: K): V + // Tests whether this map contains a binding for a key. + def contains(key: K): Boolean + // Applies a function f to all elements of this concurrent map. This function + // iterates over a snapshot of the map. + def foreach[U](f: ((K, V)) => U): Unit + // Optionally returns the value associated with a key. + def get(key: K): Option[V] + // Collects all key of this map in an iterable collection. The result is a + // snapshot of the values at a specific point in time. + def keys: Iterator[K] + // Transforms this map by applying a function to every retrieved value. This + // returns a new map. + def mapValues[W](f: V => W): TrieMap[K, W] + // Associates the given key with a given value, unless the key was already + // associated with some other value. This is an atomic operation. + def putIfAbsent(k: K, v: V): Option[V] + // Removes a key from this map, returning the value associated previously with + // that key as an option. + def remove(k: K): Option[V] + // Removes the entry for the specified key if it's currently mapped to the + // specified value. This is an atomic operation. + def remove(k: K, v: V): Boolean + // Replaces the entry for the given key only if it was previously mapped to a + // given value. Returns true if the change is successful, or false otherwise. + // This is an atomic operation. + def replace(k: K, oldvalue: V, newvalue: V): Boolean + // Adds a new key/value pair to this map. + def update(k: K, v: V): Unit + // Collects all values of this map in an iterable collection. The result is a + // snapshot of the values at a specific point in time. + def values: Iterator[V] + +// An int value that may be updated atomically. +// The constructor takes the initial value at its only argument. For example, +// this create an `AtomicInteger` with an initial value of `42`: +// val myAtomicInteger = new AtomicInteger(42) +abstract class AtomicInteger: + // Atomically adds the given value to the current value and returns the + // updated value. + def addAndGet(delta: Int): Int + // Atomically sets the value to the given updated value if the current value + // == the expected value. Returns true if the change is successful, or false + // otherwise. This is an atomic operation. + def compareAndSet(oldvalue: Int, newvalue: Int): Boolean + // Gets the current value. This is an atomic operation. + def get(): Int + // Atomically increments by one the current value. This is an atomic operation. + def incrementAndGet(): Int + +// --------------------------------------------- + +// Needed so that we can compile successfully, but not included for students. +// See Option class doc instead. +abstract class Some[+A](value: A) extends Option[A] +object Some: + def apply[A](value: A): Some[A] = ??? +val None: Option[Nothing] = ??? + +object List: + def apply[A](values: A*): List[A] = ??? +val Nil: List[Nothing] = ??? + +object Array: + def apply[A](values: A*): Array[A] = ??? diff --git a/src/main/scala/midterm22/common.scala b/src/main/scala/midterm22/common.scala new file mode 100644 index 0000000000000000000000000000000000000000..a88ea2522dee53405c5e7766057055601a24715f --- /dev/null +++ b/src/main/scala/midterm22/common.scala @@ -0,0 +1,30 @@ +package midterm22 + +import java.util.concurrent.ForkJoinTask +import java.util.concurrent.RecursiveTask +import java.util.concurrent.ForkJoinWorkerThread +import java.util.concurrent.ForkJoinPool +import java.util.concurrent.atomic.AtomicInteger + +val forkJoinPool = ForkJoinPool() +var parallelismEnabled = true +var tasksCreated: AtomicInteger = AtomicInteger(0) + +def schedule[T](body: => T): ForkJoinTask[T] = + val t = new RecursiveTask[T]: + def compute = body + Thread.currentThread match + case wt: ForkJoinWorkerThread => t.fork() + case _ => forkJoinPool.execute(t) + t + +def task[T](body: => T): ForkJoinTask[T] = + tasksCreated.incrementAndGet + schedule(body) + +def parallel[A, B](op1: => A, op2: => B): (A, B) = + if parallelismEnabled then + val res1 = task { op1 } + val res2 = op2 + (res1.join(), res2) + else (op1, op2) diff --git a/src/test/scala/instrumentation/MockedMonitor.scala b/src/test/scala/instrumentation/MockedMonitor.scala new file mode 100644 index 0000000000000000000000000000000000000000..69a889f46983c888d34331985bf9beaa472e6db4 --- /dev/null +++ b/src/test/scala/instrumentation/MockedMonitor.scala @@ -0,0 +1,46 @@ +package instrumentation + +trait MockedMonitor extends Monitor: + def scheduler: Scheduler + + // Can be overriden. + override def waitDefault() = + scheduler.log("wait") + scheduler updateThreadState Wait(this, scheduler.threadLocks.tail) + override def synchronizedDefault[T](toExecute: =>T): T = + scheduler.log("synchronized check") + val prevLocks = scheduler.threadLocks + scheduler updateThreadState Sync(this, prevLocks) // If this belongs to prevLocks, should just continue. + scheduler.log("synchronized -> enter") + try + toExecute + finally + scheduler updateThreadState Running(prevLocks) + scheduler.log("synchronized -> out") + override def notifyDefault() = + scheduler mapOtherStates { + state => state match + case Wait(lockToAquire, locks) if lockToAquire == this => SyncUnique(this, state.locks) + case e => e + } + scheduler.log("notify") + override def notifyAllDefault() = + scheduler mapOtherStates { + state => state match + case Wait(lockToAquire, locks) if lockToAquire == this => Sync(this, state.locks) + case SyncUnique(lockToAquire, locks) if lockToAquire == this => Sync(this, state.locks) + case e => e + } + scheduler.log("notifyAll") + +abstract class ThreadState: + def locks: Seq[AnyRef] +trait CanContinueIfAcquiresLock extends ThreadState: + def lockToAquire: AnyRef +case object Start extends ThreadState { def locks: Seq[AnyRef] = Seq.empty } +case object End extends ThreadState { def locks: Seq[AnyRef] = Seq.empty } +case class Wait(lockToAquire: AnyRef, locks: Seq[AnyRef]) extends ThreadState +case class SyncUnique(lockToAquire: AnyRef, locks: Seq[AnyRef]) extends ThreadState with CanContinueIfAcquiresLock +case class Sync(lockToAquire: AnyRef, locks: Seq[AnyRef]) extends ThreadState with CanContinueIfAcquiresLock +case class Running(locks: Seq[AnyRef]) extends ThreadState +case class VariableReadWrite(locks: Seq[AnyRef]) extends ThreadState diff --git a/src/test/scala/instrumentation/Scheduler.scala b/src/test/scala/instrumentation/Scheduler.scala new file mode 100644 index 0000000000000000000000000000000000000000..1f77342c7283aaf1a69ab8f929e15a62358e9b9e --- /dev/null +++ b/src/test/scala/instrumentation/Scheduler.scala @@ -0,0 +1,276 @@ +package instrumentation + +import java.util.concurrent._; +import scala.concurrent.duration._ +import scala.collection.mutable._ +import Stats._ + +import java.util.concurrent.atomic.AtomicInteger + +sealed abstract class Result +case class RetVal(rets: List[Any]) extends Result +case class Except(msg: String, stackTrace: Array[StackTraceElement]) extends Result +case class Timeout(msg: String) extends Result + +/** + * A class that maintains schedule and a set of thread ids. + * The schedules are advanced after an operation of a SchedulableBuffer is performed. + * Note: the real schedule that is executed may deviate from the input schedule + * due to the adjustments that had to be made for locks + */ +class Scheduler(sched: List[Int]): + val maxOps = 500 // a limit on the maximum number of operations the code is allowed to perform + + private var schedule = sched + private var numThreads = 0 + private val realToFakeThreadId = Map[Long, Int]() + private val opLog = ListBuffer[String]() // a mutable list (used for efficient concat) + private val threadStates = Map[Int, ThreadState]() + + /** + * Runs a set of operations in parallel as per the schedule. + * Each operation may consist of many primitive operations like reads or writes + * to shared data structure each of which should be executed using the function `exec`. + * @timeout in milliseconds + * @return true - all threads completed on time, false -some tests timed out. + */ + def runInParallel(timeout: Long, ops: List[() => Any]): Result = + numThreads = ops.length + val threadRes = Array.fill(numThreads) { None: Any } + var exception: Option[Except] = None + val syncObject = new Object() + var completed = new AtomicInteger(0) + // create threads + val threads = ops.zipWithIndex.map { + case (op, i) => + new Thread(new Runnable() { + def run(): Unit = { + val fakeId = i + 1 + setThreadId(fakeId) + try { + updateThreadState(Start) + val res = op() + updateThreadState(End) + threadRes(i) = res + // notify the main thread if all threads have completed + if completed.incrementAndGet() == ops.length then { + syncObject.synchronized { syncObject.notifyAll() } + } + } catch { + case e: Throwable if exception != None => // do nothing here and silently fail + case e: Throwable => + log(s"throw ${e.toString}") + exception = Some(Except(s"Thread $fakeId crashed on the following schedule: \n" + opLog.mkString("\n"), + e.getStackTrace)) + syncObject.synchronized { syncObject.notifyAll() } + //println(s"$fakeId: ${e.toString}") + //Runtime.getRuntime().halt(0) //exit the JVM and all running threads (no other way to kill other threads) + } + } + }) + } + // start all threads + threads.foreach(_.start()) + // wait for all threads to complete, or for an exception to be thrown, or for the time out to expire + var remTime = timeout + syncObject.synchronized { + timed { if completed.get() != ops.length then syncObject.wait(timeout) } { time => remTime -= time } + } + if exception.isDefined then + exception.get + else if remTime <= 1 then // timeout ? using 1 instead of zero to allow for some errors + Timeout(opLog.mkString("\n")) + else + // every thing executed normally + RetVal(threadRes.toList) + + // Updates the state of the current thread + def updateThreadState(state: ThreadState): Unit = + val tid = threadId + synchronized { + threadStates(tid) = state + } + state match + case Sync(lockToAquire, locks) => + if locks.indexOf(lockToAquire) < 0 then waitForTurn else + // Re-aqcuiring the same lock + updateThreadState(Running(lockToAquire +: locks)) + case Start => waitStart() + case End => removeFromSchedule(tid) + case Running(_) => + case _ => waitForTurn // Wait, SyncUnique, VariableReadWrite + + def waitStart(): Unit = + //while (threadStates.size < numThreads) { + //Thread.sleep(1) + //} + synchronized { + if threadStates.size < numThreads then + wait() + else + notifyAll() + } + + def threadLocks = + synchronized { + threadStates(threadId).locks + } + + def threadState = + synchronized { + threadStates(threadId) + } + + def mapOtherStates(f: ThreadState => ThreadState) = + val exception = threadId + synchronized { + for k <- threadStates.keys if k != exception do + threadStates(k) = f(threadStates(k)) + } + + def log(str: String) = + if (realToFakeThreadId contains Thread.currentThread().getId()) then + val space = (" " * ((threadId - 1) * 2)) + val s = space + threadId + ":" + "\n".r.replaceAllIn(str, "\n" + space + " ") + //println(s) + opLog += s + + /** + * Executes a read or write operation to a global data structure as per the given schedule + * @param msg a message corresponding to the operation that will be logged + */ + def exec[T](primop: => T)(msg: => String, postMsg: => Option[T => String] = None): T = + if ! (realToFakeThreadId contains Thread.currentThread().getId()) then + primop + else + updateThreadState(VariableReadWrite(threadLocks)) + val m = msg + if m != "" then log(m) + if opLog.size > maxOps then + throw new Exception(s"Total number of reads/writes performed by threads exceed $maxOps. A possible deadlock!") + val res = primop + postMsg match + case Some(m) => log(m(res)) + case None => + res + + private def setThreadId(fakeId: Int) = synchronized { + realToFakeThreadId(Thread.currentThread.getId) = fakeId + } + + def threadId = + try + realToFakeThreadId(Thread.currentThread().getId()) + catch + case e: NoSuchElementException => + throw new Exception("You are accessing shared variables in the constructor. This is not allowed. The variables are already initialized!") + + private def isTurn(tid: Int) = synchronized { + (!schedule.isEmpty && schedule.head != tid) + } + + def canProceed(): Boolean = + val tid = threadId + canContinue match + case Some((i, state)) if i == tid => + //println(s"$tid: Runs ! Was in state $state") + canContinue = None + state match + case Sync(lockToAquire, locks) => updateThreadState(Running(lockToAquire +: locks)) + case SyncUnique(lockToAquire, locks) => + mapOtherStates { + _ match + case SyncUnique(lockToAquire2, locks2) if lockToAquire2 == lockToAquire => Wait(lockToAquire2, locks2) + case e => e + } + updateThreadState(Running(lockToAquire +: locks)) + case VariableReadWrite(locks) => updateThreadState(Running(locks)) + true + case Some((i, state)) => + //println(s"$tid: not my turn but $i !") + false + case None => + false + + var threadPreference = 0 // In the case the schedule is over, which thread should have the preference to execute. + + /** returns true if the thread can continue to execute, and false otherwise */ + def decide(): Option[(Int, ThreadState)] = + if !threadStates.isEmpty then // The last thread who enters the decision loop takes the decision. + //println(s"$threadId: I'm taking a decision") + if threadStates.values.forall { case e: Wait => true case _ => false } then + val waiting = threadStates.keys.map(_.toString).mkString(", ") + val s = if threadStates.size > 1 then "s" else "" + val are = if threadStates.size > 1 then "are" else "is" + throw new Exception(s"Deadlock: Thread$s $waiting $are waiting but all others have ended and cannot notify them.") + else + // Threads can be in Wait, Sync, SyncUnique, and VariableReadWrite mode. + // Let's determine which ones can continue. + val notFree = threadStates.collect { case (id, state) => state.locks }.flatten.toSet + val threadsNotBlocked = threadStates.toSeq.filter { + case (id, v: VariableReadWrite) => true + case (id, v: CanContinueIfAcquiresLock) => !notFree(v.lockToAquire) || (v.locks contains v.lockToAquire) + case _ => false + } + if threadsNotBlocked.isEmpty then + val waiting = threadStates.keys.map(_.toString).mkString(", ") + val s = if threadStates.size > 1 then "s" else "" + val are = if threadStates.size > 1 then "are" else "is" + val whoHasLock = threadStates.toSeq.flatMap { case (id, state) => state.locks.map(lock => (lock, id)) }.toMap + val reason = threadStates.collect { + case (id, state: CanContinueIfAcquiresLock) if !notFree(state.lockToAquire) => + s"Thread $id is waiting on lock ${state.lockToAquire} held by thread ${whoHasLock(state.lockToAquire)}" + }.mkString("\n") + throw new Exception(s"Deadlock: Thread$s $waiting are interlocked. Indeed:\n$reason") + else if threadsNotBlocked.size == 1 then // Do not consume the schedule if only one thread can execute. + Some(threadsNotBlocked(0)) + else + val next = schedule.indexWhere(t => threadsNotBlocked.exists { case (id, state) => id == t }) + if next != -1 then + //println(s"$threadId: schedule is $schedule, next chosen is ${schedule(next)}") + val chosenOne = schedule(next) // TODO: Make schedule a mutable list. + schedule = schedule.take(next) ++ schedule.drop(next + 1) + Some((chosenOne, threadStates(chosenOne))) + else + threadPreference = (threadPreference + 1) % threadsNotBlocked.size + val chosenOne = threadsNotBlocked(threadPreference) // Maybe another strategy + Some(chosenOne) + //threadsNotBlocked.indexOf(threadId) >= 0 + /* + val tnb = threadsNotBlocked.map(_._1).mkString(",") + val s = if (schedule.isEmpty) "empty" else schedule.mkString(",") + val only = if (schedule.isEmpty) "" else " only" + throw new Exception(s"The schedule is $s but$only threads ${tnb} can continue")*/ + else canContinue + + /** + * This will be called before a schedulable operation begins. + * This should not use synchronized + */ + var numThreadsWaiting = new AtomicInteger(0) + //var waitingForDecision = Map[Int, Option[Int]]() // Mapping from thread ids to a number indicating who is going to make the choice. + var canContinue: Option[(Int, ThreadState)] = None // The result of the decision thread Id of the thread authorized to continue. + private def waitForTurn = + synchronized { + if numThreadsWaiting.incrementAndGet() == threadStates.size then + canContinue = decide() + notifyAll() + //waitingForDecision(threadId) = Some(numThreadsWaiting) + //println(s"$threadId Entering waiting with ticket number $numThreadsWaiting/${waitingForDecision.size}") + while !canProceed() do wait() + } + numThreadsWaiting.decrementAndGet() + + /** + * To be invoked when a thread is about to complete + */ + private def removeFromSchedule(fakeid: Int) = synchronized { + //println(s"$fakeid: I'm taking a decision because I finished") + schedule = schedule.filterNot(_ == fakeid) + threadStates -= fakeid + if numThreadsWaiting.get() == threadStates.size then + canContinue = decide() + notifyAll() + } + + def getOperationLog() = opLog diff --git a/src/test/scala/instrumentation/Stats.scala b/src/test/scala/instrumentation/Stats.scala new file mode 100644 index 0000000000000000000000000000000000000000..b876596ce6ad25ee20b5c5e9f34766f5c8b29956 --- /dev/null +++ b/src/test/scala/instrumentation/Stats.scala @@ -0,0 +1,19 @@ +package instrumentation + +import java.lang.management._ + +/** + * A collection of methods that can be used to collect run-time statistics about Leon programs. + * This is mostly used to test the resources properties of Leon programs + */ +object Stats: + def timed[T](code: => T)(cont: Long => Unit): T = + var t1 = System.currentTimeMillis() + val r = code + cont((System.currentTimeMillis() - t1)) + r + + def withTime[T](code: => T): (T, Long) = + var t1 = System.currentTimeMillis() + val r = code + (r, (System.currentTimeMillis() - t1)) diff --git a/src/test/scala/instrumentation/TestHelper.scala b/src/test/scala/instrumentation/TestHelper.scala new file mode 100644 index 0000000000000000000000000000000000000000..6bed59d4a44ee6903b5229f36bb1ea37ea32b08f --- /dev/null +++ b/src/test/scala/instrumentation/TestHelper.scala @@ -0,0 +1,112 @@ +package instrumentation + +import scala.util.Random +import scala.collection.mutable.{Map => MutableMap} + +import Stats._ + +object TestHelper: + val noOfSchedules = 10000 // set this to 100k during deployment + val readWritesPerThread = 20 // maximum number of read/writes possible in one thread + val contextSwitchBound = 10 + val testTimeout = 240 // the total time out for a test in seconds + val schedTimeout = 15 // the total time out for execution of a schedule in secs + + // Helpers + /*def testManySchedules(op1: => Any): Unit = testManySchedules(List(() => op1)) + def testManySchedules(op1: => Any, op2: => Any): Unit = testManySchedules(List(() => op1, () => op2)) + def testManySchedules(op1: => Any, op2: => Any, op3: => Any): Unit = testManySchedules(List(() => op1, () => op2, () => op3)) + def testManySchedules(op1: => Any, op2: => Any, op3: => Any, op4: => Any): Unit = testManySchedules(List(() => op1, () => op2, () => op3, () => op4))*/ + + def testSequential[T](ops: Scheduler => Any)(assertions: T => (Boolean, String)) = + testManySchedules(1, + (sched: Scheduler) => { + (List(() => ops(sched)), + (res: List[Any]) => assertions(res.head.asInstanceOf[T])) + }) + + /** + * @numThreads number of threads + * @ops operations to be executed, one per thread + * @assertion as condition that will executed after all threads have completed (without exceptions) + * the arguments are the results of the threads + */ + def testManySchedules(numThreads: Int, + ops: Scheduler => + (List[() => Any], // Threads + List[Any] => (Boolean, String)) // Assertion + ) = + var timeout = testTimeout * 1000L + val threadIds = (1 to numThreads) + //(1 to scheduleLength).flatMap(_ => threadIds).toList.permutations.take(noOfSchedules).foreach { + val schedules = (new ScheduleGenerator(numThreads)).schedules() + var schedsExplored = 0 + schedules.takeWhile(_ => schedsExplored <= noOfSchedules && timeout > 0).foreach { + //case _ if timeout <= 0 => // break + case schedule => + schedsExplored += 1 + val schedr = new Scheduler(schedule) + //println("Exploring Sched: "+schedule) + val (threadOps, assertion) = ops(schedr) + if threadOps.size != numThreads then + throw new IllegalStateException(s"Number of threads: $numThreads, do not match operations of threads: $threadOps") + timed { schedr.runInParallel(schedTimeout * 1000, threadOps) } { t => timeout -= t } match + case Timeout(msg) => + throw new java.lang.AssertionError("assertion failed\n"+"The schedule took too long to complete. A possible deadlock! \n"+msg) + case Except(msg, stkTrace) => + val traceStr = "Thread Stack trace: \n"+stkTrace.map(" at "+_.toString).mkString("\n") + throw new java.lang.AssertionError("assertion failed\n"+msg+"\n"+traceStr) + case RetVal(threadRes) => + // check the assertion + val (success, custom_msg) = assertion(threadRes) + if !success then + val msg = "The following schedule resulted in wrong results: \n" + custom_msg + "\n" + schedr.getOperationLog().mkString("\n") + throw new java.lang.AssertionError("Assertion failed: "+msg) + } + if timeout <= 0 then + throw new java.lang.AssertionError("Test took too long to complete! Cannot check all schedules as your code is too slow!") + + /** + * A schedule generator that is based on the context bound + */ + class ScheduleGenerator(numThreads: Int): + val scheduleLength = readWritesPerThread * numThreads + val rands = (1 to scheduleLength).map(i => new Random(0xcafe * i)) // random numbers for choosing a thread at each position + def schedules(): LazyList[List[Int]] = + var contextSwitches = 0 + var contexts = List[Int]() // a stack of thread ids in the order of context-switches + val remainingOps = MutableMap[Int, Int]() + remainingOps ++= (1 to numThreads).map(i => (i, readWritesPerThread)) // num ops remaining in each thread + val liveThreads = (1 to numThreads).toSeq.toBuffer + + /** + * Updates remainingOps and liveThreads once a thread is chosen for a position in the schedule + */ + def updateState(tid: Int): Unit = + val remOps = remainingOps(tid) + if remOps == 0 then + liveThreads -= tid + else + remainingOps += (tid -> (remOps - 1)) + val schedule = rands.foldLeft(List[Int]()) { + case (acc, r) if contextSwitches < contextSwitchBound => + val tid = liveThreads(r.nextInt(liveThreads.size)) + contexts match + case prev :: tail if prev != tid => // we have a new context switch here + contexts +:= tid + contextSwitches += 1 + case prev :: tail => + case _ => // init case + contexts +:= tid + updateState(tid) + acc :+ tid + case (acc, _) => // here context-bound has been reached so complete the schedule without any more context switches + if !contexts.isEmpty then + contexts = contexts.dropWhile(remainingOps(_) == 0) + val tid = contexts match + case top :: tail => top + case _ => liveThreads(0) // here, there has to be threads that have not even started + updateState(tid) + acc :+ tid + } + schedule #:: schedules() diff --git a/src/test/scala/instrumentation/TestUtils.scala b/src/test/scala/instrumentation/TestUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..49328b2fd6c2a9ce303a301eaacb2af3ee4a457b --- /dev/null +++ b/src/test/scala/instrumentation/TestUtils.scala @@ -0,0 +1,33 @@ +package instrumentation + +import scala.concurrent._ +import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global +import org.junit.Assert.* + +object TestUtils: + def failsOrTimesOut[T](action: => T): Boolean = + val asyncAction = Future { + action + } + try + Await.result(asyncAction, 2000.millisecond) + catch + case _: Throwable => return true + return false + + def assertDeadlock[T](action: => T): Unit = + try + action + throw new AssertionError("No error detected.") + catch + case e: AssertionError => + assert(e.getMessage.contains("Deadlock"), "No deadlock detected.") + + def assertMaybeDeadlock[T](action: => T): Unit = + try + action + throw new AssertionError("No error detected.") + catch + case e: AssertionError => + assert(e.getMessage.contains("A possible deadlock!"), "No deadlock detected.") diff --git a/src/test/scala/midterm22/Mock2Test.scala b/src/test/scala/midterm22/Mock2Test.scala new file mode 100644 index 0000000000000000000000000000000000000000..fe3a2cc2c04b72c55c09d15bcbbb9a3baca743c1 --- /dev/null +++ b/src/test/scala/midterm22/Mock2Test.scala @@ -0,0 +1,29 @@ +package midterm22 + +import org.junit.* +import org.junit.Assert.* +import instrumentation.* + +class Mock2Test: + @Test + def test() = + TestUtils.assertDeadlock( + TestHelper.testManySchedules( + 2, + scheduler => + val a = new ScheduledAccount(50, scheduler) + val b = new ScheduledAccount(70, scheduler) + + ( + List( + () => a.transfer(b, 10), + () => b.transfer(a, 10) + ), + results => (true, "") + ) + ) + ) + + class ScheduledAccount(n: Int, val scheduler: Scheduler) + extends Account(n) + with MockedMonitor diff --git a/src/test/scala/midterm22/Part1Test.scala b/src/test/scala/midterm22/Part1Test.scala new file mode 100644 index 0000000000000000000000000000000000000000..6e20c617f4d6ab26b819b277a65c43a6436fd5f4 --- /dev/null +++ b/src/test/scala/midterm22/Part1Test.scala @@ -0,0 +1,26 @@ +package midterm22 + +import org.junit.* +import org.junit.Assert.* + +class Part1Test: + val testArray = + Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19) + + @Test + def testQuestion1Pos() = + val tasksCreatedBefore = tasksCreated.get + assertEquals(Some(18), find(testArray, 18, 3)) + assertEquals(10, tasksCreated.get - tasksCreatedBefore) + + @Test + def testQuestion1Neg() = + assertEquals(find(testArray, 20, 3), None) + + @Test + def testQuestion2Pos(): Unit = + assertEquals(findAggregated(testArray, 18), Some(18)) + + @Test + def testQuestion2Neg(): Unit = + assertEquals(findAggregated(testArray, 20), None) diff --git a/src/test/scala/midterm22/Part2Test.scala b/src/test/scala/midterm22/Part2Test.scala new file mode 100644 index 0000000000000000000000000000000000000000..e6bcb8fe0df96686cbd9439c68cea44c90ad9e19 --- /dev/null +++ b/src/test/scala/midterm22/Part2Test.scala @@ -0,0 +1,24 @@ +package midterm22 + +import org.junit.* +import org.junit.Assert.* + +class Part2Test: + val testArray2 = Array(0, 50, 7, 1, 28, 42) + val testList2 = List(0, 50, 7, 1, 28, 42) + + @Test + def testQuestion4Pos() = + assert(contains(testArray2, 7)) + + @Test + def testQuestion4Neg() = + assert(!contains(testArray2, 8)) + + @Test + def testQuestion6Pos() = + assert(contains(testList2, 7)) + + @Test + def testQuestion6Neg() = + assert(!contains(testList2, 8)) diff --git a/src/test/scala/midterm22/Part4Test.scala b/src/test/scala/midterm22/Part4Test.scala new file mode 100644 index 0000000000000000000000000000000000000000..7de8eb21d130210bae3deb51d1443894de8546f3 --- /dev/null +++ b/src/test/scala/midterm22/Part4Test.scala @@ -0,0 +1,250 @@ +package midterm22 + +import instrumentation.Monitor +import instrumentation.MockedMonitor + +import org.junit.* +import org.junit.Assert.* +import instrumentation.* +import java.util.concurrent.atomic.AtomicInteger + +class Part4Test: + + // This test can result in a deadlock because locks can be called in any + // order. Here, Thread 1 locks Node 3 first and then Node 2, whereas Thread 2 + // locks Node 2 first and then Node 3. This will lead to a deadlock. + @Test + def testQuestion9() = + TestUtils.assertDeadlock( + TestHelper.testManySchedules( + 2, + scheduler => + val allNodes = (for i <- 0 to 6 yield ScheduledNode(i, scheduler)).toList + + // Shared by all threads + var sum: Int = 0 + def increment(e: Int) = sum += e + + ( + List( + () => + // Thread 1 + var nodes: List[Node] = List(allNodes(1), allNodes(3), allNodes(2), allNodes(4)) + nodes = nodes + lockFun(nodes, increment), + () => + // Thread 2 + var nodes: List[Node] = List(allNodes(5), allNodes(2), allNodes(3)) + nodes = nodes + lockFun(nodes, increment), + ), + results => (true, "") + ) + ) + ) + + // This will not lead to a deadlock because the lock acquire happens in a + // particular order. Thread 1 acquires locks in order 1->2->3->4, whereas + // Thread 2 acquires locks in order 2->3->5. + @Test + def testQuestion10() = + TestHelper.testManySchedules( + 2, + scheduler => + val allNodes = (for i <- 0 to 6 yield ScheduledNode(i, scheduler)).toList + + // Shared by all threads + var sum: Int = 0 + def increment(e: Int) = sum += e + + ( + List( + () => + // Thread 1 + var nodes: List[Node] = List(allNodes(1), allNodes(3), allNodes(2), allNodes(4)) + nodes = nodes.sortWith((x, y) => x.guid > y.guid) + lockFun(nodes, increment), + () => + // Thread 2 + var nodes: List[Node] = List(allNodes(5), allNodes(2), allNodes(3)) + nodes = nodes.sortWith((x, y) => x.guid > y.guid) + lockFun(nodes, increment), + ), + results => (true, "") + ) + ) + + + // This will not lead to a deadlock because the lock acquire happens in a + // particular order. Thread 1 acquires locks in order 4->3->2->1, whereas + // Thread 2 acquires locks in order 5->3->2. + @Test + def testQuestion11() = + TestHelper.testManySchedules( + 2, + scheduler => + val allNodes = (for i <- 0 to 6 yield ScheduledNode(i, scheduler)).toList + + // Shared by all threads + var sum: Int = 0 + def increment(e: Int) = sum += e + + ( + List( + () => + // Thread 1 + var nodes: List[Node] = List(allNodes(1), allNodes(3), allNodes(2), allNodes(4)) + nodes = nodes.sortWith((x, y) => x.guid < y.guid) + lockFun(nodes, increment), + () => + // Thread 2 + var nodes: List[Node] = List(allNodes(5), allNodes(2), allNodes(3)) + nodes = nodes.sortWith((x, y) => x.guid < y.guid) + lockFun(nodes, increment), + ), + results => (true, "") + ) + ) + + // This test can result in a deadlock because locks are not called in any + // order. Thread 1 acquire order (3->2->4->1), Thread 2 acquire order + // (2->3->5). Thread 1 locks Node3 first and then Node2, whereas Thread 2 + // locks Node 2 first and then Node3. This will lead to a deadlock. + @Test + def testQuestion12() = + TestUtils.assertDeadlock( + TestHelper.testManySchedules( + 2, + scheduler => + val allNodes = (for i <- 0 to 6 yield ScheduledNode(i, scheduler)).toList + + // Shared by all threads + var sum: Int = 0 + def increment(e: Int) = sum += e + + ( + List( + () => + // Thread 1 + var nodes: List[Node] = List(allNodes(1), allNodes(3), allNodes(2), allNodes(4)) + nodes = nodes.tail.appended(nodes(0)) + lockFun(nodes, increment), + () => + // Thread 2 + var nodes: List[Node] = List(allNodes(5), allNodes(2), allNodes(3)) + nodes = nodes.tail.appended(nodes(0)) + lockFun(nodes, increment), + ), + results => (true, "") + ) + ) + ) + + // sum returns wrong answer because there is a data race on the sum variable. + @Test(expected = classOf[AssertionError]) + def testQuestion13() = + TestHelper.testManySchedules( + 2, + scheduler => + val allNodes = (for i <- 0 to 6 yield ScheduledNode(i, scheduler)).toList + + // Shared by all threads + var sum: Int = 0 + def increment(e: Int) = + val previousSum = scheduler.exec{sum}("Get sum") + scheduler.exec{sum = previousSum + e}("Write sum") + + ( + List( + () => + // Thread 1 + var nodes: List[Node] = List(allNodes(1), allNodes(3), allNodes(2), allNodes(4)) + nodes = nodes.sortWith((x, y) => x.guid < y.guid) + lockFun(nodes, increment), + () => + // Thread 2 + var nodes: List[Node] = List(allNodes(5), allNodes(2), allNodes(3)) + nodes = nodes.sortWith((x, y) => x.guid < y.guid) + lockFun(nodes, increment), + ), + results => + if sum != 20 then + (false, f"Wrong sum: expected 20 but got $sum") + else + (true, "") + ) + ) + + // sum value will be correct here because "sum += e" is protected by a lock. + @Test + def testQuestion14() = + TestHelper.testManySchedules( + 2, + sched => + val allNodes = (for i <- 0 to 6 yield ScheduledNode(i, sched)).toList + + val monitor = new MockedMonitor: // Monitor is a type of a lock. + def scheduler = sched + + // Shared by all threads + var sum: Int = 0 + def increment(e: Int) = + monitor.synchronized { sum += e } + + ( + List( + () => + // Thread 1 + var nodes: List[Node] = List(allNodes(1), allNodes(3), allNodes(2), allNodes(4)) + nodes = nodes.sortWith((x, y) => x.guid < y.guid) + lockFun(nodes, increment), + () => + // Thread 2 + var nodes: List[Node] = List(allNodes(5), allNodes(2), allNodes(3)) + nodes = nodes.sortWith((x, y) => x.guid < y.guid) + lockFun(nodes, increment), + ), + results => + if sum != 20 then + (false, f"Wrong sum: expected 20 but got $sum") + else + (true, "") + ) + ) + + // total will give correct output here as it is an atomic instruction. + @Test + def testQuestion15() = + TestHelper.testManySchedules( + 2, + sched => + val allNodes = (for i <- 0 to 6 yield ScheduledNode(i, sched)).toList + + // Shared by all threads + var total: AtomicInteger = new AtomicInteger(0) + def increment(e: Int) = + total.addAndGet(e) + + ( + List( + () => + // Thread 1 + var nodes: List[Node] = List(allNodes(1), allNodes(3), allNodes(2), allNodes(4)) + nodes = nodes.sortWith((x, y) => x.guid < y.guid) + lockFun(nodes, increment), + () => + // Thread 2 + var nodes: List[Node] = List(allNodes(5), allNodes(2), allNodes(3)) + nodes = nodes.sortWith((x, y) => x.guid < y.guid) + lockFun(nodes, increment), + ), + results => + if total.get != 20 then + (false, f"Wrong total: expected 20 but got $total") + else + (true, "") + ) + ) + + + class ScheduledNode(value: Int, val scheduler: Scheduler) extends Node(value) with MockedMonitor diff --git a/src/test/scala/midterm22/Part6Test.scala b/src/test/scala/midterm22/Part6Test.scala new file mode 100644 index 0000000000000000000000000000000000000000..8b961de46ce3092112f686c539c9923e1b019d32 --- /dev/null +++ b/src/test/scala/midterm22/Part6Test.scala @@ -0,0 +1,33 @@ +package midterm22 + +import org.junit.* +import org.junit.Assert.* +import instrumentation.* + +class Part6Test: + @Test(expected = classOf[AssertionError]) + def testQuestion21() = + TestHelper.testManySchedules( + 2, + scheduler => + val ticketsManager = ScheduledTicketsManager(1, scheduler) + + ( + List( + () => + // Thread 1 + ticketsManager.getTicket(), + () => + // Thread 2 + ticketsManager.getTicket() + ), + results => + if ticketsManager.remainingTickets < 0 then + (false, "Sold more tickets than available!") + else (true, "") + ) + ) + + class ScheduledTicketsManager(totalTickets: Int, val scheduler: Scheduler) + extends TicketsManager(totalTickets) + with MockedMonitor diff --git a/src/test/scala/midterm22/Part7Test.scala b/src/test/scala/midterm22/Part7Test.scala new file mode 100644 index 0000000000000000000000000000000000000000..5726270866668ce71221a690c99e1c65a8a8f32a --- /dev/null +++ b/src/test/scala/midterm22/Part7Test.scala @@ -0,0 +1,89 @@ +package midterm22 + +import org.junit.* +import org.junit.Assert.* +import instrumentation.* + +class Part7Test: + @Test + def testNicManagerSequential() = + val nicsManager = NICManager(4) + assertEquals((0, 1), nicsManager.assignNICs()) + assertEquals((2, 3), nicsManager.assignNICs()) + + @Test + def testQuestion22() = + testNicManagerParallel(2, 3) + + @Test + def testQuestion23() = + val nicsManager = NICManager(2) + + // Thread 1 + assertEquals((0, 1), nicsManager.assignNICs()) + nicsManager.nics(0).assigned = false + nicsManager.nics(1).assigned = false + + // Thread 2 + assertEquals((0, 1), nicsManager.assignNICs()) + nicsManager.nics(0).assigned = false + nicsManager.nics(1).assigned = false + + @Test + def testQuestion24() = + testNicManagerParallel(3, 2, true) + + @Test + def testQuestion24NotLimitingRecvNICs() = + TestUtils.assertMaybeDeadlock( + testNicManagerParallel(3, 2) + ) + + def testNicManagerParallel( + threads: Int, + nics: Int, + limitRecvNICs: Boolean = false + ) = + TestHelper.testManySchedules( + threads, + scheduler => + val nicsManager = ScheduledNicsManager(nics, scheduler) + val tasks = for i <- 0 until threads yield () => + // Thread i + val (recvNIC, sendNIC) = nicsManager.assignNICs(limitRecvNICs) + + // Do something with NICs... + + // Un-assign NICs + nicsManager.nics(recvNIC).assigned = false + nicsManager.nics(sendNIC).assigned = false + ( + tasks.toList, + results => + if nicsManager.nics.count(_.assigned) != 0 then + (false, f"All NICs should have been released.") + else (true, "") + ) + ) + + class ScheduledNicsManager(n: Int, scheduler: Scheduler) + extends NICManager(n): + class ScheduledNIC( + _index: Int, + _assigned: Boolean, + val scheduler: Scheduler + ) extends NIC(_index, _assigned) + with MockedMonitor: + override def index = scheduler.exec { super.index }( + "", + Some(res => f"read NIC.index == $res") + ) + override def assigned = scheduler.exec { super.assigned }( + "", + Some(res => f"read NIC.assigned == $res") + ) + override def assigned_=(v: Boolean) = scheduler.exec { super.assigned = v }( + f"write NIC.assigned = $v" + ) + override val nics = + (for i <- 0 until n yield ScheduledNIC(i, false, scheduler)).toList diff --git a/src/test/scala/midterm22/Part8Test.scala b/src/test/scala/midterm22/Part8Test.scala new file mode 100644 index 0000000000000000000000000000000000000000..0a7e22a25b80f8edfc5901426604a64dbea5b996 --- /dev/null +++ b/src/test/scala/midterm22/Part8Test.scala @@ -0,0 +1,380 @@ +package midterm22 + +import org.junit.* +import org.junit.Assert.* +import instrumentation.* +import annotation.nowarn + +import scala.collection.concurrent.TrieMap +import scala.collection.concurrent.{TrieMap, Map} + +class Part8Test: + @Test + def usage() = + val insta = Instagram() + assertEquals(1, insta.add()) + assertEquals(2, insta.add()) + insta.follow(1, 2) + assertEquals(insta.graph, Map(1 -> List(2), 2 -> List())) + insta.follow(2, 1) + insta.unfollow(1, 2) + assertEquals(insta.graph, Map(1 -> List(), 2 -> List(1))) + insta.follow(3, 1) // fails silently + assertEquals(insta.graph, Map(1 -> List(), 2 -> List(1))) + insta.remove(1) + assertEquals(insta.graph, Map(2 -> List())) + insta.unfollow(1, 2) // fails silently + + @Test + def testParallelFollowABRemoveA() = + TestHelper.testManySchedules( + 2, + scheduler => + val insta = new Instagram: + override val graph = + ScheduledTrieMap(TrieMap[Int, List[Int]](), scheduler) + + val u1 = insta.add() + val u2 = insta.add() + + ( + List( + () => + // Thread 1 + insta.follow(u1, u2), + () => + // Thread 2 + insta.remove(u1) + ), + results => + val size = insta.graph.size + if size != 1 then + (false, f"Wrong number of user: expected 1 but got ${size}") + else validateGraph(insta) + ) + ) + + @Test + def testParallelFollowABRemoveB() = + TestHelper.testManySchedules( + 2, + scheduler => + val insta = new Instagram: + override val graph = + ScheduledTrieMap(TrieMap[Int, List[Int]](), scheduler) + + val u1 = insta.add() + val u2 = insta.add() + + ( + List( + () => + // Thread 1 + insta.follow(u1, u2), + () => + // Thread 2 + insta.remove(u2) + ), + results => + val size = insta.graph.size + if size != 1 then + (false, f"Wrong number of user: expected 1 but got ${size}") + else validateGraph(insta) + ) + ) + + @Test + def testParallelFollowACRemoveB() = + TestHelper.testManySchedules( + 2, + scheduler => + val insta = new Instagram: + override val graph = + ScheduledTrieMap(TrieMap[Int, List[Int]](), scheduler) + + val u1 = insta.add() + val u2 = insta.add() + val u3 = insta.add() + insta.follow(u1, u2) + + ( + List( + () => + // Thread 1 + insta.follow(u1, u3), + () => + // Thread 2 + insta.remove(u2) + ), + results => + val size = insta.graph.size + if size != 2 then + (false, f"Wrong number of user: expected 2 but got ${size}") + else validateGraph(insta) + ) + ) + + @Test + def testParallelFollow() = + TestHelper.testManySchedules( + 2, + scheduler => + val insta = new Instagram: + override val graph = + ScheduledTrieMap(TrieMap[Int, List[Int]](), scheduler) + + val u1 = insta.add() + val u2 = insta.add() + val u3 = insta.add() + + ( + List( + () => + // Thread 1 + insta.follow(u1, u2), + () => + // Thread 2 + insta.follow(u1, u3) + ), + results => + val u1FollowingSize = insta.graph(u1).size + if u1FollowingSize != 2 then + ( + false, + f"Wrong number of users followed by user 1: expected 2 but got ${u1FollowingSize}" + ) + else validateGraph(insta) + ) + ) + + @Test + def testParallelRemove() = + TestHelper.testManySchedules( + 2, + scheduler => + val insta = new Instagram: + override val graph = + ScheduledTrieMap(TrieMap[Int, List[Int]](), scheduler) + + // Setup + val u1 = insta.add() + val u2 = insta.add() + val u3 = insta.add() + insta.follow(u1, u2) + insta.follow(u2, u1) + insta.follow(u2, u3) + insta.follow(u3, u1) + + ( + List( + () => + // Thread 1 + insta.remove(u2), + () => + // Thread 2 + insta.remove(u3) + ), + results => + val size = insta.graph.size + if size != 1 then + (false, f"Wrong number of user: expected 1 but got ${size}") + else validateGraph(insta) + ) + ) + + // We test wrong code here, so we expect an assertion error. You can replace + // the next line by `@Test` if you want to see the error with the failing + // schedule. + @Test(expected = classOf[AssertionError]) + def testParallelWrongAdd() = + TestHelper.testManySchedules( + 2, + scheduler => + val insta = new Instagram: + override val graph = + ScheduledTrieMap(TrieMap[Int, List[Int]](), scheduler) + + // This implementation of `add` is wrong, because two threads might + // allocate the same id. + // Consider the following schedule: + // T1: res = 1 + // T2: res = 2 + // T2: graph.update(2, Nil) + // T2: 2 + // T1: graph.update(2, Nil) + // T1: 2 + override def add(): Int = + val res = maxId.incrementAndGet + graph.update(maxId.get, Nil) + res + + ( + List( + () => + // Thread 1 + insta.add(), + () => + // Thread 2 + insta.add() + ), + results => + if results(0) != results(1) then + (false, f"Allocated twice id ${results(0)}") + else validateGraph(insta) + ) + ) + + // We test wrong code here, so we expect an assertion error. You can replace + // the next line by `@Test` if you want to see the error with the failing + // schedule. + @Test(expected = classOf[AssertionError]) + def testParallelWrongRemove() = + TestHelper.testManySchedules( + 2, + scheduler => + val insta = new Instagram: + override val graph = + ScheduledTrieMap(TrieMap[Int, List[Int]](), scheduler) + + // This implementation of `remove` is wrong because we don't retry to + // call `graph.replace` when it fails. Therefore, user 1 might end up + // following user 2 that has been removed, or not following user 3 + // which is concurrently followed. + override def remove(idToRemove: Int): Unit = + graph.remove(idToRemove) + for (key, value) <- graph do + graph.replace(key, value, value.filter(_ != idToRemove)) + // Note: writing `graph(key) = value.filter(_ != idToRemove)` would also + // be wrong because it does not check the previous value. + // Therefore, it could erase a concurrent update. + + val u1 = insta.add() + val u2 = insta.add() + val u3 = insta.add() + insta.follow(u1, u2) + + ( + List( + () => + // Thread 1 + insta.follow(u1, u3), + () => + // Thread 2 + insta.remove(u2) + ), + results => + val size = insta.graph.size + if insta.graph(u1).size != 1 then + (false, f"Wrong number of users followed by 1: expected 1 but got ${insta.graph(u1)}") + else validateGraph(insta) + ) + ) + + // We test wrong code here, so we expect an assertion error. You can replace + // the next line by `@Test` if you want to see the error with the failing + // schedule. + @Test(expected = classOf[AssertionError]) + def testParallelWrongUnfollow() = + TestHelper.testManySchedules( + 2, + scheduler => + val insta = new Instagram: + override val graph = + ScheduledTrieMap(TrieMap[Int, List[Int]](), scheduler) + override def unfollow(a: Int, b: Int): Unit = + if !graph.contains(a) then return + val prev = graph(a) // Might throw java.util.NoSuchElementException + if !graph.replace(a, prev, prev.filter(_ != b)) then unfollow(a, b) + + val u1 = insta.add() + val u2 = insta.add() + insta.follow(u1, u2) + + ( + List( + () => + // Thread 1 + insta.unfollow(u1, u2), + () => + // Thread 2 + insta.remove(u1) + ), + results => + val size = insta.graph.size + if size != 1 then + (false, f"Wrong number of user: expected 1 but got ${size}") + else validateGraph(insta) + ) + ) + + @nowarn + def validateGraph(insta: Instagram): (Boolean, String) = + for (a, following) <- insta.graph; b <- following do + if !insta.graph.contains(b) then + return (false, f"User $a follows non-existing user $b") + (true, "") + + final class ScheduledIterator[T]( + private val myIterator: Iterator[T], + private val scheduler: Scheduler + ) extends Iterator[T]: + override def hasNext = + myIterator.hasNext + override def next() = + scheduler.exec(myIterator.next)("", Some(res => f"Iterator.next == $res")) + override def knownSize: Int = + myIterator.knownSize + + final class ScheduledTrieMap[K, V]( + private val myMap: Map[K, V], + private val scheduler: Scheduler + ) extends Map[K, V]: + override def apply(key: K): V = + scheduler.exec(myMap(key))( + "", + Some(res => f"TrieMap.apply($key) == $res") + ) + override def contains(key: K): Boolean = + scheduler.exec(myMap.contains(key))( + "", + Some(res => f"TrieMap.contains($key) == $res") + ) + override def get(key: K): Option[V] = + scheduler.exec(myMap.get(key))( + "", + Some(res => f"TrieMap.get($key) == $res") + ) + override def addOne(kv: (K, V)) = + scheduler.exec(myMap.addOne(kv))(f"TrieMap.addOne($kv)") + this + override def subtractOne(k: K) = + scheduler.exec(myMap.subtractOne(k))(f"TrieMap.subtractOne($k)") + this + override def iterator() = + scheduler.log("TrieMap.iterator") + ScheduledIterator(myMap.iterator, scheduler) + override def replace(k: K, v: V): Option[V] = + scheduler.exec(myMap.replace(k, v))( + "", + Some(res => f"TrieMap.replace($k, $v) == $res") + ) + override def replace(k: K, oldvalue: V, newvalue: V): Boolean = + scheduler.exec(myMap.replace(k, oldvalue, newvalue))( + "", + Some(res => f"TrieMap.replace($k, $oldvalue, $newvalue) == $res") + ) + override def putIfAbsent(k: K, v: V): Option[V] = + scheduler.exec(myMap.putIfAbsent(k, v))( + "", + Some(res => f"TrieMap.putIfAbsent($k, $v)") + ) + override def remove(k: K): Option[V] = + scheduler.exec(myMap.remove(k))( + "", + Some(res => f"TrieMap.remove($k)") + ) + override def remove(k: K, v: V): Boolean = + scheduler.exec(myMap.remove(k, v))( + "", + Some(res => f"TrieMap.remove($k, $v)") + )