diff --git a/README.md b/README.md index f0de341ec8eb4fd02aca097a86ef7bcfec502fa4..568e065546a69039d8e2a98f6c0166ed6c288414 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ Do include your own ratings in your final submission so we can check your answer ## Compute predictions ```` -> sbt 'runMain predict.Predictor' +> sbt "runMain predict.Predictor --train data/ml-100k/u1.base --test data/ml-100k/u1.test --json answers.json" ```` ## Compute recommendations diff --git a/build.sbt b/build.sbt index d84bf536f28231379715f23f9b92962750eabab6..762b925b91254009ff709de0fb3a1c1bb7973c48 100644 --- a/build.sbt +++ b/build.sbt @@ -1,7 +1,13 @@ -libraryDependencies += "org.scalatest" %% "scalatest" % "3.2.0" % Test -scalaVersion in ThisBuild := "2.13.3" -enablePlugins(JavaAppPackaging) - name := "m1_yourid" version := "1.0" maintainer := "your.name@epfl.ch" + +libraryDependencies += "org.scalatest" %% "scalatest" % "3.2.0" % Test +libraryDependencies += "org.rogach" %% "scallop" % "4.0.2" +libraryDependencies += "org.json4s" %% "json4s-jackson" % "3.6.10" +libraryDependencies += "org.apache.spark" %% "spark-core" % "3.0.0" +libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.0.0" + +scalaVersion in ThisBuild := "2.12.13" +enablePlugins(JavaAppPackaging) + diff --git a/src/main/scala/predict/Predictor.scala b/src/main/scala/predict/Predictor.scala index 0c132f5935548ee2cf35af4a0367652f82902e4a..d9b871a9fedc8a99cc6d2fb26255527ba1a2e97d 100644 --- a/src/main/scala/predict/Predictor.scala +++ b/src/main/scala/predict/Predictor.scala @@ -1,6 +1,83 @@ package predict +import org.rogach.scallop._ +import org.json4s.jackson.Serialization +import org.apache.spark.rdd.RDD + +import org.apache.spark.sql.SparkSession +import org.apache.log4j.Logger +import org.apache.log4j.Level + +class Conf(arguments: Seq[String]) extends ScallopConf(arguments) { + val train = opt[String](required = true) + val test = opt[String](required = true) + val json = opt[String]() + verify() +} + +case class Rating(user: Int, item: Int, rating: Double) + object Predictor extends App { - println("Computing predictions ...") - println("Done") + // Remove these lines if encountering/debugging Spark + Logger.getLogger("org").setLevel(Level.OFF) + Logger.getLogger("akka").setLevel(Level.OFF) + val spark = SparkSession.builder() + .master("local[1]") + .getOrCreate() + spark.sparkContext.setLogLevel("ERROR") + + println("") + println("******************************************************") + + var conf = new Conf(args) + println("Loading training data from: " + conf.train()) + val trainFile = spark.sparkContext.textFile(conf.train()) + val train = trainFile.map(l => { + val cols = l.split("\t").map(_.trim) + Rating(cols(0).toInt, cols(1).toInt, cols(2).toDouble) + }) + assert(train.count == 80000, "Invalid training data") + + println("Loading test data from: " + conf.test()) + val testFile = spark.sparkContext.textFile(conf.test()) + val test = testFile.map(l => { + val cols = l.split("\t").map(_.trim) + Rating(cols(0).toInt, cols(1).toInt, cols(2).toDouble) + }) + assert(test.count == 20000, "Invalid test data") + + val globalPred = 3.0 + val globalMae = test.map(r => scala.math.abs(r.rating - globalPred)).reduce(_+_) / test.count.toDouble + + // Save answers as JSON + def printToFile(content: String, + location: String = "./answers.json") = + Some(new java.io.PrintWriter(location)).foreach{ + f => try{ + f.write(content) + } finally{ f.close } + } + conf.json.toOption match { + case None => ; + case Some(jsonFile) => { + var json = ""; + { + // Limiting the scope of implicit formats with {} + implicit val formats = org.json4s.DefaultFormats + val answers: Map[String, Any] = Map( + "3.1.4" -> Map( + "global-mae" -> globalMae + ) + ) + json = Serialization.writePretty(answers) + } + + println(json) + println("Saving answers in: " + jsonFile) + printToFile(json, jsonFile) + } + } + + println("") + spark.close() }