spark自带的ALS算法实现协同过滤
2018-10-26 本文已影响0人
匪_3f3e
环境:spark1.6.0 scala2.11.4
使用的数据集是tpch数据集
第一步进行文件的读取,将读取到的dataframe注册成table;(如果存到了hive上可以直接使用hiveContext进行数据处理)
第二步利用sqlContext查询出(用户,他们所购买的商品),因为没有评分信息,所以评分都默认10分;
第三步拆分训练集,进行模型训练;
第四步利用训练好的模型给测试集进行商品推荐
MyColleborativeFilter.scala
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.recommendation.ALS
import org.apache.spark.sql.{Dataset, SQLContext}
object MyColleborativeFilter {
case class Customer(id: Int, name: String, address: String, nation: String, phone: String, mktsegment:String,comment:String)
case class Order(id:Int,customer:String,status:String,totalPrice:Double,date:String,priority:String,clerk:String,shipPriority:Double,comment:String)
case class LineItem(orders:Int,part:Int)
def main(args: Array[String]): Unit = {
val path=args(0)
//获取sparkSession
val conf = new SparkConf()//.setAppName("MyRs").setMaster("local")
//val sparkSession = new SparkSession(conf)
// sparkSession.sparkContext.setLogLevel("WARN")
//获取context
val sparkContext =new SparkContext(conf)
val sqlContext = new SQLContext(sparkContext)
import sqlContext.implicits._
//读取三个表的数据
val customerDf = sparkContext.textFile(path+"/customer/customer.tbl")
.map(_.split("\\|"))
.map(u => Customer(u(0).toInt, u(1), u(2), u(3), u(4), u(5), u(6))).toDF()
customerDf.show()
customerDf.registerTempTable("customer")
val orderDf=sparkContext.textFile(path+"/orders/orders.tbl")
.map(_.split("\\|"))
.map(u=>Order(u(0).toInt, u(1), u(2), u(3).toDouble, u(4), u(5), u(6),u(7).toDouble,u(8)))
.toDF()
orderDf.registerTempTable("orders")
orderDf.show()
val itemlineDf=sparkContext.textFile(path+"/lineitem/lineitem.tbl")
.map(_.split("\\|"))
.map(u=>LineItem(u(0).toInt,u(1).toInt))
.toDF()
itemlineDf.registerTempTable("itemline")
//利用sparksql查询数据
val customerPartDf=sqlContext
.sql("SELECT c.id customer,i.part part FROM customer c,orders o,itemline i " +
"WHERE c.id=o.customer and o.id=i.orders")
//增加评分,默认10
val resultDf=customerPartDf.withColumn("rating",customerPartDf("customer")*0+10.0)
resultDf.show()
//生成测试,训练集
val Array(traing,test) = resultDf.randomSplit(Array(0.8,0.2))
//进行模型训练
val als = new ALS()
.setMaxIter(1)
.setUserCol("customer")
.setItemCol("part")
.setRatingCol("rating")
.setRegParam(0.01)//正则化参数
val model = als.fit(traing)
//model.setColdStartStrategy("drop")
//model.write.overwrite().save("./ColleborativeFilterModle")
//得出测试集的推荐结果
val predictions = model.transform(test)
predictions.show(false)
//spark2.3.0之后可以用如下代码进行推荐
//model.recommendForUserSubset(user1,10).show(false)
//model.recommendForAllUsers(10)
sparkContext.stop()
}
}
建议把启动命令写成一个脚本,操作起来会更加的方便,master 可以根据自己的需要进行指定,测试可以用local模式,提交到集群上需要用standlone或yarn-client模式,/program/lxf/sql-1.0.jar 是你scala程序jar包的位置,再后面是我的程序需要传入的参数
start-yarn.sh
spark-submit \
--master yarn-client \
--class com.example.spark.MyColleborativeFilter \
/program/lxf/sql-1.0.jar \
hdfs://10.77.20.23:8020/tpch