tidymodels学习实录-3(Evaluate your m

2022-05-13  本文已影响0人  灵活胖子的进步之路
library(tidymodels) # for the rsample package, along with the rest of tidymodels
library(modeldata)  # for the cells data

data(cells, package = "modeldata")
cells %>% 
  count(class) %>% 
  mutate(prop = n/sum(n))

预测结果变量分布情况,为二分类变量
#以下按照结果变量进行分层随机分组
set.seed(123)
cell_split <- initial_split(cells %>% select(-case), 
                            strata = class,
                            prop =0.75)
cell_train <- training(cell_split)
cell_test  <- testing(cell_split)

#以下查看建模组和测试组结果变量比例情况
cell_train %>% 
  count(class) %>% 
  mutate(prop = n/sum(n))


# test set proportions by class
cell_test %>% 
  count(class) %>% 
  mutate(prop = n/sum(n))
两组class的比例一样

以下用分类决策树模型构建模型

#定义模型及损失函数
rf_mod <- 
  rand_forest(trees = 1000) %>% 
  set_engine("ranger") %>% 
  set_mode("classification")

#在训练集进行训练
set.seed(234)
rf_fit <- 
  rf_mod %>% 
  fit(class ~ ., data = cell_train)
rf_fit
模型参数
#在训练集上进行模型的好坏的参数验证
rf_training_pred <- 
  predict(rf_fit, cell_train) %>% #预测分类结果
  bind_cols(predict(rf_fit, cell_train, type = "prob")) %>% #预测分类概率
  # 以下添加原始真实结果
  bind_cols(cell_train %>% 
              select(class))
rf_training_pred 
训练集上预测结果情况
#查看训练集AUC及准确率
rf_training_pred %>%                # training set predictions
  roc_auc(truth = class, .pred_PS)

rf_training_pred %>%                # training set predictions
  accuracy(truth = class, .pred_class)
训练集AUC及准确率情况
#以下用上述同样的方法计算测试集的AUC和准确率
rf_testing_pred <- 
  predict(rf_fit, cell_test) %>% 
  bind_cols(predict(rf_fit, cell_test, type = "prob")) %>% 
  bind_cols(cell_test %>% select(class))

rf_testing_pred %>%                   # test set predictions
  roc_auc(truth = class, .pred_PS)

rf_testing_pred %>%                   # test set predictions
  accuracy(truth = class, .pred_class)
测试集情况
#以下将进行交叉验证
#采用10成交叉验证
set.seed(345)
folds <- vfold_cv(cell_train, v = 10)

#定义分析流程
rf_wf <- 
  workflow() %>%
  add_model(rf_mod) %>%
  add_formula(class ~ .)

#利用10成交叉验证拟合训练集
set.seed(456)
rf_fit_rs <- 
  rf_wf %>% 
  fit_resamples(folds)
#显示拟合结果
collect_metrics(rf_fit_rs)
交叉验证拟合结果
上一篇下一篇

猜你喜欢

热点阅读