Flink-sql自定义UDAF函数

2021-08-17  本文已影响0人  wudl

1. 用自定义的函数在Flink Sql 中使用

1.1 官网也说的很详细

https://ci.apache.org/projects/flink/flink-docs-release-1.12/dev/table/functions/udfs.html#scalar-functions

1.1.1 官网上面的例子:

import org.apache.flink.table.api.*;
import org.apache.flink.table.functions.AggregateFunction;
import static org.apache.flink.table.api.Expressions.*;

// mutable accumulator of structured type for the aggregate function
public static class WeightedAvgAccumulator {
  public long sum = 0;
  public int count = 0;
}

// function that takes (value BIGINT, weight INT), stores intermediate results in a structured
// type of WeightedAvgAccumulator, and returns the weighted average as BIGINT
public static class WeightedAvg extends AggregateFunction<Long, WeightedAvgAccumulator> {

  @Override
  public WeightedAvgAccumulator createAccumulator() {
    return new WeightedAvgAccumulator();
  }

  @Override
  public Long getValue(WeightedAvgAccumulator acc) {
    if (acc.count == 0) {
      return null;
    } else {
      return acc.sum / acc.count;
    }
  }

  public void accumulate(WeightedAvgAccumulator acc, Long iValue, Integer iWeight) {
    acc.sum += iValue * iWeight;
    acc.count += iWeight;
  }

  public void retract(WeightedAvgAccumulator acc, Long iValue, Integer iWeight) {
    acc.sum -= iValue * iWeight;
    acc.count -= iWeight;
  }

  public void merge(WeightedAvgAccumulator acc, Iterable<WeightedAvgAccumulator> it) {
    for (WeightedAvgAccumulator a : it) {
      acc.count += a.count;
      acc.sum += a.sum;
    }
  }

  public void resetAccumulator(WeightedAvgAccumulator acc) {
    acc.count = 0;
    acc.sum = 0L;
  }
}

TableEnvironment env = TableEnvironment.create(...);

// call function "inline" without registration in Table API
env
  .from("MyTable")
  .groupBy($("myField"))
  .select($("myField"), call(WeightedAvg.class, $("value"), $("weight")));

// register function
env.createTemporarySystemFunction("WeightedAvg", WeightedAvg.class);

// call registered function in Table API
env
  .from("MyTable")
  .groupBy($("myField"))
  .select($("myField"), call("WeightedAvg", $("value"), $("weight")));

// call registered function in SQL
env.sqlQuery(
  "SELECT myField, WeightedAvg(`value`, weight) FROM MyTable GROUP BY myField"
);

2.自己实现

package com.wudl.flink.sql;

import com.wudl.flink.bean.WaterSensor;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.annotation.DataTypeHint;
import org.apache.flink.table.annotation.FunctionHint;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.flink.table.functions.TableFunction;
import org.apache.flink.types.Row;

import static org.apache.flink.table.api.Expressions.$;
import static org.apache.flink.table.api.Expressions.call;

/**
 * @ClassName : Flink_Sql_Function_UDTF
 * @Description : Flink自定义udtf 函数
 * @Author :wudl
 * @Date: 2021-08-17 22:55
 */

public class Flink_Sql_Function_UDTF {
    public static void main(String[] args) throws Exception {


        //1. 获取执行环境
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(1);
        StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);
        //2. 读取端口中的数据并且转化为javaBean
        SingleOutputStreamOperator<WaterSensor> waterSensorDs = env.socketTextStream("192.168.1.180", 9999)
                .map(line -> {
                    String[] split = line.split(",");
                    return new WaterSensor(split[0], Long.parseLong(split[1]), Integer.parseInt(split[2]));
                });

        // 3. 讲流 转化为动态表
        Table table = tableEnv.fromDataStream(waterSensorDs);
        // 5. 先注册在使用
        tableEnv.createTemporarySystemFunction("MyAvg", MyAvg.class);
        //5.1 使用table api 实现的方式
//        table.groupBy($("id")).select($("id"),call("MyAvg",$("vc")))
//                .execute()
//                .print();
        // 5.2 采用sql 的写法
        tableEnv.sqlQuery("select id , MyAvg(vc) from "+table+" group by id")
                .execute()
                .print();

        //5. 执行任务
        env.execute();


    }

    // 自定义函数类Udtf 求平均数
    public static class MyAvg extends AggregateFunction<Double, SumCount> {

        public void accumulate(SumCount acc, Integer vc) {
            acc.setVcSum(acc.getVcSum() + vc);
            acc.setCount(acc.getCount() + 1);
        }


        @Override
        public Double getValue(SumCount sumCount) {
            return sumCount.getVcSum() * 1D / sumCount.getCount();
        }

        @Override
        public SumCount createAccumulator() {
            return new SumCount();
        }
    }

}

一个应用类

package com.wudl.flink.sql;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

/**
 * @ClassName : SumCount
 * @Description : 自定义UDAF函数的 的bean
 * @Author :wudl
 * @Date: 2021-08-17 23:33
 */

@Data
@AllArgsConstructor
@NoArgsConstructor
public class SumCount {

    private int vcSum;
    private int count;
}

打印结果


Flink-sql-udtf.png
上一篇下一篇

猜你喜欢

热点阅读