sklearn中随机测试数据:sklearn包中SVM算法库的使
目录
- SVM相关知识点回顾
1.1. SVM与SVR
1.2. 核函数
- sklearn中SVM相关库的简介
- 2.1. 分类库与回归库
- 2.2. 高斯核调参
- 2.2.1. 需要调节的参数
- 2.2.2. 调参方法:网格搜索
- 编程实现
这是《西瓜书带学训练营·实战任务》系列的第三篇笔记
1. SVM相关知识点回顾
1.1. SVM与SVR
-
SVM分类算法
其原始形式是:
其中m为样本个数,我们的样本为(x1,y1),(x2,y2),...,(xm,ym)。w,b是我们的分离超平面的w∙ϕ(xi)+b=0系数, ξi为第i个样本的松弛系数, C为惩罚系数。ϕ(xi)为低维到高维的映射函数
通过拉格朗日函数以及对偶化后的形式为:
-
SVR回归算法
其中m为样本个数,我们的样本为(x1,y1),(x2,y2),...,(xm,ym)。w,b是我们的回归超平面的w∙xi+b=0系数, ξ∨i,ξ∧i为第i个样本的松弛系数, C为惩罚系数,ϵ为损失边界,到超平面距离小于ϵ的训练集的点没有损失。ϕ(xi)为低维到高维的映射函数。
1.2. 核函数
在scikit-learn中,内置的核函数一共有4种:
-
线性核函数(Linear Kernel)表达式为:K(x,z)=x∙z,就是普通的内积
-
多项式核函数(Polynomial Kernel)是线性不可分SVM常用的核函数之一,表达式为:K(x,z)=(γx∙z+r)d ,其中,γ,r,d都需要自己调参定义
-
高斯核函数(Gaussian Kernel),在SVM中也称为径向基核函数(Radial Basis Function,RBF),它是 libsvm 默认的核函数,当然也是 scikit-learn 默认的核函数。表达式为:K(x,z)=exp(−γ||x−z||2), 其中,γ大于0,需要自己调参定义
-
Sigmoid核函数(Sigmoid Kernel)也是线性不可分SVM常用的核函数之一,表达式为:K(x,z)=tanh(γx∙z+r), 其中,γ,r都需要自己调参定义
一般情况下,对非线性数据使用默认的高斯核函数会有比较好的效果,如果你不是SVM调参高手的话,建议使用高斯核来做数据分析。
2. sklearn中SVM相关库的简介
scikit-learn SVM算法库封装了libsvm 和 liblinear 的实现,仅仅重写了算法了接口部分
2.1. 分类库与回归库
-
分类算法库
包括SVC, NuSVC,和LinearSVC 3个类
对于SVC, NuSVC,和LinearSVC 3个分类的类,SVC和 NuSVC差不多,区别仅仅在于对损失的度量方式不同,而LinearSVC从名字就可以看出,他是线性分类,也就是不支持各种低维到高维的核函数,仅仅支持线性核函数,对线性不可分的数据不能使用
-
回归算法库
包括SVR, NuSVR,和LinearSVR 3个类
同样的,对于SVR, NuSVR,和LinearSVR 3个回归的类, SVR和NuSVR差不多,区别也仅仅在于对损失的度量方式不同。LinearSVR是线性回归,只能使用线性核函数
2.2. 高斯核调参
2.2.1. 需要调节的参数
-
SVM分类模型
如果是SVM分类模型,这两个超参数分别是惩罚系数C和RBF核函数的系数γ
惩罚系数C
它在优化函数里主要是平衡支持向量的复杂度和误分类率这两者之间的关系,可以理解为正则化系数
-
当C比较大时,我们的损失函数也会越大,这意味着我们不愿意放弃比较远的离群点。这样我们会有更加多的支持向量,也就是说支持向量和超平面的模型也会变得越复杂,也容易过拟合
-
当C比较小时,意味我们不想理那些离群点,会选择较少的样本来做支持向量,最终的支持向量和超平面的模型也会简单
scikit-learn中默认值是1
C越大,泛化能力越差,易出现过拟合现象;C越小,泛化能力越好,易出现过欠拟合现象
BF核函数的参数γ
RBF 核函数K(x,z)=exp(−γ||x−z||2) γ>0
γ主要定义了单个样本对整个分类超平面的影响
-
当γ比较小时,单个样本对整个分类超平面的影响比较小,不容易被选择为支持向量
-
当γ比较大时,单个样本对整个分类超平面的影响比较大,更容易被选择为支持向量,或者说整个模型的支持向量也会多
scikit-learn中默认值是
1/样本特征数
γ越大,训练集拟合越好,泛化能力越差,易出现过拟合现象
如果把惩罚系数C和RBF核函数的系数γ一起看,当C比较大, γ比较大时,我们会有更多的支持向量,我们的模型会比较复杂,容易过拟合一些。如果C比较小 , γ比较小时,模型会变得简单,支持向量的个数会少
-
-
SVM回归模型
SVM回归模型的RBF核比分类模型要复杂一点,因为此时我们除了惩罚系数C和RBF核函数的系数γ之外,还多了一个损失距离度量ϵ
对于损失距离度量ϵ,它决定了样本点到超平面的距离损失
-
当 ϵ 比较大时,损失较小,更多的点在损失距离范围之内,而没有损失,模型较简单
-
当 ϵ 比较小时,损失函数会较大,模型也会变得复杂
scikit-learn中默认值是0.1
如果把惩罚系数C,RBF核函数的系数γ和损失距离度量ϵ一起看,当C比较大, γ比较大,ϵ比较小时,我们会有更多的支持向量,我们的模型会比较复杂,容易过拟合一些。如果C比较小 , γ比较小,ϵ比较大时,模型会变得简单,支持向量的个数会少
-
2.2.2. 调参方法:网格搜索
对于SVM的RBF核,我们主要的调参方法都是交叉验证。具体在scikit-learn中,主要是使用网格搜索,即GridSearchCV类
from sklearn.model_selection import GridSearchCV
grid = GridSearchCV(SVC(), param_grid={"C":[0.1, 1, 10], "gamma": [1, 0.1, 0.01]}, cv=4)
grid.fit(X, y)
将GridSearchCV类用于SVM RBF调参时要注意的参数有:
estimator:即我们的模型,此处我们就是带高斯核的SVC或者SVR
param_grid:即我们要调参的参数列表。 比如我们用SVC分类模型的话,那么param_grid可以定义为{"C":[0.1, 1, 10], "gamma": [0.1, 0.2, 0.3]},这样我们就会有9种超参数的组合来进行网格搜索,选择一个拟合分数最好的超平面系数
cv:S折交叉验证的折数,即将训练集分成多少份来进行交叉验证。默认是3。如果样本较多的话,可以适度增大cv的值
3. 编程实现
-
生成测试数据
from sklearn.datasets import make_circles from sklearn.preprocessing import StandardScaler # 生成一些随机数据用于后续分类 X, y = make_circles(noise=0.2, factor=0.5, random_state=1) # 生成时加入了一些噪声 X = StandardScaler().fit_transform(X) # 把数据归一化
生成的随机数据可视化结果如下:
-
调参
接着采用网格搜索的策略进行RBF核函数参数搜索
from sklearn.model_selection import GridSearchCV grid = GridSearchCV(SVC(), param_grid={"C":[0.1, 1, 10], "gamma": [1, 0.1, 0.01]}, cv=4) # 总共有9种参数组合的搜索空间 grid.fit(X, y) print("The best parameters are %s with a score of %0.2f" % (grid.best_params_, grid.best_score_)) 输出为: The best parameters are {'C': 10, 'gamma': 0.1} with a score of 0.91
可以对9种参数组合训练的结果进行可视化,观察分类的效果:
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 xx, yy = np.meshgrid(np.arange(x_min, x_max,0.02), np.arange(y_min, y_max, 0.02)) for i, C in enumerate((0.1, 1, 10)): for j, gamma in enumerate((1, 0.1, 0.01)): plt.subplot() clf = SVC(C=C, gamma=gamma) clf.fit(X,y) Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) # Put the result into a color plot Z = Z.reshape(xx.shape) plt.contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8) # Plot also the training points plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.coolwarm) plt.xlim(xx.min(), xx.max()) plt.ylim(yy.min(), yy.max()) plt.xticks(()) plt.yticks(()) plt.xlabel(" gamma=" + str(gamma) + " C=" + str(C)) plt.show()
|
1.000 | 0.100 | 0.001 |
---|---|---|---|
0.1 | |||
1 | |||
10 |
参考资料: