fp_growth频繁项集和关联规则Spark ML调用实现
2020-11-22 本文已影响0人
xiaogp
摘要:关联规则
,置信度
,支持度
,提升度
,规则集
,数据挖掘
,Spark
关联规则
关联规则是基于统计的无监督学习方法,它基于序列挖掘频繁出现因素组合的模式,进而可以推断出如果出现了A,B,还可能出现C的规则,可以使用的场景包括二分类中需要找到规则集,在推荐中做关联推荐等。
关联规则的研究对象是事件序列,目的是找到频繁事件组合(项集),用支持度
来衡量出现的频数强度,一个频繁项集内部也分为前项
和后项
,为了描述前项的出现推断后项的能力强弱引出置信度
,即等于在前项出现的情况下后项出现的比例,再引出提升度
,即因为前项的出现导致后项比随机出现概率提升的倍数。
Spark ML代码实现
算法接受DataFrame输入,指定输入序列字段,支持度,置信度,freqItemsets
输出频繁项集,associationRules
输出关联规则,序列字段由groupBy
+collect_list
构造得到,相当于将一个对象的所有元素聚合成一个序列,transform
可以对新dataframe做预测,推荐频繁项集内未出现的元素
import org.apache.spark.sql.functions._
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.fpm.FPGrowth
import org.apache.spark.ml.fpm.FPGrowthModel
object FpGrowthExample {
val spark: SparkSession = SparkSession.builder().appName("FpGrowthExample").master("yarn").getOrCreate()
import spark.implicits._
def main(args: Array[String]): Unit = {
val df = spark.read.format("csv").option("header", true).load("/user/test/data.txt")
// 过滤热门词
val df2 = df.filter(!$"label_value".isin("其他", "有登记联系方式")).distinct()
df2.cache()
val ts = df2.groupBy("ent_name").agg(collect_list("label_value").alias("label_value_list"))
// 定义模型阈值
val model = new FPGrowth()
.setItemsCol("label_value_list")
.setMinConfidence(0.5)
.setMinSupport(0.1)
.fit(ts)
model.write.overwrite().save("/user/test/SparkMLModel/fpgrowth")
// 载入模型
val model2 = FPGrowthModel.load("/user/test/SparkMLModel/fpgrowth")
// 查看频繁项集
val freq = model2.freqItemsets
// 只查看组合多于1个的项集
val myfunc1 = udf((x: Any) => {
val tmp = x.asInstanceOf[scala.collection.mutable.WrappedArray[String]]
tmp.size > 1
})
val freq2 = freq.filter(myfunc1($"items"))
// 查看置信度(关联规则)
val conf = model2.associationRules
// 输出格式
val myfunc2 = udf((x: Any) => x.asInstanceOf[scala.collection.mutable.WrappedArray[String]](0))
val myfunc3 = udf((x: Any) => x.asInstanceOf[scala.collection.mutable.WrappedArray[String]].mkString("+"))
val conf2 = conf.withColumn("antecedent", myfunc3($"antecedent")).withColumn("consequent", myfunc2($"consequent"))
// 加入提升度
val df3 = df2.groupBy("label_value").count()
val count = df2.select($"ent_name").distinct().count()
val df4 = df3.withColumn("base", $"count" / count)
val conf3 = conf2.join(df4, $"consequent" === $"label_value", "left")
val conf4 = conf3.withColumn("lift", $"confidence" / $"base").select($"antecedent", $"consequent", $"confidence", $"lift")
.sort($"confidence".desc)
val freq3 = freq2.withColumn("items", myfunc3($"items"))
// 输出
conf4.repartition(1).write.format("csv").mode("overwrite").save("/user/test/fpgrowth/conf")
freq3.repartition(1).write.format("csv").mode("overwrite").save("/user/test/fpgrowth/freq")
conf4.show(10, false)
spark.stop()
}
}
频繁项集:显示序列和频数
频繁项集.png
关联规则:其中antecedent
代表前项,consequent
代表后项,confidence
和lift
分别是置信度和提升度