R语言与统计分析数据科学与R语言

60-R语言中的神经网络

2020-03-27  本文已影响0人  wonphen

《深度学习精要(基于R语言)》学习笔记

1、什么是深度学习

机器学习主要用于开发和使用那些从原始数据中学习、总结出来的用于进行预测的算法。
深度学习是一种强大的多层架构,可以用于模式识别、信号检测以及分类或预测等多个领域。
神经网络包括一系列的神经元,或者叫作节点,它们彼此连结并处理输入。神经元之间的连结经过加权处理,权重取决于从数据中学习、总结出的使用函数。一组神经元的激活和权重(从数据中自适应地学习)可以提供给其他的神经元,其中一些最终神经元的激活就是预测。
经常选择的激活函数是sigmoid函数以及双曲正切函数tanh,因为径向基函数是有效的函数逼近,所以有时也会用到它们。
权重是从每个隐藏单元到每个输出的路径,对第i个的输出通过(w_i)表示。如创建隐藏层的权重,这些权重也是从数据中学习得到的。分类会经常使用一种最终变换,softmax函数。线性回归经常使用恒等(identity)函数,它返回输入值。权重必须从数据中学习得到,权重为零或接近零基本上等同于放弃不必要的关系。

R中神经网络相关包:

2、初始化h2o

> library(pacman)
> p_load(h2o)
> cl <- h2o.init(
+   # 最大使用内存
+   max_mem_size = "4G",
+   # 线程数
+   nthreads = 4
+ )
##  Connection successful!
## 
## R is connected to the H2O cluster: 
##     H2O cluster uptime:         2 hours 27 minutes 
##     H2O cluster timezone:       Asia/Shanghai 
##     H2O data parsing timezone:  UTC 
##     H2O cluster version:        3.28.0.4 
##     H2O cluster version age:    1 month  
##     H2O cluster name:           H2O_started_from_R_Admin_qxq246 
##     H2O cluster total nodes:    1 
##     H2O cluster total memory:   3.40 GB 
##     H2O cluster total cores:    4 
##     H2O cluster allowed cores:  4 
##     H2O cluster healthy:        TRUE 
##     H2O Connection ip:          localhost 
##     H2O Connection port:        54321 
##     H2O Connection proxy:       NA 
##     H2O Internal Security:      FALSE 
##     H2O API Extensions:         Amazon S3, Algos, AutoML, Core V3, TargetEncoder, Core V4 
##     R Version:                  R version 3.6.3 (2020-02-29)

一旦集群完成初始化,可以使用R或本地主机(127.0.0.1:54321)提供的Web接口与它连接。

3、上传数据到h2o集群

如果数据集已经加载到R,使用as.h2o()函数:

> h2oiris <- as.h2o(droplevels(iris))
> h2oiris
> 
> # 检查因子变量水平
> h2o.levels(h2oiris, "Species")

如果数据没有载入R,可以直接导入到h2o中:

> # 直接导入文件
> h2o.kaoshi <- h2o.importFile(path = "./ks/CommViolPredUnnormalizedData.csv")
> h2o.kaoshi

也可以直接导入网络上的文件:

> h2o.bin <- h2o.importFile(path = "http://www.ats.ucla.edu/stat/data/binary.csv")
> h2o.bin


导入基于图片识别手写体数字,数据集的每一列(即特征),表示图像的一个像素。每张图像都经过标准化处理,转化成同样的大小,所以所有图像的像素个数都相同。第一列包含真实的数据标签,其余各列是黑暗像素的值,它用于分类。

