Skip to content
Snippets Groups Projects
Commit d1b6cf9e authored by Parshikov Tikhon's avatar Parshikov Tikhon
Browse files

Optimize argtopk

parent 90c9c71c
No related branches found
No related tags found
No related merge requests found
...@@ -12,11 +12,11 @@ ...@@ -12,11 +12,11 @@
"2.k10u1v864": 0.24232304952129619, "2.k10u1v864": 0.24232304952129619,
"3.k10u1v886": 0, "3.k10u1v886": 0,
"4.PredUser1Item1": 4.319093503763853, "4.PredUser1Item1": 4.319093503763853,
"5.PredUser327Item2": 2.6994178006921192, "5.PredUser327Item2": 2.9542350642575563,
"6.Mae": 0.8287277961963542 "6.Mae": 0.8311356402216432
}, },
"BR.2": { "BR.2": {
"average (ms)": 2748.0986, "average (ms)": 6474.2568,
"stddev (ms)": 0 "stddev (ms)": 0
} }
} }
\ No newline at end of file
...@@ -176,17 +176,13 @@ package object predictions ...@@ -176,17 +176,13 @@ package object predictions
//take k nearest users by similarity excluding itself //take k nearest users by similarity excluding itself
def knn(user : Int, k: Int, similarities : DenseMatrix[Double]) : Array[Int] = { def knn(user : Int, k: Int, similarities : DenseMatrix[Double]) : Array[Int] = {
val row = similarities(user, ::)
val row_index = row.t.toArray.zipWithIndex
val row_sorted = row_index.sortBy(- _._1).take(k+1)
//first element is itself so take the tail //first element is itself so take the tail
return row_sorted.map(_._2).tail return argtopk(similarities(::,user),k+1).toArray.tail
} }
def predict_knn_similarity(similarities : DenseMatrix[Double], k : Int): DenseMatrix[Double] = { def predict_knn_similarity(similarities : DenseMatrix[Double], k : Int): DenseMatrix[Double] = {
for (x <- 0 to similarities.rows - 1){ for (x <- 0 to similarities.rows - 1){
val index_user = knn(x, k, similarities) val index_user = knn(x, k, similarities)
for(y <- 0 to similarities.cols - 1) { for(y <- 0 to similarities.cols - 1) {
// if y is not an index of the largest similarity coefficients it is changed to 0 // if y is not an index of the largest similarity coefficients it is changed to 0
if (!index_user.contains(y)) similarities(x,y) = 0.0 if (!index_user.contains(y)) similarities(x,y) = 0.0
...@@ -232,8 +228,7 @@ package object predictions ...@@ -232,8 +228,7 @@ package object predictions
iterator.map(user=>{ iterator.map(user=>{
val ratings = broadcast.value val ratings = broadcast.value
val similarities = ratings * ratings.t(::,user) val similarities = ratings * ratings.t(::,user)
val row_index = similarities.toArray.zipWithIndex val res=argtopk(similarities,k+1).toArray.tail
val res = row_index.sortBy(- _._1).take(k+1).map(_._2).tail
val res_form = res.map(user_2 => (user,user_2,similarities(user_2))) val res_form = res.map(user_2 => (user,user_2,similarities(user_2)))
res_form})}).collect() res_form})}).collect()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment