Learning Spark [9] - MLlib库 - Tr
2021-02-19 本文已影响0人
屹然1ran
基于树模型(Tree-Based Models)
基于树模型,比如决策树,梯度提升树,随机森林等,相对比回归模型,是较为好解释的(Interpret)
决策树(Decision Tree)
决策树模型是由一系列的if-then-else规则构成,用于解决分类或回归问题。
决策树模型:是否接受offer
# decision tree
from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.ml.feature import VectorAssembler, OneHotEncoder, StringIndexer
df = DecisionTreeRegressor(labelCol = 'price')
# filter for just numeric columns
numericCols = [field for (field, dataType) in trainDF.dtypes
if ((dataType == 'double') & (field != 'price'))]
# combine output of StringIndexer defined above and numeric columns
# dummy variable
categoricalCols = [field for (field, dataTypes) in trainDF.dtypes if dataTypes == 'string']
indexOutputCols = [x + 'Index' for x in categoricalCols]
assembleInputs = indexOutputCols + numericCols
vecAssembler = VectorAssembler(inputCols = assembleInputs, outputCol = 'features')
stringIndexer = StringIndexer(inputCols=categoricalCols,
outputCols=indexOutputCols,
handleInvalid = 'skip')
stages = [stringIndexer, vecAssembler, df]
pipeline = Pipeline(stages = stages)
df.setMaxBins(40)
pipelineModel = pipeline.fit(trainDF)
dtModel = pipelineModel.stages[-1]
print(dtModel.toDebugString)
DecisionTreeRegressionModel: uid=DecisionTreeRegressor_901986020659, depth=5, numNodes=47, numFeatures=33
If (feature 12 <= 2.5)
If (feature 12 <= 1.5)
If (feature 5 in {1.0,2.0})
If (feature 4 in {0.0,1.0,3.0,5.0,9.0,10.0,11.0,13.0,14.0,16.0,18.0,24.0})
If (feature 3 in {0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,16.0,17.0,18.0,19.0,20.0,21.0,23.0,24.0,25.0,26.0,27.0,28.0,29.0,30.0,31.0,32.0,33.0,34.0})
Predict: 104.23992784125075
Else (feature 3 not in
...