package instrumentation import scala.util.Random import scala.collection.mutable.{Map as MutableMap} import Stats.* object TestHelper: val noOfSchedules = 10000 // set this to 100k during deployment val readWritesPerThread = 20 // maximum number of read/writes possible in one thread val contextSwitchBound = 10 val testTimeout = 240 // the total time out for a test in seconds val schedTimeout = 15 // the total time out for execution of a schedule in secs // Helpers /*def testManySchedules(op1: => Any): Unit = testManySchedules(List(() => op1)) def testManySchedules(op1: => Any, op2: => Any): Unit = testManySchedules(List(() => op1, () => op2)) def testManySchedules(op1: => Any, op2: => Any, op3: => Any): Unit = testManySchedules(List(() => op1, () => op2, () => op3)) def testManySchedules(op1: => Any, op2: => Any, op3: => Any, op4: => Any): Unit = testManySchedules(List(() => op1, () => op2, () => op3, () => op4))*/ def testSequential[T](ops: Scheduler => Any)(assertions: T => ( Boolean, String )) = testManySchedules( 1, (sched: Scheduler) => ( List(() => ops(sched)), (res: List[Any]) => assertions(res.head.asInstanceOf[T]) ) ) /** @numThreads * number of threads * @ops * operations to be executed, one per thread * @assertion * as condition that will executed after all threads have completed * (without exceptions) the arguments are the results of the threads */ def testManySchedules( numThreads: Int, ops: Scheduler => ( List[() => Any], // Threads List[Any] => (Boolean, String) ) // Assertion ) = var timeout = testTimeout * 1000L val threadIds = (1 to numThreads) // (1 to scheduleLength).flatMap(_ => threadIds).toList.permutations.take(noOfSchedules).foreach { val schedules = (new ScheduleGenerator(numThreads)).schedules() var schedsExplored = 0 schedules.takeWhile(_ => schedsExplored <= noOfSchedules && timeout > 0 ).foreach { // case _ if timeout <= 0 => // break case schedule => schedsExplored += 1 val schedr = new Scheduler(schedule) // println("Exploring Sched: "+schedule) val (threadOps, assertion) = ops(schedr) if threadOps.size != numThreads then throw new IllegalStateException( s"Number of threads: $numThreads, do not match operations of threads: $threadOps" ) timed { schedr.runInParallel(schedTimeout * 1000, threadOps) } { t => timeout -= t } match case Timeout(msg) => throw new java.lang.AssertionError( "assertion failed\n" + "The schedule took too long to complete. A possible deadlock! \n" + msg ) case Except(msg, stkTrace) => val traceStr = "Thread Stack trace: \n" + stkTrace.map( " at " + _.toString ).mkString("\n") throw new java.lang.AssertionError( "assertion failed\n" + msg + "\n" + traceStr ) case RetVal(threadRes) => // check the assertion val (success, custom_msg) = assertion(threadRes) if !success then val msg = "The following schedule resulted in wrong results: \n" + custom_msg + "\n" + schedr.getOperationLog().mkString( "\n" ) throw new java.lang.AssertionError("Assertion failed: " + msg) } if timeout <= 0 then throw new java.lang.AssertionError( "Test took too long to complete! Cannot check all schedules as your code is too slow!" ) /** A schedule generator that is based on the context bound */ class ScheduleGenerator(numThreads: Int): val scheduleLength = readWritesPerThread * numThreads val rands = (1 to scheduleLength).map(i => new Random(0xcafe * i) ) // random numbers for choosing a thread at each position def schedules(): LazyList[List[Int]] = var contextSwitches = 0 var contexts = List[Int]() // a stack of thread ids in the order of context-switches val remainingOps = MutableMap[Int, Int]() remainingOps ++= (1 to numThreads).map(i => (i, readWritesPerThread) ) // num ops remaining in each thread val liveThreads = (1 to numThreads).toSeq.toBuffer /** Updates remainingOps and liveThreads once a thread is chosen for a * position in the schedule */ def updateState(tid: Int): Unit = val remOps = remainingOps(tid) if remOps == 0 then liveThreads -= tid else remainingOps += (tid -> (remOps - 1)) val schedule = rands.foldLeft(List[Int]()) { case (acc, r) if contextSwitches < contextSwitchBound => val tid = liveThreads(r.nextInt(liveThreads.size)) contexts match case prev :: tail if prev != tid => // we have a new context switch here contexts +:= tid contextSwitches += 1 case prev :: tail => case _ => // init case contexts +:= tid updateState(tid) acc :+ tid case ( acc, _ ) => // here context-bound has been reached so complete the schedule without any more context switches if !contexts.isEmpty then contexts = contexts.dropWhile(remainingOps(_) == 0) val tid = contexts match case top :: tail => top case _ => liveThreads( 0 ) // here, there has to be threads that have not even started updateState(tid) acc :+ tid } schedule #:: schedules()