> digits.train <- read.csv("./data_set/digit-recognizer/train.csv")
> dim(digits.train)
## [1] 42000   785
> head(colnames(digits.train), 5)
## [1] "label"  "pixel0" "pixel1" "pixel2" "pixel3"
> tail(colnames(digits.train), 5)
## [1] "pixel779" "pixel780" "pixel781" "pixel782" "pixel783"
> head(digits.train[, 1:5])
##   label pixel0 pixel1 pixel2 pixel3
## 1     1      0      0      0      0
## 2     0      0      0      0      0
## 3     1      0      0      0      0
## 4     4      0      0      0      0
## 5     0      0      0      0      0
## 6     0      0      0      0      0
> class(digits.train[, 1])
## [1] "integer"
> # 将label列转换为因子型,让R知道这是一个分类问题
> digits.train$label <- factor(digits.train$label, levels = 0:9)
> 
> # 笔记本电脑性能有限,减少数据集行数
> digit.x <- digits.train[1:5000, -1]
> digit.y <- digits.train[1:5000, 1]
> 
> # 查看label列数字分布情况
> p_load(magrittr, ggplot2, caret)
> table(digit.y) %>% as.data.frame %>% 
+     ggplot(aes(x = Freq, y = digit.y)) + 
+     geom_bar(stat = "identity") + 
+     labs(x = "", y = "") + theme_bw()
image.png

4、nnet包训练预测模型

使用caret包训练模型:

> set.seed(123)
> digit.ml <- train(x=digit.x,y=digit.y,
+                   method="nnet",
+                   tuneGrid=expand.grid(
+                     # 5个隐藏神经元
+                     .size=5,
+                     # 衰变率
+                     .decay=0.1
+                   ),
+                   trControl=trainControl(method="none"),
+                   # 最大权重数量
+                   MaxNWts=10000,
+                   # 最大迭代次数
+                   maxit=100)

生成数据的一组预测,查看柱状图:

> pred.ml <- predict(digit.ml)
>
> pred.ml %>% table %>% as.data.frame %>% ggplot(aes(x = Freq, y = .)) + geom_bar(stat = "identity") + 
+     labs(x = "", y = "") + theme_bw()
预测值柱状图

跟训练集数据柱状图对比,很明显模型不是最优的。
通过混淆矩阵检查模型性能:

> caret::confusionMatrix(pred.ml, digit.y)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1   2   3   4   5   6   7   8   9
##          0 440   0   7  15  11 144  13   8  18   3
##          1   0 424  10   1   0   6   1   2   5   2
##          2   0   0   0   0   0   0   0   0   0   0
##          3   0   0   0   0   0   0   0   0   0   0
##          4   0   0   0   0   0   0   0   0   0   0
##          5   0   0   0   0   0   0   0   0   0   0
##          6  52  20 480 425  98 250 500  22 338  34
##          7   2 114  48  39 367  69   2 473 116 439
##          8   0   0   0   0   0   0   0   0   0   0
##          9   0   0   0   0   1   0   0   1   0   0
## 
## Overall Statistics
##                                          
##                Accuracy : 0.3674         
##                  95% CI : (0.354, 0.3809)
##     No Information Rate : 0.1116         
##     P-Value [Acc > NIR] : < 2.2e-16      
##                                          
##                   Kappa : 0.295          
##                                          
##  Mcnemar's Test P-Value : NA             
## 
## Statistics by Class:
## 
##                      Class: 0 Class: 1 Class: 2 Class: 3 Class: 4
## Sensitivity            0.8907   0.7599    0.000    0.000   0.0000
## Specificity            0.9514   0.9939    1.000    1.000   1.0000
## Pos Pred Value         0.6677   0.9401      NaN      NaN      NaN
## Neg Pred Value         0.9876   0.9705    0.891    0.904   0.9046
## Prevalence             0.0988   0.1116    0.109    0.096   0.0954
## Detection Rate         0.0880   0.0848    0.000    0.000   0.0000
## Detection Prevalence   0.1318   0.0902    0.000    0.000   0.0000
## Balanced Accuracy      0.9210   0.8769    0.500    0.500   0.5000
##                      Class: 5 Class: 6 Class: 7 Class: 8 Class: 9
## Sensitivity            0.0000   0.9690   0.9348   0.0000   0.0000
## Specificity            1.0000   0.6166   0.7339   1.0000   0.9996
## Pos Pred Value            NaN   0.2253   0.2834      NaN   0.0000
## Neg Pred Value         0.9062   0.9942   0.9901   0.9046   0.9044
## Prevalence             0.0938   0.1032   0.1012   0.0954   0.0956
## Detection Rate         0.0000   0.1000   0.0946   0.0000   0.0000
## Detection Prevalence   0.0000   0.4438   0.3338   0.0000   0.0004
## Balanced Accuracy      0.5000   0.7928   0.8343   0.5000   0.4998

