spark自带的ALS算法实现协同过滤

2018-10-26  本文已影响0人  匪_3f3e

环境:spark1.6.0 scala2.11.4
使用的数据集是tpch数据集

第一步进行文件的读取,将读取到的dataframe注册成table;(如果存到了hive上可以直接使用hiveContext进行数据处理)
第二步利用sqlContext查询出(用户,他们所购买的商品),因为没有评分信息,所以评分都默认10分;
第三步拆分训练集,进行模型训练;
第四步利用训练好的模型给测试集进行商品推荐
MyColleborativeFilter.scala

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.recommendation.ALS
import org.apache.spark.sql.{Dataset, SQLContext}

object MyColleborativeFilter {
  case class Customer(id: Int, name: String, address: String, nation: String, phone: String, mktsegment:String,comment:String)
  case class Order(id:Int,customer:String,status:String,totalPrice:Double,date:String,priority:String,clerk:String,shipPriority:Double,comment:String)
  case class LineItem(orders:Int,part:Int)


  def main(args: Array[String]): Unit = {

    val path=args(0)

    //获取sparkSession
    val conf = new SparkConf()//.setAppName("MyRs").setMaster("local")
    //val sparkSession = new SparkSession(conf)
   // sparkSession.sparkContext.setLogLevel("WARN")
    //获取context
    val sparkContext =new  SparkContext(conf)
    val sqlContext = new SQLContext(sparkContext)


    import sqlContext.implicits._

    //读取三个表的数据
    val customerDf = sparkContext.textFile(path+"/customer/customer.tbl")
      .map(_.split("\\|"))
      .map(u => Customer(u(0).toInt, u(1), u(2), u(3), u(4), u(5), u(6))).toDF()
    customerDf.show()
    customerDf.registerTempTable("customer")

    val orderDf=sparkContext.textFile(path+"/orders/orders.tbl")
      .map(_.split("\\|"))
      .map(u=>Order(u(0).toInt, u(1), u(2), u(3).toDouble, u(4), u(5), u(6),u(7).toDouble,u(8)))
      .toDF()
    orderDf.registerTempTable("orders")
    orderDf.show()

    val itemlineDf=sparkContext.textFile(path+"/lineitem/lineitem.tbl")
        .map(_.split("\\|"))
        .map(u=>LineItem(u(0).toInt,u(1).toInt))
        .toDF()
    itemlineDf.registerTempTable("itemline")
    //利用sparksql查询数据
    val customerPartDf=sqlContext
      .sql("SELECT c.id customer,i.part part FROM customer c,orders o,itemline i " +
        "WHERE c.id=o.customer and o.id=i.orders")
    //增加评分,默认10
    val resultDf=customerPartDf.withColumn("rating",customerPartDf("customer")*0+10.0)
    resultDf.show()

    //生成测试,训练集
    val Array(traing,test) = resultDf.randomSplit(Array(0.8,0.2))
    //进行模型训练
    val als = new ALS()
      .setMaxIter(1)
      .setUserCol("customer")
      .setItemCol("part")
      .setRatingCol("rating")
      .setRegParam(0.01)//正则化参数
    val model = als.fit(traing)
    //model.setColdStartStrategy("drop")
    //model.write.overwrite().save("./ColleborativeFilterModle")

  //得出测试集的推荐结果
    val predictions = model.transform(test)
    predictions.show(false)
    //spark2.3.0之后可以用如下代码进行推荐
    //model.recommendForUserSubset(user1,10).show(false)
    //model.recommendForAllUsers(10)

    sparkContext.stop()

  }

}

建议把启动命令写成一个脚本,操作起来会更加的方便,master 可以根据自己的需要进行指定,测试可以用local模式,提交到集群上需要用standlone或yarn-client模式,/program/lxf/sql-1.0.jar 是你scala程序jar包的位置,再后面是我的程序需要传入的参数
start-yarn.sh

spark-submit  \
--master yarn-client \       
--class com.example.spark.MyColleborativeFilter  \
 /program/lxf/sql-1.0.jar  \       
 hdfs://10.77.20.23:8020/tpch
上一篇下一篇

猜你喜欢

热点阅读