Feedforward Neural Network
前言
前馈神经网络的目标是近似某个函数 。后面在围绕解释的也就是这个话题。
从XOR问题说起
xor(异或)运算应该都不陌生,两个二进制值 和 的运算,当恰好有一个值为1时,结果就是1,其余是0。我们把这个规则(运算)当成我们要学习的函数 ,我们的模型给出一个函数 ,我们希望能够不断调整 让 能够 接近 。
在这个例子中。,
我们可以把这个问题当作回归问题(当然当作分类也是可以的),并使用均方误差作为损失函数,那么损失函数就可以表示为
然后我们确定 的具体模型,假设是一个线性模型(简单一点),那么 就包含 和 两个参数,模型也就是
然后通过求解最小化的 来得出 和 的值。
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
def plot_data(data, labels):
one = data[labels == 1, :]
zero = data[labels == 0, :]
plt.scatter(one[:, 0], one[:, 1],
color='red', marker='+', s=200)
plt.scatter(zero[:, 0], zero[:, 1],
color='blue', marker='.',s=200)
one = np.array([[1, 0], [0, 1]])
zero = np.array([[0, 0], [1, 1]])
data = np.concatenate([one, zero])
labels = np.array([1, 1, 0, 0])
plot_data(data, labels)
image.png
上面用代码构造了这个数据并简单展示一下。然后我们用sklearn里的线性模型来进行求解
from sklearn import linear_model
clf = linear_model.LinearRegression()
clf.fit(data, labels)
print(clf.coef_)
print(clf.intercept_)
out:
[0. 0.]
0.5
发现 的值 是0(向量), 的值是0.5,也就是说不论输入数据是什么,输出都是0.5。怎么形象的理解这个问题呢,简单说就是一个线性模型根本无法区分开这个数据(有人说,不对啊,svm就可以,我试过了,svm确实在构造一个超平面,但是它可以使用核函数把数据维度提高成线性可分的去解决非线性问题。单纯的线性模型是不可以的。)我们观察上图发现,当 时,模型输出必须随着 的增大而增大。当 时,模型必须随着 的增大而减小。线性模型不能使用 的值来改变 的值
那么如何解决这个问题呢。传统机器学习那些处理非线性的思路是可行的,还有一种方法就是使用一个模型来学习一个不同的特征空间,然后在这个空间上使用线性模型可以表示这个解。
g1.gv.png具体来说,到这里就可以引入简单的前馈神经网络了。(如下图)它有一个隐藏层并且隐藏层中包含两个单元。这个前馈神经网络通过函数 计算隐藏单元的向量 。
这些隐藏单元的值随后被用作第二层的输入。第二层就是这个网络的输出层。输出层仍然是一个线性回归模型,只不过变量是 而非 。那么现在就可以用
来表示网络
g2.gv.png还有种简单的表示方法,如下
现在要考虑的问题是 用什么样的模型表示,假设它依旧是线性回归模型那么 ,如果不考虑 截距的话, 将表示成 ,这仍旧是一个线性函数,我们前面讨论到问题的所在就是线性所致,所以需要一个非线性的变换,在神经网络的通常做法加入一个激活函数(关于激活函数会在后面详细介绍)来使之前的仿射变换变成非线性的,也就是
这个用一个比较通用的激活函数ReLU(rectified linear unit)。,那么现在我们的网络就可以表示成
现在继续讨论xor问题,我们确定解的参数(至于参数为什么会是这个,也在后续说明,这里只是解释这个问题)
用X作为输入矩阵,当经过ReLU后,数据变成
我们用上面的分析,绘制下图
new_data = np.array([[1,0],[1,0],[0,0],[2,1]])
plot_data(new_data, labels)
image.png
我们发现数据集变成线性可分的了,继续用线性回归模型去计算现在的数据,也得到了结果。
clf.fit(new_data, labels)
print(clf.coef_)
print(clf.intercept_)
out:
[ 1. -2.]
6.661338147750939e-16
回顾这个过程,不论是损失函数的选择,模型的选择,或者是激活函数,再到后面的给定的参数,都有点“假”,这也是我们后面要详细说的东西。
[这是一个正在更新的东西,因为课比较多,人比较懒,也懂的少,所以比较慢]