ducktype 性能测试

2024-09-05  本文已影响0人  hehehehe


def udf_concurrency():
    class A(object):
        @staticmethod
        def gen_linestring():
            points = [(random.uniform(0, 100), random.uniform(0, 100)) for _ in range(5)]
            line = LineString(points)
            return line.wkt

        @staticmethod
        def geom_z_diff(geom_wkt: str) -> float:
            coords, has_z, geom_type = geoutil.get_coords_from_wkt(geom_wkt)
            if coords and has_z:
                z_vals = [coord[2] for coord in coords]
                z_vals_diff = max(z_vals) - min(z_vals)
                return z_vals_diff
            return 0

        @staticmethod
        def list_intersect(a: list, b: list) -> list:
            # time.sleep(0.001)
            return list(set(a).intersection(set(b)))

        @staticmethod
        def udf2(a: int) -> str:
            # time.sleep(0.001)
            return f"{a}_{a}"

    conn.create_function("geom_z_diff", A.geom_z_diff, null_handling="special", exception_handling="return_null")
    conn.create_function("udf2", A.udf2, null_handling="special", exception_handling="return_null")
    wkt = 'LINESTRING Z(116.21126367 39.86818626 0.0000011, 116.21167992 39.86805255 0.00005, 116.21143061 39.86831677 0.00005, 116.21125012 39.86857083 0.00005)'
    # wkt = 'b'

    data = [{"a": 1, "b": 1, "c": wkt, "d": 'd1'}]
    # data = [{"a": [1, random.randint(0, 5), random.randint(0, 5)], "b": [2, random.randint(0, 5), random.randint(0, 5)], "c": wkt, "d": 'd1'}]
    # datas = []
    # for i in range(1000000):
    #     # datas.append({"a": [1, random.randint(0, 5), random.randint(0, 5)], "b": [2, random.randint(0, 5), random.randint(0, 5)], "c": wkt, "d": 'd1'})
    #     datas.append({"a": 1, "b": 1, "c": wkt, "d": 'd1'})
    datas = data * 500000
    df = pl.from_dicts(datas)
    print(f"{df.shape=}")
    conn.execute("PRAGMA threads=10;")

    start = time.time()
    t = conn.sql("select a,geom_z_diff(c) as c from df")
    conn.sql("select count(c) from t").show()
    # t = conn.sql("select a,list_intersect(a, b) as b from df")
    # conn.sql("select count(b) from t").show()
    # # print(a.shape)
    # # conn.sql("select a,udf2(b) from df").show()
    print(f"duckdb1:{time.time() - start}")

    # start = time.time()
    # conn.sql("select * from t").show()
    # # t2 = conn.sql("select b,geom_z_diff(c) from df")
    # # t2 = conn.sql("select a,udf2(b) from df")
    # print(f"duckdb3:{time.time() - start}")

    # start = time.time()
    # conn.sql("select * from t ").pl()
    # print(f"duckdb2:{time.time() - start}")

    start = time.time()
    # for i in a:
    #     udf2(i)
    # for i in datas:
    #     geom_z_diff(i['c'])
    [A.geom_z_diff(i['c']) for i in datas]
    # [udf2(i['b']) for i in datas]
    # [A.list_intersect(i['a'], i['b']) for i in datas]
    print(f"py:{time.time() - start}")
    start = time.time()
    # [A.list_intersect(i['a'], i['b']) for i in datas]
    # func = getattr(A, "list_intersect")
    geom_z_diff = getattr(A, "geom_z_diff")
    params = ['a', 'b']
    # compiled_expr = compile("func(d['a'], d['b'])", '<string>', 'eval')
    compiled_expr = compile("geom_z_diff(row['c'])", '<string>', 'eval')
    for row in datas:
        # val = [d[i] for i in params]
        # A.list_intersect(i['a'], i['b'])
        # func(*val)
        eval(compiled_expr)
        # eval("A.list_intersect(d['a'], d['b'])")
    # [eval("list_intersect(i['a'], i['b'])") for i in datas]
    print(f"py-eval:{time.time() - start}")
    start = time.time()

    # df2 = df.with_columns(pl.col("c").map_elements(geom_z_diff, return_dtype=pl.Float64).alias("bb"))
    df2 = df.with_columns(
        # pl.struct(["c"]).map_elements(lambda row: geom_z_diff(row['c']), return_dtype=polars.Float64).alias("bb")
        pl.struct(["c"]).map_elements(
            lambda row: eval(compiled_expr, {"geom_z_diff": getattr(A, "geom_z_diff"), "row": row}),
            return_dtype=eval("pl.Float64")).alias("bb")
        # pl.struct(["c"]).map_elements(eval(compiled_expr, {"geom_z_diff": getattr(A, "geom_z_diff"), "row": row}))
    )

    print(df2.head(10))
    # df2 = df.with_columns(pl.col("b").map_elements(udf2).alias("bb"))
    # df2 = df.with_columns(pl.col("b").map_elements(list_intersect).alias("bb"))
    # print(df2.head())
    print(f"df:{time.time() - start}")
    # start = time.time()
    # df2 = df.with_columns(
    #     pl.col("b").map_batches(udf_batch).alias("bb")
    # )
    # print(df2.head())
    # print(f"{time.time() - start}")

if __name__ == '__main__':
    # a = time.time()
    # postgres_scan()
    # print(time.time() - a)
    a = time.time()
    # df3 = get_df()
    # conn.sql("select * from df3").show()
    # null_test()
    udf_concurrency()
    print(time.time() - a)

上一篇 下一篇

猜你喜欢

热点阅读