Newer
Older
import java.util.concurrent.*;
import scala.concurrent.duration.*
import scala.collection.mutable.*
import Stats.*
import java.util.concurrent.atomic.AtomicInteger
sealed abstract class Result
case class RetVal(rets: List[Any]) extends Result
case class Except(msg: String, stackTrace: Array[StackTraceElement])
extends Result
/** A class that maintains schedule and a set of thread ids. The schedules are
* advanced after an operation of a SchedulableBuffer is performed. Note: the
* real schedule that is executed may deviate from the input schedule due to
* the adjustments that had to be made for locks
*/
val maxOps =
500 // a limit on the maximum number of operations the code is allowed to perform
private var schedule = sched
private var numThreads = 0
private val realToFakeThreadId = Map[Long, Int]()
private val opLog =
ListBuffer[String]() // a mutable list (used for efficient concat)
private val threadStates = Map[Int, ThreadState]()
/** Runs a set of operations in parallel as per the schedule. Each operation
* may consist of many primitive operations like reads or writes to shared
* data structure each of which should be executed using the function `exec`.
* @timeout
* in milliseconds
* @return
* true - all threads completed on time, false -some tests timed out.
*/
def runInParallel(timeout: Long, ops: List[() => Any]): Result =
numThreads = ops.length
val threadRes = Array.fill(numThreads) { None: Any }
var exception: Option[Except] = None
val syncObject = new Object()
var completed = new AtomicInteger(0)
// create threads
val threads = ops.zipWithIndex.map {
case (op, i) =>
updateThreadState(Start)
val res = op()
updateThreadState(End)
threadRes(i) = res
// notify the main thread if all threads have completed
catch
case e: Throwable
if exception != None => // do nothing here and silently fail
case e: Throwable =>
log(s"throw ${e.toString}")
exception = Some(Except(
s"Thread $fakeId crashed on the following schedule: \n" + opLog.mkString(
"\n"
),
e.getStackTrace
))
// println(s"$fakeId: ${e.toString}")
// Runtime.getRuntime().halt(0) //exit the JVM and all running threads (no other way to kill other threads)
)
}
// start all threads
threads.foreach(_.start())
// wait for all threads to complete, or for an exception to be thrown, or for the time out to expire
var remTime = timeout
syncObject.synchronized {
timed { if completed.get() != ops.length then syncObject.wait(timeout) } {
time => remTime -= time
}
else if remTime <= 1
then // timeout ? using 1 instead of zero to allow for some errors
Timeout(opLog.mkString("\n"))
else
// every thing executed normally
RetVal(threadRes.toList)
// Updates the state of the current thread
def updateThreadState(state: ThreadState): Unit =
val tid = threadId
synchronized {
threadStates(tid) = state
}
state match
case Sync(lockToAquire, locks) =>
// Re-aqcuiring the same lock
updateThreadState(Running(lockToAquire +: locks))
case Start => waitStart()
case End => removeFromSchedule(tid)
case Running(_) =>
case _ => waitForTurn // Wait, SyncUnique, VariableReadWrite
def waitStart(): Unit =
// while (threadStates.size < numThreads) {
// Thread.sleep(1)
// }
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
synchronized {
if threadStates.size < numThreads then
wait()
else
notifyAll()
}
def threadLocks =
synchronized {
threadStates(threadId).locks
}
def threadState =
synchronized {
threadStates(threadId)
}
def mapOtherStates(f: ThreadState => ThreadState) =
val exception = threadId
synchronized {
for k <- threadStates.keys if k != exception do
threadStates(k) = f(threadStates(k))
}
def log(str: String) =
if (realToFakeThreadId contains Thread.currentThread().getId()) then
val space = (" " * ((threadId - 1) * 2))
val s =
space + threadId + ":" + "\n".r.replaceAllIn(str, "\n" + space + " ")
// println(s)
/** Executes a read or write operation to a global data structure as per the
* given schedule
* @param msg
* a message corresponding to the operation that will be logged
*/
def exec[T](primop: => T)(
msg: => String,
postMsg: => Option[T => String] = None
): T =
if !(realToFakeThreadId contains Thread.currentThread().getId()) then
primop
else
updateThreadState(VariableReadWrite(threadLocks))
val m = msg
if m != "" then log(m)
if opLog.size > maxOps then
throw new Exception(
s"Total number of reads/writes performed by threads exceed $maxOps. A possible deadlock!"
)
val res = primop
postMsg match
case Some(m) => log(m(res))
res
private def setThreadId(fakeId: Int) = synchronized {
realToFakeThreadId(Thread.currentThread.getId) = fakeId
}
def threadId =
try
realToFakeThreadId(Thread.currentThread().getId())
catch
case e: NoSuchElementException =>
throw new Exception(
"You are accessing shared variables in the constructor. This is not allowed. The variables are already initialized!"
)
private def isTurn(tid: Int) = synchronized {
(!schedule.isEmpty && schedule.head != tid)
}
def canProceed(): Boolean =
val tid = threadId
canContinue match
case Some((i, state)) if i == tid =>
case Sync(lockToAquire, locks) =>
updateThreadState(Running(lockToAquire +: locks))
case SyncUnique(lockToAquire, locks) =>
mapOtherStates {
_ match
case SyncUnique(lockToAquire2, locks2)
if lockToAquire2 == lockToAquire =>
Wait(lockToAquire2, locks2)
case e => e
}
updateThreadState(Running(lockToAquire +: locks))
case VariableReadWrite(locks) => updateThreadState(Running(locks))
true
case Some((i, state)) =>
var threadPreference =
0 // In the case the schedule is over, which thread should have the preference to execute.
/** returns true if the thread can continue to execute, and false otherwise */
def decide(): Option[(Int, ThreadState)] =
if !threadStates.isEmpty
then // The last thread who enters the decision loop takes the decision.
// println(s"$threadId: I'm taking a decision")
if threadStates.values.forall {
case e: Wait => true
case _ => false
}
then
val waiting = threadStates.keys.map(_.toString).mkString(", ")
val s = if threadStates.size > 1 then "s" else ""
val are = if threadStates.size > 1 then "are" else "is"
throw new Exception(
s"Deadlock: Thread$s $waiting $are waiting but all others have ended and cannot notify them."
)
else
// Threads can be in Wait, Sync, SyncUnique, and VariableReadWrite mode.
// Let's determine which ones can continue.
val notFree = threadStates.collect { case (id, state) =>
state.locks
}.flatten.toSet
val threadsNotBlocked = threadStates.toSeq.filter {
case (id, v: VariableReadWrite) => true
case (id, v: CanContinueIfAcquiresLock) =>
!notFree(v.lockToAquire) || (v.locks contains v.lockToAquire)
case _ => false
}
if threadsNotBlocked.isEmpty then
val waiting = threadStates.keys.map(_.toString).mkString(", ")
val s = if threadStates.size > 1 then "s" else ""
val are = if threadStates.size > 1 then "are" else "is"
val whoHasLock = threadStates.toSeq.flatMap { case (id, state) =>
state.locks.map(lock => (lock, id))
}.toMap
case (id, state: CanContinueIfAcquiresLock)
if !notFree(state.lockToAquire) =>
s"Thread $id is waiting on lock ${state.lockToAquire} held by thread ${whoHasLock(state.lockToAquire)}"
}.mkString("\n")
throw new Exception(
s"Deadlock: Thread$s $waiting are interlocked. Indeed:\n$reason"
)
else if threadsNotBlocked.size == 1
then // Do not consume the schedule if only one thread can execute.
val next = schedule.indexWhere(t =>
threadsNotBlocked.exists { case (id, state) => id == t }
)
// println(s"$threadId: schedule is $schedule, next chosen is ${schedule(next)}")
val chosenOne =
schedule(next) // TODO: Make schedule a mutable list.
schedule = schedule.take(next) ++ schedule.drop(next + 1)
Some((chosenOne, threadStates(chosenOne)))
else
threadPreference = (threadPreference + 1) % threadsNotBlocked.size
val chosenOne =
threadsNotBlocked(threadPreference) // Maybe another strategy
/*
val tnb = threadsNotBlocked.map(_._1).mkString(",")
val s = if (schedule.isEmpty) "empty" else schedule.mkString(",")
val only = if (schedule.isEmpty) "" else " only"
throw new Exception(s"The schedule is $s but$only threads ${tnb} can continue")*/
else canContinue
/** This will be called before a schedulable operation begins. This should not
* use synchronized
*/
// var waitingForDecision = Map[Int, Option[Int]]() // Mapping from thread ids to a number indicating who is going to make the choice.
var canContinue: Option[(Int, ThreadState)] =
None // The result of the decision thread Id of the thread authorized to continue.
private def waitForTurn =
synchronized {
if numThreadsWaiting.incrementAndGet() == threadStates.size then
canContinue = decide()
notifyAll()
// waitingForDecision(threadId) = Some(numThreadsWaiting)
// println(s"$threadId Entering waiting with ticket number $numThreadsWaiting/${waitingForDecision.size}")
while !canProceed() do wait()
}
numThreadsWaiting.decrementAndGet()
private def removeFromSchedule(fakeid: Int) = synchronized {
schedule = schedule.filterNot(_ == fakeid)
threadStates -= fakeid
if numThreadsWaiting.get() == threadStates.size then
canContinue = decide()
notifyAll()
}
def getOperationLog() = opLog