pyspark udf

2022-01-17  本文已影响0人  hehehehe

udf lambda

df4 = df3.withColumn("hn_id3", udf(lambda x: x[0],StringType())(df3.ids))

group_df = df1.groupBy(['house_num_normal']).agg(fn.collect_list('lon').alias('lons'),
                                                     fn.collect_list('lat').alias('lats'),
                                                     fn.collect_list('id2')[0].alias('id'))
group_df.withColumn('centroid', get_centroid(group_df.lons, group_df.lats))

@udf(returnType=StringType())
def get_wkt(lons: list, lats: list):
    points = []
    for lon, lat in zip(lons, lats):
        points.append(Point(float(lon), float(lat)))
    return MultiPoint(points).wkt


@udf(returnType=ArrayType(StringType()))
def lonlat_split(address: str):
    return address[:-1].split(",")


@udf(returnType=ArrayType(FloatType()))
def lonlat_to_list(lon: float, lat: float):
    return [lon, lat]


@udf(returnType=StringType())
def points_hull(points):
    return MultiPoint(points).convex_hull.wkt


@udf(returnType=StringType())
def polygon_centroid(polygon_wkt):
    return wkt.loads(polygon_wkt).centroid.wkt

@udf(returnType=StringType())
def get_centroid(lons: list, lats: list):
    points = []
    for lon, lat in zip(lons, lats):
        points.append(Point(float(lon), float(lat)))
    return MultiPoint(points).centroid.wkt


@udf(returnType=FloatType())
def get_distance(confidence_list: list, lons: list, lats: list):
    if 5 in confidence_list:
        main_idx = confidence_list.index(5)
        main_lon, main_lat = lons[main_idx], lats[main_idx]
        max_dist = 0
        for i, confidence in enumerate(confidence_list):
            if i != main_idx:
                dist = distance.geodesic((float(main_lat), float(main_lon), 0),
                                         (float(lats[i]), float(lons[i]), 0)).meters
                if dist > max_dist:
                    max_dist = dist
        return max_dist

    return -1.0

df2 = spark.sql("""
        select group_key,concat_ws(",",hn_id_list) as hn_id_str,concat_ws(",",confidence_list) as confidence_str,confidence_list,
        lon_list, lat_list 
        from hn_table3
        where hn_id_len > 1
    """)
上一篇 下一篇

猜你喜欢

热点阅读