特征变量选择

R语言:lasso建模和预测

2021-01-17  本文已影响0人  胡童远

导读:

clustlasso函数包lasso函数进行建模和预测,包中clustlasso函数也可以进行相似的建模和预测。

clustlasso安装:
https://www.jianshu.com/p/2aed75aeca91

clustlasso lasso使用文档:
https://gitlab.com/biomerieux-data-science/clustlasso/-/blob/master/vignettes/vignette.pdf

1 加载包和数据

# load package
library(clustlasso)
# specify / set random seed
seed = 42
set.seed(seed)
# load example dataset
input.file = system.file("data", "NG-dataset.Rdata", package = "clustlasso")
load(input.file)

2 随机选择20%的ID

# pick 20% for test
test.frac = 0.2
# stratify by origin / population structure
ind.by.struct = split(seq(nrow(meta)), meta$pop_structure)
# split按值分割成列表
ind.sample = sapply(ind.by.struct, function(x){sample(x, round(test.frac * length(x)))})  # 每个表种select 20%, sample对List中的每个df执行一次function。

3 制备test set和train set

ind.test = unlist(ind.sample)
# test dataset
X.test = X[ind.test, ]
y.test = y[ind.test]
meta.test = meta[ind.test, ]
# train datasets
X.train = X[-ind.test, ]
y.train = y[-ind.test]
meta.train = meta[-ind.test, ]

4 建模和交叉验证

# 1. Cross-validation process
# specify cross-validation parameters
n.folds = 10
n.lambda = 100
n.repeat = 3
# run cross-validation process
cv.res.lasso = lasso_cv(X.train, y.train, subgroup = meta.train$pop_structure, n.lambda = n.lambda, n.folds = n.folds, n.repeat = n.repeat, seed = seed, verbose = FALSE)

pdf("cv.pdf", width=15)
par(mfcol = c(1, 3))  # 一页多图,一行三列
show_cv_overall(cv.res.lasso, modsel.criterion = "balanced.accuracy.best", best.eps = 1)
dev.off()

5 最佳模型

# 2. Selecting the best model
pdf("cv_best.pdf", width=15)
layout(matrix(c(1, 2, 3), nrow = 1, byrow = TRUE), width = c(0.3,
0.3, 0.4), height = c(1))
perf.best.lasso = show_cv_best(cv.res.lasso, modsel.criterion = "balanced.accuracy.best", best.eps = 1, method = "lasso")
dev.off()
# print cross-validation performance of best model
print(perf.best.lasso)
best.model.lasso = extract_best_model(cv.res.lasso, modsel.criterion = "balanced.accuracy.best", best.eps = 1)

6 模型预测和表型评估

# 3. Making predictions and measuring performance
# make predictions # preds.lasso$preds预测结果
preds.lasso = predict_clustlasso(X.test, best.model.lasso)
# compute performance
perf.lasso = compute_perf(preds.lasso$preds, preds.lasso$probs,
y.test)
# print
print(t(perf.lasso$perf))
pdf("predict.pdf", width=15)
par(mfcol = c(1, 2))
plot(perf.lasso$roc.curves[[1]], lwd = 2, main = "lasso - test set ROC curve")
grid()
plot(perf.lasso$pr.curves[[1]], lwd = 2, main = "lasso - test set precision / recall curve")
grid()
dev.off()

参考:
【机器学习】Cross-Validation(交叉验证)详解
Lasso regression(稀疏学习,R)
lasso_cv

上一篇下一篇

猜你喜欢

热点阅读