spark SQL 2.用户自定义聚合函数(UDAF)

2021-05-20  本文已影响0人  caster

Spark 在 org.apache.spark.sql.functions 类中定义了 DataFrame 的常见聚合函数,如count(),countDistinct(),avg(),max(),min() 等。虽然这些功能是专门为 DataFrame 设计的,Spark SQL 也在 java 和 scala 中设计了其中一些函数的类型安全版本用于强类型限制的Dataset。此外,用户还可以创建自己的聚合函数。

1. 适用于未定义类型的DataFrame的UDFA

后续版本不再支持(3.0),SQL建议使用强类型即DataSet

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class AggTest1 extends UserDefinedAggregateFunction{

  // 聚合函数输入数据结构
  override def inputSchema: StructType = StructType(StructField("input", IntegerType)::Nil)
  // 缓存区数据结构,用于计算
  override def bufferSchema: StructType = StructType(StructField("sum", IntegerType)::StructField("count", IntegerType)::Nil)
  // 聚合函数输出值数据结构
  override def dataType: DataType = DoubleType

  // 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
  override def deterministic: Boolean = true
  
  // 初始化缓冲区
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
    buffer(1) = 0
    //buffer.update(0,0)
    //buffer.update(1,0)
  }
  // 给聚合函数传入一条新数据进行处理
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (input.isNullAt(0)) return
    buffer(0) = buffer.getInt(0) + input.getInt(0)
    buffer(1) = buffer.getInt(1) + 1
    //buffer.update(0,buffer.getInt(0) + input.getInt(0))
    //buffer.update(1,buffer.getInt(1) + 1)
  }
  // 合并聚合函数缓冲区(分布式)
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0)
    buffer1(1) = buffer1.getInt(1) + buffer2.getInt(1)
    //buffer1.update(0,buffer1.getInt(0) + buffer2.getInt(0))
    //buffer1.update(1,buffer1.getInt(1) + buffer2.getInt(1))
  }
  // 计算最终结果
  override def evaluate(buffer: Row): Any = {
    buffer.getInt(0).toDouble / buffer.getInt(1)
  }
}

测试方法如下:

def main(args: Array[String]): Unit = {
  val sc = SparkSession.builder().appName("test").master("local").getOrCreate()
  val lines = sc.sparkContext.parallelize(Array("i love china","i love usa"))
  //进行分词操作
  //:RDD[String]
  val words = lines.flatMap(_.split(" "))
  //Map操作:每个单词记一次数
  //:RDD[(String,Int)]
  val wordPair = words.map((_,1))
  //key相同单词计数合并
  //:RDD[(String,Int)]
  val wordCount = wordPair.reduceByKey(_+_)
  //RDD转为Row类型
  val rowRdd:RDD[Row] = wordCount.map(wc => Row(wc._1, wc._2))
  val schema = StructType(Array(StructField("word", StringType, nullable = true),StructField("count", IntegerType, nullable = true)))
  //RDD[Row]转为DataFrame
  val frame = sc.createDataFrame(rowRdd, schema)
  //注册视图用于sql查询
  frame.createOrReplaceTempView("res")
  //注册UDFA
  sc.udf.register("avg", new AggTest1())
  val results = sc.sql("select avg(count) from res")
  results.show()
  sc.stop()
}

2. 适用于强类型限制的Dataset的UDFA

条理清晰,字段直观调用,不需要buffer(i)方式

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator

//定义in buffer out类型
//早起版本in设置为列对象类型,新版本支持SQL查询,in设置为需要聚合字段类型
object AggTest2 extends Aggregator[Words,Average,Double]{
  //初始化缓冲区buffer
  override def zero: Average = Average(0,0)
  //处理一条新的记录,聚合后更新缓冲区
  override def reduce(b: Average, a: Words): Average = {
    b.sum += a.count
    b.count += 1
    b
  }
  //合并多个buffer
  override def merge(b1: Average, b2: Average): Average = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
  }
  //计算平均值
  override def finish(reduction: Average): Double = {
    reduction.sum.toDouble / reduction.count
  }
  //缓冲区编码器
  override def bufferEncoder: Encoder[Average] = Encoders.product
  //输出编码器
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
//缓冲器数据类型,参数值需要可改
case class Average(var sum: Int, var count: Int) {
}
//输入DataSet的数据类型
case class Words(word: String, count: Int) {
}

测试方法如下:

def main(args: Array[String]): Unit = {
  val sc = SparkSession.builder().appName("test").master("local").getOrCreate()
  val lines = sc.sparkContext.parallelize(Array("i love china","i love usa"))
  //进行分词操作
  //:RDD[String]
  val words = lines.flatMap(_.split(" "))
  //Map操作:每个单词记一次数
  //:RDD[(String,Int)]
  val wordPair = words.map((_,1))
  //key相同单词计数合并
  //:RDD[(String,Int)]
  val wordCount = wordPair.reduceByKey(_+_)
  //RDD转为Row类型
  val rowRdd:RDD[Row] = wordCount.map(wc => Row(wc._1, wc._2))
  val schema = StructType(Array(StructField("word", StringType, nullable = true),StructField("count", IntegerType, nullable = true)))
  //用于隐式转化toDF,toDS,as[]等
  import sc.implicits._
  //RDD[Row]转为DataFrame
  val ds = sc.createDataFrame(rowRdd, schema).as[Words]

  //1. 旧版本用法,早起强类型UDAF只能用DSL
  ds.select(AggTest2.toColumn.name("avg")).show()
  //2. 新版本用法,可以将强类型UDAF使用于SQL
  ds.createOrReplaceTempView("res")
  sc.udf.register("avg", functions.udaf(new AggTest2()))
  sc.sql("select avg(count) from res")

  sc.stop()
}

3. 用户自定义函数(UDF)

val spark = SparkSession.builder.master("local").appName("test").getOrCreate()
import spark.implicits._
val df = spark.read.json("file/1.txt")
df.show
//注册简单的UDF
spark.udf.register("addName",(name:String)=>{
  "Name:"+name
})
df.createOrReplaceTempView("user")
//使用注册的udf
spark.sql("select addName(name),age from user").show()
spark.stop()
上一篇 下一篇

猜你喜欢

热点阅读