数据挖掘

fp_growth频繁项集和关联规则Spark ML调用实现

2020-11-22  本文已影响0人  xiaogp

摘要:关联规则置信度支持度提升度规则集数据挖掘Spark

关联规则

关联规则是基于统计的无监督学习方法,它基于序列挖掘频繁出现因素组合的模式,进而可以推断出如果出现了A,B,还可能出现C的规则,可以使用的场景包括二分类中需要找到规则集,在推荐中做关联推荐等。
关联规则的研究对象是事件序列,目的是找到频繁事件组合(项集),用支持度来衡量出现的频数强度,一个频繁项集内部也分为前项后项,为了描述前项的出现推断后项的能力强弱引出置信度,即等于在前项出现的情况下后项出现的比例,再引出提升度,即因为前项的出现导致后项比随机出现概率提升的倍数。

Spark ML代码实现

算法接受DataFrame输入,指定输入序列字段,支持度,置信度,freqItemsets输出频繁项集,associationRules输出关联规则,序列字段由groupBy+collect_list构造得到,相当于将一个对象的所有元素聚合成一个序列,transform可以对新dataframe做预测,推荐频繁项集内未出现的元素

import org.apache.spark.sql.functions._
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.fpm.FPGrowth
import org.apache.spark.ml.fpm.FPGrowthModel

object FpGrowthExample {
  val spark: SparkSession = SparkSession.builder().appName("FpGrowthExample").master("yarn").getOrCreate()
  import spark.implicits._

  def main(args: Array[String]): Unit = {
    val df = spark.read.format("csv").option("header", true).load("/user/test/data.txt")
    // 过滤热门词
    val df2 = df.filter(!$"label_value".isin("其他", "有登记联系方式")).distinct()
    df2.cache()
    val ts = df2.groupBy("ent_name").agg(collect_list("label_value").alias("label_value_list"))

    // 定义模型阈值
    val model = new FPGrowth()
      .setItemsCol("label_value_list")
      .setMinConfidence(0.5)
      .setMinSupport(0.1)
      .fit(ts)

    model.write.overwrite().save("/user/test/SparkMLModel/fpgrowth")

    // 载入模型
    val model2 = FPGrowthModel.load("/user/test/SparkMLModel/fpgrowth")

    // 查看频繁项集
    val freq = model2.freqItemsets
    // 只查看组合多于1个的项集
    val myfunc1 = udf((x: Any) => {
      val tmp = x.asInstanceOf[scala.collection.mutable.WrappedArray[String]]
      tmp.size > 1
    })
    val freq2 = freq.filter(myfunc1($"items"))

    // 查看置信度(关联规则)
    val conf = model2.associationRules

    // 输出格式
    val myfunc2 = udf((x: Any) => x.asInstanceOf[scala.collection.mutable.WrappedArray[String]](0))
    val myfunc3 = udf((x: Any) => x.asInstanceOf[scala.collection.mutable.WrappedArray[String]].mkString("+"))
    val conf2 = conf.withColumn("antecedent", myfunc3($"antecedent")).withColumn("consequent", myfunc2($"consequent"))
    // 加入提升度
    val df3 = df2.groupBy("label_value").count()
    val count = df2.select($"ent_name").distinct().count()
    val df4 = df3.withColumn("base", $"count" / count)
    val conf3 = conf2.join(df4, $"consequent" === $"label_value", "left")
    val conf4 = conf3.withColumn("lift", $"confidence" / $"base").select($"antecedent", $"consequent", $"confidence", $"lift")
      .sort($"confidence".desc)
    val freq3 = freq2.withColumn("items", myfunc3($"items"))

    // 输出
    conf4.repartition(1).write.format("csv").mode("overwrite").save("/user/test/fpgrowth/conf")
    freq3.repartition(1).write.format("csv").mode("overwrite").save("/user/test/fpgrowth/freq")

    conf4.show(10, false)
    spark.stop()
  }
}

频繁项集:显示序列和频数


频繁项集.png

关联规则:其中antecedent代表前项,consequent代表后项,confidencelift分别是置信度和提升度

关联规则.png
上一篇下一篇

猜你喜欢

热点阅读