R plotR for statistics

机器学习--有监督--多元自适应回归MARS

2021-11-11  本文已影响0人  小贝学生信
MARS可简单理解为分段线性函数,针对某一特征变量x与响应变量y存在较为复杂的非线性关系,通过寻找合适数目(n)的cut point/knot分隔为若干(n+1)近似线性模型(hinge function)。

1、多元回归与MARS的超参数

背景知识理解
超参数:the number of terms retained in the final model

2、代码实操

示例数据:预测房价
ames <- AmesHousing::make_ames()
dim(ames)
## [1] 2930   81

set.seed(123)
library(rsample)
split <- initial_split(ames, prop = 0.7, 
                       strata = "Sale_Price")
ames_train  <- training(split)
# [1] 2049   81
ames_test   <- testing(split)
# [1] 881  81
library(caret)
# create a tuning grid
hyper_grid <- expand.grid(
  degree = 1:3,
  nprune = seq(2, 100, length.out = 10) %>% floor()
)
# 30 combination
head(hyper_grid)
#   degree nprune
# 1      1      2
# 2      2      2
# 3      3      2
# 4      1     12
# 5      2     12
# 6      3     12
# Cross-validated model
set.seed(1111) # for reproducibility
cv_mars <- train(
  x = subset(ames_train, select = -Sale_Price),
  y = ames_train$Sale_Price,
  method = "earth",
  metric = "RMSE",
  trControl = trainControl(method = "cv", number = 10),
  tuneGrid = hyper_grid
)
# View results
cv_mars$bestTune
#   nprune degree
# 5     45      1
# 最佳参数组合的模型性能
cv_mars$results %>%
  dplyr::filter(nprune == cv_mars$bestTune$nprune, degree == cv_mars$bestTune$degree)
#   degree nprune     RMSE  Rsquared      MAE   RMSESD RsquaredSD    MAESD
# 1      1     45 26435.26 0.8903344 17013.83 4390.809 0.03478498 1583.466

#可视化
ggplot(cv_mars)
pred = predict(cv_mars, ames_test)
caret::RMSE(ames_test$Sale_Price, pred)
# [1] 23703.75
library(vip)
p1 <- vip(cv_mars, num_features = 40, geom = "point", value = "gcv") + ggtitle("GCV")
p2 <- vip(cv_mars, num_features = 40, geom = "point", value = "rss") + ggtitle("RSS")
gridExtra::grid.arrange(p1, p2, ncol = 2)
library(pdp)
# Construct partial dependence plots
p1 <- partial(cv_mars, pred.var = "Gr_Liv_Area", grid.resolution = 10) %>%
  autoplot()
p2 <- partial(cv_mars, pred.var = "Year_Built", grid.resolution = 10) %>%
  autoplot()
p3 <- partial(cv_mars, pred.var = c("Gr_Liv_Area", "Year_Built"),
              grid.resolution = 10) %>%
  plotPartial(levelplot = FALSE, zlab = "yhat", drape = TRUE, colorkey = TRUE,
              screen = list(z = -20, x = -60))
# Display plots side by side
gridExtra::grid.arrange(p1, p2, p3, ncol = 3)

此外MARS也可以用于分类问题,这里就暂不记录了

上一篇下一篇

猜你喜欢

热点阅读