No Information Rate(无信息率)指不考虑任何信息而仅仅通过猜测来决定最频繁的类的准确度期望。在情形“1”中,它在11.16%的时间中发生。P值(P-Value [Acc > NIR])检验了观测准确度(Accuracy : 0.3674)是否显著不同于无信息率(11.16%)。
Class: 0的灵敏度(Sensitivity)可以解释为:89.07%的数字0被正确地预测为0。特异度(Specificity)可以解释为:95.14%的预测为非数字0被预测为不是数字0。
检出率(Detection Rate)是真阳性的百分比,而最后的检出预防度(detection prevalence)是预测为阳性的实例比例,不管它们是否真的为阳性。
平衡准确度(balanced accuracy)是灵敏度和特异度的平均值。

接下来我们通过增加神经元的个数来提升模型的性能,其代价是模型的复杂性会显著增加:

> set.seed(123)
> digit.ml2 <- train(x=digit.x,y=digit.y,
+                   method="nnet",
+                   tuneGrid=expand.grid(
+                     # 10个隐藏神经元
+                     .size=10,
+                     # 衰变率
+                     .decay=0.1
+                   ),
+                   trControl=trainControl(method="none"),
+                   # 最大权重数量
+                   MaxNWts=50000,
+                   # 最大迭代次数
+                   maxit=100)
>
> pred.ml2 <- predict(digit.ml2)
> caret::confusionMatrix(xtabs(~pred.ml2 + digit.y))
## Confusion Matrix and Statistics
## 
##         digit.y
## pred.ml2   0   1   2   3   4   5   6   7   8   9
##        0 403   0  32   1   3  37  10  11   6   4
##        1   0 496  13   2   1   1   6   7   8   3
##        2  17   5 312   6  12  26  92   2  12   5
##        3   0   1   2 386   2  86   2   2  45   8
##        4   0   0  44   2 418  32   6  60   6 355
##        5  30   2   5  12   2 200  29   5  26  10
##        6  14   0   9   0   2  16 366   0   2   0
##        7   5   4   7  18  33  16   1 364  47  83
##        8  25  50 121  53   2  55   4  55 325  10
##        9   0   0   0   0   2   0   0   0   0   0
## 
## Overall Statistics
##                                           
##                Accuracy : 0.654           
##                  95% CI : (0.6406, 0.6672)
##     No Information Rate : 0.1116          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.6155          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: 0 Class: 1 Class: 2 Class: 3 Class: 4
## Sensitivity            0.8158   0.8889   0.5725   0.8042   0.8763
## Specificity            0.9769   0.9908   0.9603   0.9673   0.8883
## Pos Pred Value         0.7949   0.9236   0.6380   0.7228   0.4529
## Neg Pred Value         0.9797   0.9861   0.9483   0.9790   0.9855
## Prevalence             0.0988   0.1116   0.1090   0.0960   0.0954
## Detection Rate         0.0806   0.0992   0.0624   0.0772   0.0836
## Detection Prevalence   0.1014   0.1074   0.0978   0.1068   0.1846
## Balanced Accuracy      0.8964   0.9398   0.7664   0.8857   0.8823
##                      Class: 5 Class: 6 Class: 7 Class: 8 Class: 9
## Sensitivity            0.4264   0.7093   0.7194   0.6813   0.0000
## Specificity            0.9733   0.9904   0.9524   0.9171   0.9996
## Pos Pred Value         0.6231   0.8949   0.6298   0.4643   0.0000
## Neg Pred Value         0.9425   0.9673   0.9679   0.9647   0.9044
## Prevalence             0.0938   0.1032   0.1012   0.0954   0.0956
## Detection Rate         0.0400   0.0732   0.0728   0.0650   0.0000
## Detection Prevalence   0.0642   0.0818   0.1156   0.1400   0.0004
## Balanced Accuracy      0.6999   0.8499   0.8359   0.7992   0.4998

