临床数据的的机器模型构建

MachineLearning 6. 肿瘤诊断机器学习之分类树(

2022-05-08  本文已影响0人  桓峰基因


前   言

树方法精髓就是划分特征,从第一次分裂开始就要考虑如何最大程度改善RSS,然后持续进行“树权”分裂,直到树结束。后面的划分并不作用于全数据集,而仅作用于上次划分时落到这个分支之下的那部分数据。这个自顶向下的过程被称为“递归划分”。这个过程是贪婪的,贪婪的含义是指算法在每次分裂中都追求最大程度减少RSS,而不管以后的划分中表现如何。这样做可能会生成一个带有无效分支的树,尽管偏差很小,但是方差很大。为了避免这个问题,生成完整的树之后,你要对树进行剪枝,得到最优的解。这种方法的优点是可以处理高度非线性关系,但它还存在一些潜在的问题:一个观测被赋予所属终端节点的平均值,这会损害整体预测效果(高偏差)。相反,如果你一直对数据进行划分,树的层次越来越深,这样可以达到低偏差的效果,但是高方差又成了问题。和其他方法一样,你也可以用交叉验证来选择合适的深度。

基本原理

回归树(regression tree),顾名思义,就是用树模型做回归问题,每一片叶子都输出一个预测值。预测值一般是该片叶子所含训练集元素输出的均值,即 cm=ave(yi|xi∈leafm)。

CART 在分类问题和回归问题中的相同和差异:

相同:在分类问题和回归问题中,CART 都是一棵二叉树,除叶子节点外的所有节点都有且仅有两个子节点;所有落在同一片叶子中的输入都有同样的输出。

差异在分类问题中,CART 使用基尼指数(Gini index)作为选择特征(feature)和划分(split)的依据;在回归问题中,CART 使用 mse(mean square error)或者 mae(mean absolute error)作为选择 feature 和 split 的 criteria。在分类问题中,CART 的每一片叶子都代表的是一个 class;在回归问题中,CART 的每一片叶子表示的是一个预测值,取值是连续的。

实例解析

本文叙述了线性规划在医学上的两个应用。具体来说,利用基于线性规划的机器学习技术,提高乳腺癌诊断和预后的准确性和客观性。首次应用于乳腺癌诊断利用单个细胞的特征,从微创细针抽吸获得,以区分乳腺肿块的良恶性。这使得不需要手术活检就能做出准确的诊断。威康森大学医院目前运行的诊断系统对569例患者的样本进行了培训,对131例后续患者的诊断具有100%的时间正确性。第二个应用,最近已经投入临床实践,是一种构建一个表面的方法,可以预测肿瘤切除后乳腺癌何时可能复发。这为医生和患者提供了更好的信息来计划治疗,并可能消除对预后外科手术的需要。预测方法的新特点是能够处理癌症没有复发的病例(审查数据)以及癌症在特定时间复发的病例。该预后系统的预期误差为13.9 ~ 18.3个月,优于其他可用技术的预后准确性。

1. 软件安装

这里我们主要使用rpart和partykit两个软件包,其他都为数据处理过程中需要使用软件包,如下:

if (!require(rpart)) install.packages("rpart")
if (!require(partykit)) install.packages("partykit")
if (!require(caret)) install.packages("caret")
if (!require(rpart.plot)) install.packages("rpart.plot")

library(rpart) #classification and regression trees
library(partykit) #treeplots
library(caret) #tune hyper-parameters
library(rpart.plot)

2. 数据读取

数据来源《机器学习与R语言》书中,具体来自UCI机器学习仓库。地址:http://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/ 下载wbdc.data和wbdc.names这两个数据集,数据经过整理,成为面板数据。查看数据结构,其中第一列为id列,无特征意义,需要删除。第二列diagnosis为响应变量,字符型,一般在R语言中分类任务都要求响应变量为因子类型,因此需要做数据类型转换。剩余的为预测变量,数值类型。查看数据维度,568个样本,32个特征(包括响应特征)。

BreastCancer <- read.csv("wisc_bc_data.csv", stringsAsFactors = FALSE)
dim(BreastCancer)
## [1] 568 32
table(BreastCancer$diagnosis)
##
## B M
## 357 211
sum(is.na(data))
## [1] 0

数据分布

比较恶性和良性之间的差距,如下:

library(reshape2)
library(ggplot2)
bc <- BreastCancer[, -1]
bc.melt <- melt(bc, id.var = "diagnosis")
head(bc.melt)
## diagnosis variable value
## 1 M radius_mean 20.57
## 2 M radius_mean 19.69
## 3 M radius_mean 11.42
## 4 M radius_mean 20.29
## 5 M radius_mean 12.45
## 6 M radius_mean 18.25
ggplot(data = bc.melt, aes(x = diagnosis, y = log(value + 1), fill = diagnosis)) +
geom_boxplot() + theme_bw() + facet_wrap(~variable, ncol = 8)

3. 数据处理

我们将整个数据进行分割,分为训练集和测试集,并保证其正负样本的比例,如下:

library(tidyverse)
data <- select(BreastCancer, -1) %>%
mutate_at("diagnosis", as.factor)
corrplot::corrplot(cor(data[, -1]))

数据分割

当我们只有一套数据的时候,可以将数据分为训练集和测试集,具体怎么分割可以看公众号的专题:Topic 5. 样本量确定及分割

# 数据分割 install.packages('sampling')
library(sampling)
set.seed(123)
# 每层抽取70%的数据
train_id <- strata(data, "diagnosis", size = rev(round(table(data$diagnosis) * 0.7)))$ID_unit
# 训练数据
train_data <- data[train_id, ]
# 测试数据
test_data <- data[-train_id, ]

# 查看训练、测试数据中正负样本比例
prop.table(table(train_data$diagnosis))
##
## B M
## 0.6281407 0.3718593

prop.table(table(test_data$diagnosis))
##
## B M
## 0.6294118 0.3705882

4. 实例操作

这里我们就是使用rpart软件包里面的函数 rpart完成教程。首先使用rpart函数建立一个分类树模型,然后绘制回归树,如下:

set.seed(123)
tree.pros <- rpart(diagnosis ~ ., data = train_data, method = "class")
printcp(tree.pros)
##
## Classification tree:
## rpart(formula = diagnosis ~ ., data = train_data, method = "class")
##
## Variables actually used in tree construction:
## [1] concave_points_mean texture_worst
##
## Root node error: 148/398 = 0.37186
##
## n= 398
##
## CP nsplit rel error xerror xstd
## 1 0.831081 0 1.00000 1.00000 0.065147
## 2 0.040541 1 0.16892 0.20270 0.035586
## 3 0.010000 2 0.12838 0.18243 0.033897
rpart.plot(tree.pros)

绘制Rpart Fit的复杂性参数表,如下:

tree.pros$cptable
## CP nsplit rel error xerror xstd
## 1 0.83108108 0 1.0000000 1.0000000 0.06514748
## 2 0.04054054 1 0.1689189 0.2027027 0.03558617
## 3 0.01000000 2 0.1283784 0.1824324 0.03389734
plotcp(tree.pros)

将各种对象强制转换为类方对象的函数,如下:

cp <- min(tree.pros$cptable[3, ])
prune.tree.pros <- prune(tree.pros, cp = cp)
plot(as.party(tree.pros))

交叉验证

ctrl <- trainControl(method = "cv", number = 10)

# CV bagged model
bagged_cv <- train(diagnosis ~ ., data = train_data, method = "treebag", trControl = ctrl,
importance = TRUE)
bagged_cv
## Bagged CART
##
## 398 samples
## 30 predictor
## 2 classes: 'B', 'M'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 359, 358, 358, 359, 358, 358, ...
## Resampling results:
##
## Accuracy Kappa
## 0.9448718 0.8830605
# plot most important variables
plot(varImp(bagged_cv), 20)

预测结果

我们在看下模型的性能,利用混合矩阵的方法,如下:

party.pros.test <- predict(prune.tree.pros, newdata = test_data, type = "class")
# rpart.resid <- party.pros.test - test_data$diagnosis #calculate residual
table(party.pros.test, test_data$diagnosis)
##
## party.pros.test B M
## B 100 12
## M 7 51
confusionMatrix(party.pros.test, test_data$diagnosis, positive = "B")
## Confusion Matrix and Statistics
##
## Reference
## Prediction B M
## B 100 12
## M 7 51
##
## Accuracy : 0.8882
## 95% CI : (0.831, 0.9314)
## No Information Rate : 0.6294
## P-Value [Acc > NIR] : 2.453e-14
##
## Kappa : 0.7564
##
## Mcnemar's Test P-Value : 0.3588
##
## Sensitivity : 0.9346
## Specificity : 0.8095
## Pos Pred Value : 0.8929
## Neg Pred Value : 0.8793
## Prevalence : 0.6294
## Detection Rate : 0.5882
## Detection Prevalence : 0.6588
## Balanced Accuracy : 0.8721
##
## 'Positive' Class : B
##

绘制ROC曲线

### ROC
library(ROSE)
roc.curve(party.pros.test, test_data$diagnosis, main = "ROC curve of Party", col = 2,
lwd = 2, lty = 2)
## Area under the curve (AUC): 0.886

结果解读

从准确性上来看仅为为 0.8882, 与之前的KNN(0.947), SVM(0.947)相比准确性低很多,所以在做乳腺癌这套数据时,选择回归树效果不加,目前还是 KNN或者SVM这两个机器学习方法算是最优的选择,后续我们还将使用这套数据。从整体上来说,回归树总体流程类似于分类树,分枝时穷举每一个特征的每一个阈值,来寻找最优切分特征j和最优切分点s,衡量的方法是平方误差最小化。分枝直到达到预设的终止条件(如叶子个数上限)就停止。当然,处理具体问题时,单一的回归树肯定是不够用的。可以利用集成学习中的boosting框架,对回归树进行改良升级,得到的新模型就是提升树(Boosting Decision Tree),在进一步,可以得到梯度提升树(Gradient Boosting Decision Tree,GBDT),再进一步可以升级到XGBoost。

References:

  1. Breiman L., Friedman J. H., Olshen R. A., and Stone, C. J. (1984) Classification and Regression Trees. Wadsworth.

  2. Mangasarian OL, Street WN, Wolberg WH. Breast cancer diagnosis and prognosis via linear programming. Operations Research. 1995; 43:570-577.

本文使用 文章同步助手 同步

上一篇下一篇

猜你喜欢

热点阅读