spark SQL 2.用户自定义聚合函数(UDAF)
2021-05-20 本文已影响0人
caster
Spark 在 org.apache.spark.sql.functions 类中定义了 DataFrame 的常见聚合函数,如count(),countDistinct(),avg(),max(),min() 等。虽然这些功能是专门为 DataFrame 设计的,Spark SQL 也在 java 和 scala 中设计了其中一些函数的类型安全版本用于强类型限制的Dataset。此外,用户还可以创建自己的聚合函数。
1. 适用于未定义类型的DataFrame的UDFA
后续版本不再支持(3.0),SQL建议使用强类型即DataSet
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class AggTest1 extends UserDefinedAggregateFunction{
// 聚合函数输入数据结构
override def inputSchema: StructType = StructType(StructField("input", IntegerType)::Nil)
// 缓存区数据结构,用于计算
override def bufferSchema: StructType = StructType(StructField("sum", IntegerType)::StructField("count", IntegerType)::Nil)
// 聚合函数输出值数据结构
override def dataType: DataType = DoubleType
// 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
override def deterministic: Boolean = true
// 初始化缓冲区
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0
buffer(1) = 0
//buffer.update(0,0)
//buffer.update(1,0)
}
// 给聚合函数传入一条新数据进行处理
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (input.isNullAt(0)) return
buffer(0) = buffer.getInt(0) + input.getInt(0)
buffer(1) = buffer.getInt(1) + 1
//buffer.update(0,buffer.getInt(0) + input.getInt(0))
//buffer.update(1,buffer.getInt(1) + 1)
}
// 合并聚合函数缓冲区(分布式)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0)
buffer1(1) = buffer1.getInt(1) + buffer2.getInt(1)
//buffer1.update(0,buffer1.getInt(0) + buffer2.getInt(0))
//buffer1.update(1,buffer1.getInt(1) + buffer2.getInt(1))
}
// 计算最终结果
override def evaluate(buffer: Row): Any = {
buffer.getInt(0).toDouble / buffer.getInt(1)
}
}
测试方法如下:
def main(args: Array[String]): Unit = {
val sc = SparkSession.builder().appName("test").master("local").getOrCreate()
val lines = sc.sparkContext.parallelize(Array("i love china","i love usa"))
//进行分词操作
//:RDD[String]
val words = lines.flatMap(_.split(" "))
//Map操作:每个单词记一次数
//:RDD[(String,Int)]
val wordPair = words.map((_,1))
//key相同单词计数合并
//:RDD[(String,Int)]
val wordCount = wordPair.reduceByKey(_+_)
//RDD转为Row类型
val rowRdd:RDD[Row] = wordCount.map(wc => Row(wc._1, wc._2))
val schema = StructType(Array(StructField("word", StringType, nullable = true),StructField("count", IntegerType, nullable = true)))
//RDD[Row]转为DataFrame
val frame = sc.createDataFrame(rowRdd, schema)
//注册视图用于sql查询
frame.createOrReplaceTempView("res")
//注册UDFA
sc.udf.register("avg", new AggTest1())
val results = sc.sql("select avg(count) from res")
results.show()
sc.stop()
}
2. 适用于强类型限制的Dataset的UDFA
条理清晰,字段直观调用,不需要buffer(i)方式
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
//定义in buffer out类型
//早起版本in设置为列对象类型,新版本支持SQL查询,in设置为需要聚合字段类型
object AggTest2 extends Aggregator[Words,Average,Double]{
//初始化缓冲区buffer
override def zero: Average = Average(0,0)
//处理一条新的记录,聚合后更新缓冲区
override def reduce(b: Average, a: Words): Average = {
b.sum += a.count
b.count += 1
b
}
//合并多个buffer
override def merge(b1: Average, b2: Average): Average = {
b1.sum += b2.sum
b1.count += b2.count
b1
}
//计算平均值
override def finish(reduction: Average): Double = {
reduction.sum.toDouble / reduction.count
}
//缓冲区编码器
override def bufferEncoder: Encoder[Average] = Encoders.product
//输出编码器
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
//缓冲器数据类型,参数值需要可改
case class Average(var sum: Int, var count: Int) {
}
//输入DataSet的数据类型
case class Words(word: String, count: Int) {
}
测试方法如下:
def main(args: Array[String]): Unit = {
val sc = SparkSession.builder().appName("test").master("local").getOrCreate()
val lines = sc.sparkContext.parallelize(Array("i love china","i love usa"))
//进行分词操作
//:RDD[String]
val words = lines.flatMap(_.split(" "))
//Map操作:每个单词记一次数
//:RDD[(String,Int)]
val wordPair = words.map((_,1))
//key相同单词计数合并
//:RDD[(String,Int)]
val wordCount = wordPair.reduceByKey(_+_)
//RDD转为Row类型
val rowRdd:RDD[Row] = wordCount.map(wc => Row(wc._1, wc._2))
val schema = StructType(Array(StructField("word", StringType, nullable = true),StructField("count", IntegerType, nullable = true)))
//用于隐式转化toDF,toDS,as[]等
import sc.implicits._
//RDD[Row]转为DataFrame
val ds = sc.createDataFrame(rowRdd, schema).as[Words]
//1. 旧版本用法,早起强类型UDAF只能用DSL
ds.select(AggTest2.toColumn.name("avg")).show()
//2. 新版本用法,可以将强类型UDAF使用于SQL
ds.createOrReplaceTempView("res")
sc.udf.register("avg", functions.udaf(new AggTest2()))
sc.sql("select avg(count) from res")
sc.stop()
}
3. 用户自定义函数(UDF)
val spark = SparkSession.builder.master("local").appName("test").getOrCreate()
import spark.implicits._
val df = spark.read.json("file/1.txt")
df.show
//注册简单的UDF
spark.udf.register("addName",(name:String)=>{
"Name:"+name
})
df.createOrReplaceTempView("user")
//使用注册的udf
spark.sql("select addName(name),age from user").show()
spark.stop()