隐藏神经元的数量从5个增加到10个,样本内性能的总准确度从36.74% 提升到了 65.4%。我们继续增加隐藏神经元的数量:

> set.seed(123)
> digit.ml3 <- train(x=digit.x,y=digit.y,
+                   method="nnet",
+                   tuneGrid=expand.grid(
+                     # 40个隐藏神经元
+                     .size=40,
+                     # 衰变率
+                     .decay=0.1
+                   ),
+                   trControl=trainControl(method="none"),
+                   # 最大权重数量
+                   MaxNWts=50000,
+                   # 最大迭代次数
+                   maxit=100)
>
> pred.ml3 <- predict(digit.ml2)
> caret::confusionMatrix(xtabs(~pred.ml3 + digit.y))
## Confusion Matrix and Statistics
## 
##         digit.y
## pred.ml3   0   1   2   3   4   5   6   7   8   9
##        0 403   0  32   1   3  37  10  11   6   4
##        1   0 496  13   2   1   1   6   7   8   3
##        2  17   5 312   6  12  26  92   2  12   5
##        3   0   1   2 386   2  86   2   2  45   8
##        4   0   0  44   2 418  32   6  60   6 355
##        5  30   2   5  12   2 200  29   5  26  10
##        6  14   0   9   0   2  16 366   0   2   0
##        7   5   4   7  18  33  16   1 364  47  83
##        8  25  50 121  53   2  55   4  55 325  10
##        9   0   0   0   0   2   0   0   0   0   0
## 
## Overall Statistics
##                                           
##                Accuracy : 0.654           
##                  95% CI : (0.6406, 0.6672)
##     No Information Rate : 0.1116          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.6155          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: 0 Class: 1 Class: 2 Class: 3 Class: 4
## Sensitivity            0.8158   0.8889   0.5725   0.8042   0.8763
## Specificity            0.9769   0.9908   0.9603   0.9673   0.8883
## Pos Pred Value         0.7949   0.9236   0.6380   0.7228   0.4529
## Neg Pred Value         0.9797   0.9861   0.9483   0.9790   0.9855
## Prevalence             0.0988   0.1116   0.1090   0.0960   0.0954
## Detection Rate         0.0806   0.0992   0.0624   0.0772   0.0836
## Detection Prevalence   0.1014   0.1074   0.0978   0.1068   0.1846
## Balanced Accuracy      0.8964   0.9398   0.7664   0.8857   0.8823
##                      Class: 5 Class: 6 Class: 7 Class: 8 Class: 9
## Sensitivity            0.4264   0.7093   0.7194   0.6813   0.0000
## Specificity            0.9733   0.9904   0.9524   0.9171   0.9996
## Pos Pred Value         0.6231   0.8949   0.6298   0.4643   0.0000
## Neg Pred Value         0.9425   0.9673   0.9679   0.9647   0.9044
## Prevalence             0.0938   0.1032   0.1012   0.0954   0.0956
## Detection Rate         0.0400   0.0732   0.0728   0.0650   0.0000
## Detection Prevalence   0.0642   0.0818   0.1156   0.1400   0.0004
## Balanced Accuracy      0.6999   0.8499   0.8359   0.7992   0.4998

增加到40个神经元后准确度跟10个神经元的一样,还是65.4%。如果是商业问题,还需要继续调节神经元的数量和衰变率。但是作为学习,模型对数字9的表现比较差,对其他数字都还行。

5、RSNNS包训练预测模型

RSNNS包提供了使用斯图加特神经网络仿真器(Stuttgart Neural Network Simulator , SNNS)模型的接口,但是,对基本的、单隐藏层的、前馈的神经网络,我们可以使用mlp()这个更为方便的封装函数,它的名称表示多层感知器(multi-layer perceptron)。
RSNNS包要求输入为矩阵、响应变量为一个哑变量的矩阵,因此每个可能的类表示成矩阵列中的 0/1 编码。

> p_load(RSNNS)
> digit.y.mat <- decodeClassLabels(digit.y);head(digit.y.mat)
##      0 1 2 3 4 5 6 7 8 9
## [1,] 0 1 0 0 0 0 0 0 0 0
## [2,] 1 0 0 0 0 0 0 0 0 0
## [3,] 0 1 0 0 0 0 0 0 0 0
## [4,] 0 0 0 0 1 0 0 0 0 0
## [5,] 1 0 0 0 0 0 0 0 0 0
## [6,] 1 0 0 0 0 0 0 0 0 0

通过decodeClassLabels()函数可以很方便的将数据转换为哑变量矩阵。

> set.seed(123)
> digit.ml4 <- mlp(as.matrix(digit.x),
+                  digit.y.mat,
+                  # 40个隐藏神经元
+                  size=40,
+                  learnFunc="Rprop",
+                  maxit=60)
> 
> # 返回一个矩阵,每列代表单个数字,值为对应概率
> pred.ml4 <- fitted.values(digit.ml4);head(pred.ml4)
##              [,1]         [,2]         [,3]         [,4]         [,5]
## [1,] 2.056155e-09 9.958593e-01 0.0052626291 0.0001699293 0.0002909484
## [2,] 9.887955e-01 2.212749e-06 0.0004935671 0.0069587487 0.0003329390
## [3,] 2.844076e-08 9.919980e-01 0.0046710004 0.0026906121 0.0002593570
## [4,] 7.048693e-06 3.588733e-04 0.6521131396 0.1404305398 0.0200663172
## [5,] 9.500439e-01 1.362774e-04 0.0010998574 0.0007325016 0.0128032062
## [6,] 9.660823e-01 1.292595e-07 0.0018860876 0.0018403936 0.0150223514
##              [,6]        [,7]         [,8]        [,9]       [,10]
## [1,] 3.280851e-03 0.000226795 1.846895e-02 0.267083108 0.000890466
## [2,] 7.944548e-05 0.001482799 1.073702e-03 0.007864795 0.017809818
## [3,] 2.077550e-02 0.022846304 6.169527e-02 0.074271396 0.022611069
## [4,] 1.612732e-04 0.035801634 6.758716e-05 0.001063690 0.070019640
## [5,] 4.495332e-05 0.006684554 5.744500e-03 0.001458221 0.051504783
## [6,] 3.631760e-03 0.007321015 6.434801e-04 0.007034578 0.137325883
> # 将结果转换为数字标签的单个向量
> pred.ml4.2 <- encodeClassLabels(pred.ml4);head(pred.ml4.2,20)
## [1]  2  1  2  3  1  1  8  4  5  4  9 10  2  4  4  2  3  1  8  6

预测结果的值为1-10,但是实际值为0-9,所以在生成混淆矩阵时,需要先减去1:

> caret::confusionMatrix(xtabs(~I(pred.ml4.2 - 1) + digit.y))
## Confusion Matrix and Statistics
## 
##                  digit.y
## I(pred.ml4.2 - 1)   0   1   2   3   4   5   6   7   8   9
##                 0 448   0   4   5   3   6   9   5   2   3
##                 1   0 533   3   2   0   0   0   8  17   5
##                 2   6   7 484  23   5  16   5   9  16   5
##                 3   3   1  13 389   3  29   0   3  28   9
##                 4   6   2   4   4 393   8   9   8  21  41
##                 5   9   4   5  24   8 361   4   2  32   5
##                 6  12   3   8   1   5   9 487   1   9   4
##                 7   2   2   5   3   3   0   0 443   5  20
##                 8   4   6  12  19   6  31   1   3 335   4
##                 9   4   0   7  10  51   9   1  24  12 382
## 
## Overall Statistics
##                                           
##                Accuracy : 0.851           
##                  95% CI : (0.8408, 0.8608)
##     No Information Rate : 0.1116          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.8344          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: 0 Class: 1 Class: 2 Class: 3 Class: 4 Class: 5
## Sensitivity            0.9069   0.9552   0.8881   0.8104   0.8239   0.7697
## Specificity            0.9918   0.9921   0.9793   0.9803   0.9772   0.9795
## Pos Pred Value         0.9237   0.9384   0.8403   0.8138   0.7923   0.7952
## Neg Pred Value         0.9898   0.9944   0.9862   0.9799   0.9813   0.9762
## Prevalence             0.0988   0.1116   0.1090   0.0960   0.0954   0.0938
## Detection Rate         0.0896   0.1066   0.0968   0.0778   0.0786   0.0722
## Detection Prevalence   0.0970   0.1136   0.1152   0.0956   0.0992   0.0908
## Balanced Accuracy      0.9493   0.9737   0.9337   0.8954   0.9006   0.8746
##                      Class: 6 Class: 7 Class: 8 Class: 9
## Sensitivity            0.9438   0.8755   0.7023   0.7992
## Specificity            0.9884   0.9911   0.9810   0.9739
## Pos Pred Value         0.9035   0.9172   0.7957   0.7640
## Neg Pred Value         0.9935   0.9861   0.9690   0.9787
## Prevalence             0.1032   0.1012   0.0954   0.0956
## Detection Rate         0.0974   0.0886   0.0670   0.0764
## Detection Prevalence   0.1078   0.0966   0.0842   0.1000
## Balanced Accuracy      0.9661   0.9333   0.8416   0.8865

RSNNS包的学习算法使用了相同数目的隐藏神经元,计算结果的性能却有极大提高。
函数I()有两个作用:
1.在对data.frame的调用中将对象包含在I()中来保护它,防止字符向量到factor的转换和名称的删除,并确保矩阵作为单列插入。
2.在formula函数中,它被用来禁止将“+”、“-”、“*”和“^”等运算符解释为公式运算符,因此它们被用作算术运算符。

6、从神经网络生成预测

从RSNNS包返回的预测值(pred.ml4)中可以看到,一个观测可能有40%的概率成为“5”,20%的概率成为“6”,等等。最简单的方法就是基于高预测概率来对观测进行分类。RSNNS包有一种称为赢者通吃(winner takes all,WTA)的方法,只要没有关系就选择概率最高的类,最高的概率高于用户定义的阈值(这个阈值可以是0),而其他类的预测概率都低于最大值减去另一个用户定义的阈值,否则观测的分类就不明了。如果这两个阈值都是0(缺省),那么最大值必然存在并且唯一。这种方法的优点是它提供了某种质量控制。
但是在实际应用中,比如一个医学背景下,我们收集了病人的多种生物指标和基因信息,用来分类确定他们是否健康,是否有患癌症的风险,是否有患心脏病的风险,即使有40%的患癌概率也需要病人进一步做检查,即便他健康的概率是60%。RSNNS包中还提供一种分类方法称为“402040”,如果一个值高于用户定义的阈值,而所有的其他值低于用户定义的另一个阈值。如果多个值都高于第一个阈值,或者任何值都不低于第二个阈值,我们就把观测定性为未知的。这样做的目的是再次给出了某种质量控制。

> # 缺省情况,赢者通吃
> table(encodeClassLabels(pred.ml4, method = "WTA", l = 0, h = 0))
## 
##   1   2   3   4   5   6   7   8   9  10 
## 485 568 576 478 496 454 539 483 421 500
> # 402040方法,0.4-0.6之间定义为未知
> table(encodeClassLabels(pred.ml4, method = "402040", l = 0.4, h = 0.6))
## 
##    0    1    2    3    4    5    6    7    8    9   10 
## 1222  439  528  427  325  344  299  458  409  215  334

“0”分类表示未知的预测。

