diff --git a/README.md b/README.md index 4f29493f127a3471bfe627e4fd1e89dc166738ca..e1965f4c5c821f20158ff36e1f15eff5a8a1a8b1 100644 --- a/README.md +++ b/README.md @@ -46,12 +46,18 @@ Do include your own ratings in your final submission so we can check your answer ## Compute predictions ```` -> sbt "runMain predict.Predictor --train data/ml-100k/u1.base --test data/ml-100k/u1.test --json answers.json" +> sbt "runMain stats.Analyzer --data data/ml-100k/u.data --json statistics.json" +```` + +## Compute predictions + +```` +> sbt "runMain predict.Predictor --train data/ml-100k/u1.base --test data/ml-100k/u1.test --json predictions.json" ```` ## Compute recommendations ```` -> sbt 'runMain recommend.Recommender' +> sbt "runMain recommend.Recommender --data data/ml-100k/u.data --personal data/personal.csv --json recommendations.json" ```` ## Package for submission diff --git a/data/personal.csv b/data/personal.csv index 2beace652b5bec5999548c8b4b9977d9280e2473..091077d557e4d0caa2c1089ea8a0b009c872d15d 100644 --- a/data/personal.csv +++ b/data/personal.csv @@ -1,4 +1,3 @@ -id,title, 1,Toy Story (1995), 2,GoldenEye (1995), 3,Four Rooms (1995), diff --git a/src/main/scala/recommend/Recommender.scala b/src/main/scala/recommend/Recommender.scala index 13f5f6ed34bf7ef0ed45476fba010915894e12a6..27461544d75e54b8c4cf63ba2955222380af995f 100644 --- a/src/main/scala/recommend/Recommender.scala +++ b/src/main/scala/recommend/Recommender.scala @@ -1,6 +1,83 @@ package recommend +import org.rogach.scallop._ +import org.json4s.jackson.Serialization +import org.apache.spark.rdd.RDD + +import org.apache.spark.sql.SparkSession +import org.apache.log4j.Logger +import org.apache.log4j.Level + +class Conf(arguments: Seq[String]) extends ScallopConf(arguments) { + val data = opt[String](required = true) + val personal = opt[String](required = true) + val json = opt[String]() + verify() +} + +case class Rating(user: Int, item: Int, rating: Double) + object Recommender extends App { - println("Computing recommendations ...") - println("Done") + // Remove these lines if encountering/debugging Spark + Logger.getLogger("org").setLevel(Level.OFF) + Logger.getLogger("akka").setLevel(Level.OFF) + val spark = SparkSession.builder() + .master("local[1]") + .getOrCreate() + spark.sparkContext.setLogLevel("ERROR") + + println("") + println("******************************************************") + + var conf = new Conf(args) + println("Loading data from: " + conf.data()) + val dataFile = spark.sparkContext.textFile(conf.data()) + val data = dataFile.map(l => { + val cols = l.split("\t").map(_.trim) + Rating(cols(0).toInt, cols(1).toInt, cols(2).toDouble) + }) + assert(data.count == 100000, "Invalid data") + + println("Loading personal data from: " + conf.personal()) + val personalFile = spark.sparkContext.textFile(conf.personal()) + // TODO: Extract ratings and movie titles + assert(personalFile.count == 1682, "Invalid personal data") + + + + // Save answers as JSON + def printToFile(content: String, + location: String = "./answers.json") = + Some(new java.io.PrintWriter(location)).foreach{ + f => try{ + f.write(content) + } finally{ f.close } + } + conf.json.toOption match { + case None => ; + case Some(jsonFile) => { + var json = ""; + { + // Limiting the scope of implicit formats with {} + implicit val formats = org.json4s.DefaultFormats + val answers: Map[String, Any] = Map( + "4.1.1" -> List[Any]( + List(0,"Tron", 5.0), + List(0,"Tron", 5.0), + List(0,"Tron", 5.0), + List(0,"Tron", 5.0), + List(0,"Tron", 5.0) + ) + ) + json = Serialization.writePretty(answers) + } + + println(json) + println("Saving answers in: " + jsonFile) + printToFile(json, jsonFile) + } + } + + println("") + spark.close() }