Skip to content
Snippets Groups Projects
Commit 87b29750 authored by Erick Lavoie's avatar Erick Lavoie
Browse files

Added file reading and prediction

parent 667ce209
No related branches found
No related tags found
No related merge requests found
......@@ -46,7 +46,7 @@ Do include your own ratings in your final submission so we can check your answer
## Compute predictions
````
> sbt 'runMain predict.Predictor'
> sbt "runMain predict.Predictor --train data/ml-100k/u1.base --test data/ml-100k/u1.test --json answers.json"
````
## Compute recommendations
......
libraryDependencies += "org.scalatest" %% "scalatest" % "3.2.0" % Test
scalaVersion in ThisBuild := "2.13.3"
enablePlugins(JavaAppPackaging)
name := "m1_yourid"
version := "1.0"
maintainer := "your.name@epfl.ch"
libraryDependencies += "org.scalatest" %% "scalatest" % "3.2.0" % Test
libraryDependencies += "org.rogach" %% "scallop" % "4.0.2"
libraryDependencies += "org.json4s" %% "json4s-jackson" % "3.6.10"
libraryDependencies += "org.apache.spark" %% "spark-core" % "3.0.0"
libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.0.0"
scalaVersion in ThisBuild := "2.12.13"
enablePlugins(JavaAppPackaging)
package predict
import org.rogach.scallop._
import org.json4s.jackson.Serialization
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.log4j.Logger
import org.apache.log4j.Level
class Conf(arguments: Seq[String]) extends ScallopConf(arguments) {
val train = opt[String](required = true)
val test = opt[String](required = true)
val json = opt[String]()
verify()
}
case class Rating(user: Int, item: Int, rating: Double)
object Predictor extends App {
println("Computing predictions ...")
println("Done")
// Remove these lines if encountering/debugging Spark
Logger.getLogger("org").setLevel(Level.OFF)
Logger.getLogger("akka").setLevel(Level.OFF)
val spark = SparkSession.builder()
.master("local[1]")
.getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
println("")
println("******************************************************")
var conf = new Conf(args)
println("Loading training data from: " + conf.train())
val trainFile = spark.sparkContext.textFile(conf.train())
val train = trainFile.map(l => {
val cols = l.split("\t").map(_.trim)
Rating(cols(0).toInt, cols(1).toInt, cols(2).toDouble)
})
assert(train.count == 80000, "Invalid training data")
println("Loading test data from: " + conf.test())
val testFile = spark.sparkContext.textFile(conf.test())
val test = testFile.map(l => {
val cols = l.split("\t").map(_.trim)
Rating(cols(0).toInt, cols(1).toInt, cols(2).toDouble)
})
assert(test.count == 20000, "Invalid test data")
val globalPred = 3.0
val globalMae = test.map(r => scala.math.abs(r.rating - globalPred)).reduce(_+_) / test.count.toDouble
// Save answers as JSON
def printToFile(content: String,
location: String = "./answers.json") =
Some(new java.io.PrintWriter(location)).foreach{
f => try{
f.write(content)
} finally{ f.close }
}
conf.json.toOption match {
case None => ;
case Some(jsonFile) => {
var json = "";
{
// Limiting the scope of implicit formats with {}
implicit val formats = org.json4s.DefaultFormats
val answers: Map[String, Any] = Map(
"3.1.4" -> Map(
"global-mae" -> globalMae
)
)
json = Serialization.writePretty(answers)
}
println(json)
println("Saving answers in: " + jsonFile)
printToFile(json, jsonFile)
}
}
println("")
spark.close()
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment