公众号-科研私家菜学习记录(5)

2021-08-06  本文已影响0人  明眸意海

交叉验证与模型选择

  1. 示例
library(ISLR)
data('Auto') ##载入数据集
head(Auto)
set.seed(111)

n=nrow(Auto) ## 共392个样本
train=sample(n,n/2) ## 选择50%的样本作为训练集
test=(-train) ## 50%的样本作为测试集


lm.fit=lm(mpg~horsepower, data=Auto, subset=train) 
## 使用训练集拟合线性回归模型
## 以horsepower来拟合mpg的值
mean((Auto[test,'mpg']-predict(lm.fit, newdata=Auto[test,]))^2) ## 计算均方误差

lm.fit2=lm(mpg~poly(horsepower,2), data=Auto, subset=train) 
## 使用训练集拟合多项式回归
mean((Auto[test,'mpg']-predict(lm.fit2, newdata=Auto[test,]))^2) ## 均方误差

简单交叉验证

MSE=matrix(NA,10,10) ## 建10行10列的Matrix

for(seed in 1:10){
  set.seed(seed)
  train=sample(n,n/2)
  test=(-train)
  for(degree in 1:10){
    lm.fit=lm(mpg~poly(horsepower,degree), data=Auto,subset=train)
    MSE[seed,degree]=mean((Auto[test,'mpg']-predict(lm.fit,newdata=Auto[test,]))^2)  
  }
}
## 结果可视化
plot(MSE[1,],
     ylim=range(MSE),
     type='l',
     lwd=2,
     col=rainbow(10)[1],
     xlab='degree',
     ylab='the estimated test MSE')

for(seed in 2:10){
  points(MSE[seed,],
         type='l',
         lwd=2,
         col=ggsci::pal_npg()(10)[seed])
} 

K-折交叉验证

library(boot)
cv.error.10=rep(NA,10)
for(degree in 1:10){
  glm.fit=glm(mpg ~ poly(horsepower,degree), data=Auto)
  set.seed(1234)
  cv.error.10[degree] <- cv.glm(Auto,glm.fit,K=10)$delta[1]
}
cv.error.10

plot(cv.error.10,type='b',xlab='degree',col=ggsci::pal_npg()(10))

留一交叉验证

loocv.error=rep(NA,10)
for(degree in 1:10){
  glm.fit=glm(mpg ~ poly(horsepower,degree), data=Auto)
  loocv.error[degree]=cv.glm(Auto,glm.fit,K = nrow(Auto))$delta[1]
}
loocv.error 

plot(loocv.error,type='b',xlab='degree',col=ggsci::pal_npg()(10))
上一篇 下一篇

猜你喜欢

热点阅读