Scala RDD实现mapreduce矩阵相乘

2020-04-02  本文已影响0人  锦绣拾年

小矩阵相乘,《智能与并行程序设计》一课作业
Spark scala RDD
主要思路 键值:应该是目标矩阵的位置。
矩阵相乘,是先乘 后加
所以对于矩阵1 映射(目标矩阵坐标,列),值
对于矩阵2映射 (目标矩阵坐标,行),值
然后reduce,同样的相乘(因为目标矩阵坐标相同,矩阵1列=矩阵2行)

这时就需要再相加了
去掉键值中的第三个元素 行or列
然后reduce 相加

ex.jpg

import org.apache.spark.SparkConf
import org.apache.spark.rdd._
import org.apache.spark.SparkContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.SparkSession
import scala.util.parsing.json.JSON
import org.apache.spark.sql.{ DataFrame, Dataset, SparkSession, Row }
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.functions._
import java.util.Calendar
import org.apache.spark.SparkContext
import scala.collection.mutable.ListBuffer
import org.apache.spark.sql.functions
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.mllib.recommendation.{ALS,Rating,MatrixFactorizationModel}
import org.apache.log4j.Logger
import org.apache.log4j.Level
import java.io.File
object MatrixM {
    def main(args: Array[String]) {
//      val conf = new SparkConf()
          val conf = new SparkConf().setMaster("local").setAppName("MatrixM");//本地调试
      val sc = new SparkContext(conf)
//      val mt1 = sc.textFile(args(0));
//      val mt2 = sc.textFile(args(1))
      val mt1 = sc.textFile("./mt1.txt");
      val mt2 = sc.textFile("./mt2.txt");
      val index=3
      val col=2
      val mt1nums = mt1.map(_.split(" ").take(3).map(_.toInt))//1 1 0->1 1行,  a肯定 1列 0值
      val mt2nums = mt2.map(_.split(" ").take(3).map(_.toInt))
//      val mt1nums = mt1values.map{case Array(user,movie,rating)=>Rating(user.toInt,movie.toInt,rating.toDouble)}
      mt1nums.collect().foreach(println)
      mt1nums.collect()(0).foreach(println)
      var pairs=mt1nums.map(x=>((x(0),1,x(1)),x(2)))
      for(i<-2 to col){
        pairs=pairs.union(mt1nums.map(x=>((x(0),i,x(1)),x(2))))

      }

      for(i<-1 to index){
          pairs=pairs.union(mt2nums.map(x=>((i,x(1),x(0)),x(2))))
      }
      
      pairs=pairs.reduceByKey((x,y)=>x*y)
      pairs=pairs.sortByKey()
      pairs.collect().toArray.foreach(println)
      var newpairs=pairs.map(x=>((x._1._1,x._1._2),x._2))
      newpairs=newpairs.sortByKey()
      newpairs.collect().toArray.foreach(println)
      newpairs=newpairs.reduceByKey((x,y)=>x+y)

      newpairs=newpairs.sortByKey()

      newpairs.collect().foreach(println)

      newpairs.saveAsTextFile("output")
       sc.stop() 
  
    }
}
上一篇下一篇

猜你喜欢

热点阅读