MachineLearning 7. 癌症诊断机器学习之回归树(
前 言
树方法精髓就是划分特征,从第一次分裂开始就要考虑如何最大程度改善RSS,然后持续进行“树权”分裂,直到树结束。后面的划分并不作用于全数据集,而仅作用于上次划分时落到这个分支之下的那部分数据。这个自顶向下的过程被称为“递归划分”。这个过程是贪婪的,贪婪的含义是指算法在每次分裂中都追求最大程度减少RSS,而不管以后的划分中表现如何。这样做可能会生成一个带有无效分支的树,尽管偏差很小,但是方差很大。为了避免这个问题,生成完整的树之后,你要对树进行剪枝,得到最优的解。这种方法的优点是可以处理高度非线性关系,但它还存在一些潜在的问题:一个观测被赋予所属终端节点的平均值,这会损害整体预测效果(高偏差)。相反,如果你一直对数据进行划分,树的层次越来越深,这样可以达到低偏差的效果,但是高方差又成了问题。和其他方法一样,你也可以用交叉验证来选择合适的深度。
基本原理
回归树(regression tree),顾名思义,就是用树模型做回归问题,每一片叶子都输出一个预测值。预测值一般是该片叶子所含训练集元素输出的均值,即 cm=ave(yi|xi∈leafm)。
CART 在分类问题和回归问题中的相同和差异:
相同:在分类问题和回归问题中,CART 都是一棵二叉树,除叶子节点外的所有节点都有且仅有两个子节点;所有落在同一片叶子中的输入都有同样的输出。
差异:在分类问题中,CART 使用基尼指数(Gini index)作为选择特征(feature)和划分(split)的依据;
在回归问题中,CART 使用 mse(mean square error)或者 mae(mean absolute error)作为选择 feature 和 split 的 criteria。在分类问题中,CART 的每一片叶子都代表的是一个 class;在回归问题中,CART 的每一片叶子表示的是一个预测值,取值是连续的。
实例解析
我们选取前列腺癌的数据,做回归树分析,之前乳腺癌是做分类树,而前列腺癌用来做回归树,差别就在于因变量的性质,是分类变量,还是连续变量.
FormatA data frame with 97 observations on the following 10 variables.
1. 软件安装
这里我们主要使用rpart和partykit两个软件包,其他都为数据处理过程中需要使用软件包,如下:
if (!require(rpart)) install.packages("rpart")
if (!require(partykit)) install.packages("partykit")
if (!require(caret)) install.packages("caret")
if (!require(rpart.plot)) install.packages("rpart.plot")
if (!require(ElemStatLearn)) install.packages("ElemStatLearn")
library(rpart) #classification and regression trees
library(partykit) #treeplots
library(caret) #tune hyper-parameters
library(rpart.plot)
library(ElemStatLearn)
2. 数据读取
数据来源《机器学习与R语言》书中,具体来自UCI机器学习仓库。地址:http://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/ 下载wbdc.data和wbdc.names这两个数据集,数据经过整理,成为面板数据。查看数据结构,其中第一列为id列,无特征意义,需要删除。第二列diagnosis为响应变量,字符型,一般在R语言中分类任务都要求响应变量为因子类型,因此需要做数据类型转换。剩余的为预测变量,数值类型。查看数据维度,97个样本,9个特征(包括响应特征)。
data(prostate)
prostate$gleason <- ifelse(prostate$gleason == 6, 0, 1)
sum(is.na(prostate)) ###判断缺省值
## [1] 0
数据分布
整体数据的分布,如下:
library(reshape2)
library(ggplot2)
bc <- prostate[, -10]
bc.melt <- melt(bc)
head(bc.melt)
## variable value
## 1 lcavol -0.5798185
## 2 lcavol -0.9942523
## 3 lcavol -0.5108256
## 4 lcavol -1.2039728
## 5 lcavol 0.7514161
## 6 lcavol -1.0498221
ggplot(data = bc.melt, aes(x = variable, y = value)) + geom_boxplot(colour = "blue") +
geom_jitter(size = 0.5, colour = "gray", alpha = 0.5) + theme_bw()
3. 数据处理
比较每组数据的相关性,如下:
library(tidyverse)
corrplot::corrplot(cor(bc))
数据分割
当我们只有一套数据的时候,可以将数据分为训练集和测试集,具体怎么分割可以看公众号的专题:Topic 5. 样本量确定及分割
# 数据分割 训练数据
train_data <- subset(prostate, train == TRUE)[, 1:9]
# 测试数据
test_data = subset(prostate, train == FALSE)[, 1:9]
4. 实例操作
这里我们就是使用rpart软件包里面的函数 rpart完成教程。首先使用rpart函数建立一个分类树模型,然后绘制回归树,如下:
set.seed(123)
tree.pros <- rpart(lpsa ~ ., data = train_data)
length(tree.pros$frame$var[tree.pros$frame$var == "<leaf>"]) # 共有6个叶子
## [1] 6
绘制Rpart Fit的复杂性参数表,我们选择RSS最小的,集xerror最小是,分裂次数位5,如下:
tree.pros$cptable
## CP nsplit rel error xerror xstd
## 1 0.35852251 0 1.0000000 1.0195606 0.17963802
## 2 0.12295687 1 0.6414775 0.8742500 0.12878275
## 3 0.11639953 2 0.5185206 0.7949473 0.10419946
## 4 0.05350873 3 0.4021211 0.7904898 0.09821670
## 5 0.01032838 4 0.3486124 0.7044000 0.09115510
## 6 0.01000000 5 0.3382840 0.7321671 0.09381916
plotcp(tree.pros)
绘制回归树,如下:
plot(as.party(tree.pros))
也可以使用prp绘制
prp(tree.pros, type = 1, extra = 1, under = TRUE, split.font = 1, varlen = -10, box.col = ifelse(tree.pros$frame$var ==
"<leaf>", "gray", "white"))
可以利用rpart.plot软件包绘制回归树,如下:
rpart.plot(tree.pros)
测试集测试
party.tree.test <- predict(tree.pros, newdata = test_data)
rpart.resid <- party.tree.test - test_data$lpsa #calculate residual
mean(rpart.resid^2)
## [1] 0.6136057
剪枝通过剪枝获得更少的分支,如下:
cp <- min(tree.pros$cptable[5, ])
prune.tree.pros <- prune(tree.pros, cp = cp)
plot(as.party(prune.tree.pros))
测试集预测,并计算RSS
party.pros.test <- predict(prune.tree.pros, newdata = test_data)
rpart.resid <- party.pros.test - test_data$lpsa #calculate residual
mean(rpart.resid^2)
## [1] 0.5267748
交叉验证
set.seed(123)
cv.ct <- rpart(lpsa ~ ., data = train_data, xval = 10)
printcp(cv.ct)
##
## Regression tree:
## rpart(formula = lpsa ~ ., data = train_data, xval = 10)
##
## Variables actually used in tree construction:
## [1] age lcavol lweight
##
## Root node error: 96.281/67 = 1.437
##
## n= 67
##
## CP nsplit rel error xerror xstd
## 1 0.358523 0 1.00000 1.01956 0.179638
## 2 0.122957 1 0.64148 0.87425 0.128783
## 3 0.116400 2 0.51852 0.79495 0.104199
## 4 0.053509 3 0.40212 0.79049 0.098217
## 5 0.010328 4 0.34861 0.70440 0.091155
## 6 0.010000 5 0.33828 0.73217 0.093819
plotcp(cv.ct)
rpart.plot(cv.ct)
party.cv.test <- predict(cv.ct, newdata = test_data)
rpart.resid <- party.cv.test - test_data$lpsa #calculate residual
mean(rpart.resid^2)
## [1] 0.6136057
使用使得cost complexity最小的cp值来修剪树枝,如下:
# 使用使得cost complexity最小的cp值来修剪树枝
set.seed(123)
pruned.ct <- prune(cv.ct, cp = cv.ct$cptable[which.min(cv.ct$cptable[, "xerror"]),
"CP"])
length(pruned.ct$frame$var[pruned.ct$frame$var == "<leaf>"]) # 新模型的叶子数为5
## [1] 5
plotcp(pruned.ct)
剪枝后绘制回归树,如下:
prp(pruned.ct, type = 1, extra = 1, split.font = 1, varlen = -10)
预测结果我们在看下模型的性能,利用混合矩阵的方法,如下:
party.prune.test <- predict(pruned.ct, newdata = test_data)plot(as.party(pruned.ct))
rpart.resid <- party.prune.test - test_data$lpsa #calculate residual
mean(rpart.resid^2)
## [1] 0.5267748
绘制ROC曲线
### ROC
library(ROCR)
pred <- prediction(predictions = party.prune.test, labels = test_data$gleason)
perf <- performance(prediction.obj = pred, measure = "tpr", x.measure = "fpr")
perf
## A performance instance
## 'False positive rate' vs. 'True positive rate' (alpha: 'Cutoff')
## with 6 data points
plot(perf, colorize = TRUE, main = "ROC", lwd = 2, xlab = "True positive rate", ylab = "False positive rate",
box.lty = 7, box.lwd = 2, box.col = "gray")
abline(a = 0, b = 1, lty = 2, col = "gray")
结果解读
从准确性上来看位0.80,终结一下,如果需要做回归树,因变量需要选择连续型变量,并且比较的结果需要看RSS,从上面的结果可以看出未剪枝前RSS为0.614,分支为6,剪枝后RSS为0.527,分支为5.
References:
-
Breiman L., Friedman J. H., Olshen R. A., and Stone, C. J. (1984) Classification and Regression Trees. Wadsworth.
-
Mangasarian OL, Street WN, Wolberg WH. Breast cancer diagnosis and prognosis via linear programming. Operations Research. 1995; 43:570-577.
本文使用 文章同步助手 同步