Learning Spark [9] - MLlib库 - Tr

2021-02-19  本文已影响0人  屹然1ran

基于树模型(Tree-Based Models)

基于树模型,比如决策树,梯度提升树,随机森林等,相对比回归模型,是较为好解释的(Interpret)

决策树(Decision Tree)

决策树模型是由一系列的if-then-else规则构成,用于解决分类或回归问题。


决策树模型:是否接受offer
# decision tree
from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.ml.feature import VectorAssembler, OneHotEncoder, StringIndexer

df = DecisionTreeRegressor(labelCol = 'price')

# filter for just numeric columns
numericCols = [field for (field, dataType) in trainDF.dtypes
               if ((dataType == 'double') & (field != 'price'))]

# combine output of StringIndexer defined above and numeric columns
# dummy variable 
categoricalCols = [field for (field, dataTypes) in trainDF.dtypes if dataTypes == 'string']
indexOutputCols = [x + 'Index' for x in categoricalCols]
assembleInputs = indexOutputCols + numericCols

vecAssembler = VectorAssembler(inputCols = assembleInputs, outputCol = 'features')

stringIndexer = StringIndexer(inputCols=categoricalCols,                               
                              outputCols=indexOutputCols,                               
                              handleInvalid = 'skip')

stages = [stringIndexer, vecAssembler, df]
pipeline = Pipeline(stages = stages)

df.setMaxBins(40)
pipelineModel = pipeline.fit(trainDF)

dtModel = pipelineModel.stages[-1]
print(dtModel.toDebugString)
DecisionTreeRegressionModel: uid=DecisionTreeRegressor_901986020659, depth=5, numNodes=47, numFeatures=33
  If (feature 12 <= 2.5)
   If (feature 12 <= 1.5)
    If (feature 5 in {1.0,2.0})
     If (feature 4 in {0.0,1.0,3.0,5.0,9.0,10.0,11.0,13.0,14.0,16.0,18.0,24.0})
      If (feature 3 in {0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,16.0,17.0,18.0,19.0,20.0,21.0,23.0,24.0,25.0,26.0,27.0,28.0,29.0,30.0,31.0,32.0,33.0,34.0})
       Predict: 104.23992784125075
      Else (feature 3 not in
...
上一篇下一篇

猜你喜欢

热点阅读