Spark SQL原理之Aggregate实现原理
聚合函数的分类
-
声明式聚合函数: 可以由Catalyst中的表达式直接构建的聚合函数,也是比较简单的聚合函数类型,最常见的count, sum,avg等都是声明式聚合函数。
-
命令式聚合函数: 是指一类需要显式实现几个方法来操作聚合缓冲区AggBuffer中的数据的聚合函数。命令式聚合函数不那么常见,能找到的命令式聚合函数包括基数统计hyperLogLogPlus、透视转换pivotFirst等。
-
带类型的命令式聚合函数: 带类型的命令式聚合函数是最灵活的一种聚合函数类型,它允许使用用户自定义对象作为聚合缓冲区。涉及用户自定义类型的聚合都是这种类型,例如collect_list、collect_set、percentile等。
聚合缓冲区和聚合模式
-
Partial模式: 先把局部数据进行聚合,比如先计算每个分区的sum
-
PartialMerge模式: 出现在多种类型的聚合函数同时聚合的情况,比如同时聚集sum和countDistinct。这时候缓冲区聚合之后的结果,仍然是中间结果;
-
Final模式: 把聚合缓冲区中的聚合结果再进行聚合,比如计算分区sum的sum;
-
Complete模式: 没有中间聚合过程,每个分组的全体值都需要在一次聚合过程中参与计算。(待举例)
在当前的实现中,这几个模式的分类其实并不是很好,可以参考AggregationIterator中的注释:
/**
* The following combinations of AggregationMode are supported:
* - Partial
* - PartialMerge (for single distinct)
* - Partial and PartialMerge (for single distinct)
* - Final
* - Complete (for SortAggregate with functions that does not support Partial)
* - Final and Complete (currently not used)
*
* TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression
* could have a flag to tell it's final or not.
*/
planAggregateWithoutDistinct:不带有distinct方法的聚合实现
-
Step1: 【Partial模式】计算聚合的Partial结果
-
groupingExpressions:group列(a)
-
aggregateExpressions: 聚合(partial_sum(cast(b#11 as bigint))]))
-
-
Step2: 【Final】计算聚合的Final结果
-
groupingExpressions:group列(a) + distinct使用的列(b)
-
aggregateExpressions: 聚合(sum(b))
-
create temporary view data as select * from values
(1, 1),
(1, 2),
(2, 1),
(2, 2),
(3, 1),
(3, 2)
as data(a, b);
explain select sum(b) from data group by a;
== Physical Plan ==
*HashAggregate(keys=[a#10], functions=[sum(cast(b#11 as bigint))])
+- Exchange hashpartitioning(a#10, 200)
+- *HashAggregate(keys=[a#10], functions=[partial_sum(cast(b#11 as bigint))])
+- LocalTableScan [a#10, b#11]
planAggregateWithOneDistinct: 带有distinct方法的聚合实现
create temporary view data as select * from values
(1, 1),
(1, 2),
(2, 1),
(2, 2),
(3, 1),
(3, 2)
as data(a, b);
explain select sum(b), sum(distinct b) from data group by a;
== Physical Plan ==
*HashAggregate(keys=[a#10], functions=[sum(cast(b#11 as bigint)), sum(distinct cast(b#11 as bigint)#94L)])
+- Exchange hashpartitioning(a#10, 200)
+- *HashAggregate(keys=[a#10], functions=[merge_sum(cast(b#11 as bigint)), partial_sum(distinct cast(b#11 as bigint)#94L)]) // step3
+- *HashAggregate(keys=[a#10, cast(b#11 as bigint)#94L], functions=[merge_sum(cast(b#11 as bigint))]) // step2
+- Exchange hashpartitioning(a#10, cast(b#11 as bigint)#94L, 200)
+- *HashAggregate(keys=[a#10, cast(b#11 as bigint) AS cast(b#11 as bigint)#94L], functions=[partial_sum(cast(b#11 as bigint))]) // step1
+- LocalTableScan [a#10, b#11]
-
Step1: 【Partial模式】计算非distinct聚合的Partial结果
-
groupingExpressions:group列(a) + distinct使用的列(b)
-
aggregateExpressions: 非distinct的聚合(sum(b))
-
resultExpressions = group列(a) + distinct使用的列(b) + 非distinct的Partial聚合结果(sum(b))
-
-
Step2: 【PartialMerge】计算非distinct聚合的PartialMerge结果
-
groupingExpressions:group列(a) + distinct使用的列(b)
-
aggregateExpressions: 非distinct的聚合(sum(b))
-
resultExpressions = group列(a) + distinct使用的列(b) + 非distinct的Partial聚合结果(sum(b))
-
-
Step3: 【PartialMerge】计算带有distinct聚合的PartialMerge结果
-
groupingExpressions:group列(a)
-
aggregateExpressions: 非distinct的聚合(sum(b)) + 带有distinct的聚合(partial_sum(distinct cast(b#11 as bigint)#94L))
-
resultExpressions = group列(a) + 非distinct的merge聚合结果 + 带有distinct的partial聚合结果
-
-
Step4: 【Final】计算非distinct聚合的PartialMerge结果
-
groupingExpressions:group列(a)
-
aggregateExpressions: 非distinct的聚合(Final模式) + 带有distinct的聚合(Final模式)
-
resultExpressions = group列(a) + distinct使用的列(b) + 非distinct的Partial聚合结果(sum(b))
-
参考:
- 《SparkSQL内核剖析》【Aggregation篇】:https://blog.csdn.net/renq_654321/article/details/94925717