预测Data scienceR for statistics

机器学习--有监督--支持向量机SVM

2021-11-10  本文已影响0人  小贝学生信

支持向量机Support vector machine,SVM是一种有监督的机器学习算法,可用于分类或者回归。本次笔记以分类任务为例主要学习。

1、简单理解

1.1 线性可分--硬间隔

hard margin classifier,HML:此类样本数据是最简单的情况,即数据集样本可以明显地使用线性边界区分开。可参考下图,分为3个步骤

HML的特点是不允许有样本点位于间隔区内,即必须干净的划分;即不能有样本距决策边界的距离比间隔的长度还短,甚至在决策边界的另一端(误分类)

1.2 线性不可完全分--软间隔

1.3 线性不可分-核技巧

2、代码实操

(1)示例数据:预测员工是否离职

library(modeldata)
data(attrition)
# initial dimension
dim(attrition)
## [1] 1470   31
library(dplyr)
df <- attrition %>%
  mutate_if(is.ordered, factor, ordered = FALSE)
# Create training (70%) and test (30%) sets
set.seed(123) # for reproducibility
library(rsample)
churn_split <- initial_split(df, prop = 0.7, strata = "Attrition")
churn_train <- training(churn_split)
churn_test <- testing(churn_split)

(2)caret包建模

library(caret)
set.seed(1111) # for reproducibility
# Control params for SVM
ctrl <- trainControl(
  method = "cv",
  number = 10,
  classProbs = TRUE,  #表示返回分类概率,而不是直接分类标签结果
  summaryFunction = twoClassSummary # also needed for AUC/ROC
)

churn_svm <- train(
  Attrition ~ .,
  data = churn_train,
  method = "svmRadial",
  preProcess = c("center", "scale"),
  trControl = ctrl,
  metric = "ROC", # area under ROC curve (AUC)
  tuneLength = 10) #遍历C的10次取值,即从2的-2次方到2的7次方

#如下 C取4时,模型最优
churn_svm$results %>% arrange(desc(ROC)) %>% head(1)
#     sigma C       ROC      Sens      Spec      ROCSD     SensSD     SpecSD
# 1 0.009522278 4 0.8234039 0.9791767 0.2738971 0.07462533 0.02019714 0.08679811

# Plot results
ggplot(churn_svm) 
(3)测试集验证
pred = predict(churn_svm, churn_test)
table(pred)
# pred
# No Yes 
# 415  27
table(churn_test$Attrition)
# No Yes 
# 370  72
caret::confusionMatrix(pred, churn_test$Attrition, positive="Yes")
# Accuracy : 0.871           
# 95% CI : (0.8362, 0.9008)
# No Information Rate : 0.8371          
# P-Value [Acc > NIR] : 0.02819         
# 
# Kappa : 0.3681          
# 
# Mcnemar's Test P-Value : 5.611e-09       
#                                           
#             Sensitivity : 0.29167         
#             Specificity : 0.98378         
#          Pos Pred Value : 0.77778         
#          Neg Pred Value : 0.87711         
#              Prevalence : 0.16290         
#          Detection Rate : 0.04751         
#    Detection Prevalence : 0.06109         
#       Balanced Accuracy : 0.63773         
#                                           
#        'Positive' Class : Yes

(4)衡量特征重要性

library(vip)
prob_yes <- function(object, newdata) {
  predict(object, newdata = newdata, type = "prob")[, "Yes"]
}
# Variable importance plot
set.seed(2827) # for reproducibility
vip(churn_svm, method = "permute", nsim = 5, train = churn_train,
    target = "Attrition", metric = "auc", reference_class = "Yes",
    pred_wrapper = prob_yes)
library(pdp)
features <- c("OverTime", "JobRole")
pdps <- lapply(features, function(x) {
  partial(churn_svm, pred.var = x, which.class = 2,
          prob = TRUE, plot = TRUE, plot.engine = "ggplot2") +
    coord_flip()
})
grid.arrange(grobs = pdps, nrow = 1)

image.png
上一篇下一篇

猜你喜欢

热点阅读