Pyspark-常用语句

2020-08-17  本文已影响0人  LeonTung
dataframe字段操作
字段split为array: withColumn('catenew', split(col('cates'), ','))
打印权重
rfModel = model_pipe.stages[-1]
attrs = sorted(
    (attr['idx'], attr['name']) for attr in
    (chain(*df_test_result.schema['features_asb'].metadata['ml_attr']['attrs'].values()))  # features_asb为assemble的output
)
feature_weight = [(idx, name, float(rfModel.featureImportances[idx])) for idx, name in attrs]
df_weight = spark.createDataFrame(feature_weight, ['idx', 'feature', 'weight'])
df_weight.orderBy(df_weight.weight.desc()).show(df_weight.count(), truncate=False)
解析概率
from pyspark.sql import functions as F
split_udf = udf(lambda value: float(value[1]))  # 需将dataframe的numpy.float64 cast to a python float.
df_result = df_result.withColumn('proba', split_udf('probability')).select('member_id', 'prediction', F.round('proba', 3).alias('proba'))
模型调参

see databricks

pipeline = Pipeline(stages=[assembler, gbdt])
paramGrid = (ParamGridBuilder()
             .addGrid(gbdt.maxDepth, [3, 5, 7])
             .addGrid(gbdt.maxIter, [15, 20, 25])
             .build())  # 参数搜索范围
cv = CrossValidator(estimator=pipeline,estimatorParamMaps=paramGrid,evaluator=BinaryClassificationEvaluator(), numFolds=3)
cvModel = cv.fit(df_train)
df_test_result = cvModel.transform(df_test)
gbdtModel = cvModel.bestModel.stages[-1]  # 获得模型

初始化spark

spark = SparkSession.builder.appName('pspredict').enableHiveSupport().config('spark.driver.memory', '8g').getOrCreate()  # jupyter
spark.sparkContext.setLogLevel('ERROR')

常用缺失值填充

(1) replace(to_replace, values, subset)
(2) replace('', 'unknown', 'country_nm')
(3) replace(['a', 'b'], ['c', 'd'], 'country_nm'): 将国家(可列表)中a->c, b->d, ab需同类型,b不能为None
(4) replace({-1: 14}, 'stature'): 将stature的-1->14,values参数无效,字典里多个需同类型(string与None不能混用)
(5) fillna('haha'): 将null->'haha', 非string值跳过
(6) fillna('xx', [columns_name]): 将多列统一替换na->xx
(7) fillna({'f1': 24, 'f2': 'hah'}): 多列分别替换

StringIndexer 多字段处理

pyspark StringIndexer 输入列不支持多字段, 考虑使用表达式列表实现
indexer = [StringIndexer(inputCol=x, outputCol='{}_idx'.format(x), handleInvalid='keep')  for x in feature_index]
上一篇下一篇

猜你喜欢

热点阅读