7、对预测结果的解释

通常来说,过拟合指模型在训练集上的性能优于测试集。过拟合发生在模型正好拟合了训练数据的噪声部分的时候。因为考虑了噪声,它似乎更准确,但一个数据集和下一个数据集的噪声不同,这种准确度不能运用于除了训练数据之外的任何数据 — 它没有一般化。
使用RSNNS模型对样本外数据预测:

> pred.ml4.test <- predict(digit.ml4, as.matrix(train[5001:10000, -1]))
> caret::confusionMatrix(xtabs(~I(encodeClassLabels(pred.ml4.test) - 1) + 
+     train[5001:10000, 1]))
## Confusion Matrix and Statistics
## 
##                                        train[5001:10000, 1]
## I(encodeClassLabels(pred.ml4.test) - 1)   0   1   2   3   4   5   6   7   8
##                                       0 440   0   2   7   6  10   2   5   3
##                                       1   0 506   6   4   6   2   5  10  29
##                                       2   9   9 422  33  12  16  12  14  17
##                                       3   4   2  35 371   5  33   1  11  23
##                                       4   9   2   9   3 361  15   8  11  28
##                                       5   9   2   4  49  11 316   8   2  33
##                                       6  10   3   6   8   9  11 444   3   4
##                                       7   3   0   6  11   5   1   0 425   8
##                                       8  10  10   7  27   6  28   4   6 316
##                                       9   3   3   3  16  69   5   3  46  12
##                                        train[5001:10000, 1]
## I(encodeClassLabels(pred.ml4.test) - 1)   9
##                                       0   3
##                                       1   3
##                                       2   3
##                                       3   8
##                                       4  53
##                                       5   5
##                                       6   5
##                                       7  31
##                                       8   7
##                                       9 399
## 
## Overall Statistics
##                                          
##                Accuracy : 0.8            
##                  95% CI : (0.7886, 0.811)
##     No Information Rate : 0.1074         
##     P-Value [Acc > NIR] : < 2.2e-16      
##                                          
##                   Kappa : 0.7777         
##                                          
##  Mcnemar's Test P-Value : NA             
## 
## Statistics by Class:
## 
##                      Class: 0 Class: 1 Class: 2 Class: 3 Class: 4 Class: 5
## Sensitivity            0.8853   0.9423   0.8440   0.7013   0.7367   0.7231
## Specificity            0.9916   0.9854   0.9722   0.9727   0.9694   0.9730
## Pos Pred Value         0.9205   0.8862   0.7715   0.7525   0.7234   0.7198
## Neg Pred Value         0.9874   0.9930   0.9825   0.9649   0.9713   0.9735
## Prevalence             0.0994   0.1074   0.1000   0.1058   0.0980   0.0874
## Detection Rate         0.0880   0.1012   0.0844   0.0742   0.0722   0.0632
## Detection Prevalence   0.0956   0.1142   0.1094   0.0986   0.0998   0.0878
## Balanced Accuracy      0.9384   0.9639   0.9081   0.8370   0.8531   0.8481
##                      Class: 6 Class: 7 Class: 8 Class: 9
## Sensitivity            0.9117   0.7974   0.6681   0.7718
## Specificity            0.9869   0.9854   0.9768   0.9643
## Pos Pred Value         0.8827   0.8673   0.7506   0.7138
## Neg Pred Value         0.9904   0.9761   0.9657   0.9734
## Prevalence             0.0974   0.1066   0.0946   0.1034
## Detection Rate         0.0888   0.0850   0.0632   0.0798
## Detection Prevalence   0.1006   0.0980   0.0842   0.1118
## Balanced Accuracy      0.9493   0.8914   0.8224   0.8680

模型在第一个5000行上的准确度为85.1%,在第二个5000行上的准确度减少为80%,损失超过5%,换句话说,使用训练数据来评价模型性能导致了过度乐观的准确度估计,过度估计是5%。
这个问题我们后面再处理。

上一篇 下一篇

猜你喜欢

热点阅读