机器学习--有监督--KNN(K近邻)

2021-11-08  本文已影响0人  小贝学生信

KNN是一种很简单的KNN有监督的机器学习算法;既可用于分类,也可用于回归任务。

1、KNN的简单理解

1.1 算法步骤

1.2 特点

2、R代码示例

示例数据--预测员工是否离职
library(modeldata)
data(attrition)
# initial dimension
dim(attrition)
# [1] 1470   31
#去因子化
attrit <- attrition %>% mutate_if(is.ordered, factor, ordered = FALSE)
churn_split <- rsample::initial_split(attrit, prop = .7, 
                                      strata = "Attrition")
churn_train <- rsample::training(churn_split)
churn_test <- rsample::testing(churn_split)

step1:数据预处理

注意两个方面(1)数值变量标准化;(2)类别变量转化

library(recipes)
blueprint <- recipe(Attrition ~ ., data = churn_train) %>%
  step_nzv(all_nominal()) %>%  #去除低变异的变量
  step_integer(contains("Satisfaction")) %>% #类别变量转换
  step_integer(WorkLifeBalance) %>% #类别变量转换
  step_integer(JobInvolvement) %>% #类别变量转换
  step_dummy(all_nominal(), -all_outcomes(), one_hot = TRUE) %>% #类别变量转换
  step_center(all_numeric(), -all_outcomes()) %>% #中心化
  step_scale(all_numeric(), -all_outcomes()) #归一化

step2:寻找最佳k值

(1)k值的grid search遍历比较

一般来说K值取奇数:当二分类任务时,不会出现投票数相同的情况

# Create a hyperparameter grid search
str(floor(seq(1, nrow(churn_train)/3, length.out = 20)))
# num [1:20] 1 18 36 54 72 90 108 126 144 162 ...
hyper_grid <- expand.grid(
  k = floor(seq(1, nrow(churn_train)/3, length.out = 20)))
(2) 交叉验证设置

repeatedcv方法相较于之前遇到的cv,可以理解为做多(n)次k折交叉验证,然后取n次的均值作为模型性能评价。
如下设置表示做5次k折交叉验证;每次k折分为10份,采用留一法,做10次。

# Create a resampling method
cv <- trainControl(
  method = "repeatedcv",
  number = 10,
  repeats = 5,
  classProbs = TRUE,
  summaryFunction = twoClassSummary)

ggplot(knn_grid)
(3) 确定最佳k值
# Fit knn model and perform grid search
knn_grid <- train(
  blueprint,
  data = churn_train,
  method = "knn",
  trControl = cv,
  tuneGrid = hyper_grid,
  metric = "ROC")

knn_grid$bestTune
# 198
knn_grid$results[knn_grid$results$k==198,]
#     k       ROC Sens Spec      ROCSD SensSD SpecSD
# 12 198 0.8041737    1    0 0.05500748      0      0
ggplot(knn_grid)

step3:预测测试集

pred = predict(knn_grid, newdata = churn_test)
confusionMatrix(pred, churn_test$Attrition)
# Confusion Matrix and Statistics
# 
# Reference
# Prediction  No Yes
# No  370  72
# Yes   0   0
# 
# Accuracy : 0.8371          
# 95% CI : (0.7993, 0.8703)
# No Information Rate : 0.8371          
# P-Value [Acc > NIR] : 0.5314          
# 
# Kappa : 0               
# 
# Mcnemar's Test P-Value : <2e-16          
#                                           
#             Sensitivity : 1.0000          
#             Specificity : 0.0000          
#          Pos Pred Value : 0.8371          
#          Neg Pred Value :    NaN          
#              Prevalence : 0.8371          
#          Detection Rate : 0.8371          
#    Detection Prevalence : 1.0000          
#       Balanced Accuracy : 0.5000          
#                                           
#        'Positive' Class : No 
上一篇 下一篇

猜你喜欢

热点阅读