第六章 神经网络

2020-02-17  本文已影响0人  晨光523152

6.2 全连接层

神经网络示例1

画图工具的链接如下:
http://alexlenail.me/NN-SVG/index.html

6.2.1 张量方式实现

在 TensorFlow 中,要实现全连接层,只需定义好权值张量\mathbf{W}和偏置张量\mathbf{b},并利用 tf.matmul() 函数即可。

如:让 \mathbf{x}\in \mathbb{R}^{2\times 784},权值矩阵\mathbf{W}\in \mathbb{R}^{784,256},偏置张量\mathbf{b}\in \mathbb{R}^{256}}

x = tf.random.normal([2,784])
w = tf.Variable(tf.random.normal([784,256],stddev=0.1))
b = tf.Variable(tf.zeros([256]))
y = x @ w + b
y = tf.nn.relu(y)
y
#输出结果为
<tf.Tensor: id=32, shape=(2, 256), dtype=float32, numpy=
array([[7.71913290e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 2.31523323e+00, 6.52046919e+00,
        3.29700470e-01, 0.00000000e+00, 9.10421133e-01, 3.05028844e+00,
        0.00000000e+00, 4.54208583e-01, 2.06560063e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 3.39103150e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 1.41346848e+00, 0.00000000e+00,
        2.62670159e-01, 5.09379745e-01, 0.00000000e+00, 7.26842523e-01,
        3.25483620e-01, 0.00000000e+00, 3.36569405e+00, 0.00000000e+00,
        3.18554354e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        3.49711514e+00, 5.54364681e-01, 2.64297247e-01, 0.00000000e+00,
        3.58792257e+00, 9.66847777e-01, 0.00000000e+00, 3.03364110e+00,
        0.00000000e+00, 1.79568231e+00, 0.00000000e+00, 0.00000000e+00,
        4.87591743e+00, 0.00000000e+00, 1.35650539e+00, 1.45709491e+00,
        2.53773332e-01, 0.00000000e+00, 4.55542755e+00, 0.00000000e+00,
        2.61453032e+00, 5.70898771e+00, 0.00000000e+00, 3.54384494e+00,
        3.70477438e-02, 0.00000000e+00, 2.85954285e+00, 3.52746582e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 2.74943805e+00, 0.00000000e+00, 2.45141840e+00,
        6.41983986e-01, 0.00000000e+00, 1.41170359e+00, 1.51873493e+00,
        1.37690508e+00, 0.00000000e+00, 2.36272526e+00, 5.02816725e+00,
        8.65906477e-01, 0.00000000e+00, 0.00000000e+00, 1.78874230e+00,
        2.50994110e+00, 0.00000000e+00, 2.26814771e+00, 1.39309406e-01,
        1.31797361e+00, 1.96663916e-01, 0.00000000e+00, 1.50965190e+00,
        1.63897133e+00, 3.09265780e+00, 4.00402367e-01, 0.00000000e+00,
        0.00000000e+00, 1.31546986e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 1.30133009e+00, 0.00000000e+00, 0.00000000e+00,
        5.75942945e+00, 2.69783318e-01, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 1.09467018e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 3.99055243e-01, 2.54434037e+00, 5.17293596e+00,
        7.99074948e-01, 6.07301664e+00, 0.00000000e+00, 4.91364062e-01,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        2.26829839e+00, 1.00115919e+00, 0.00000000e+00, 0.00000000e+00,
        9.99154508e-01, 0.00000000e+00, 1.98625875e+00, 0.00000000e+00,
        1.56914759e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        6.41001046e-01, 0.00000000e+00, 0.00000000e+00, 1.78947270e+00,
        0.00000000e+00, 0.00000000e+00, 5.54811525e+00, 2.02449083e-01,
        3.83221936e+00, 0.00000000e+00, 1.43161571e+00, 3.81807876e+00,
        2.32804728e+00, 0.00000000e+00, 5.86981654e-01, 1.37415338e+00,
        0.00000000e+00, 2.23189306e+00, 1.55265594e+00, 0.00000000e+00,
        1.83392847e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        3.13988662e+00, 0.00000000e+00, 2.06149364e+00, 0.00000000e+00,
        8.79067779e-01, 0.00000000e+00, 0.00000000e+00, 3.22869968e+00,
        8.63524675e-02, 2.32554674e-01, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 3.02852488e+00, 0.00000000e+00, 2.26370525e+00,
        0.00000000e+00, 2.82050991e+00, 4.89757490e+00, 3.80764604e+00,
        0.00000000e+00, 4.40563631e+00, 0.00000000e+00, 8.55568707e-01,
        0.00000000e+00, 1.18234396e-01, 1.99563265e+00, 0.00000000e+00,
        1.20063639e+00, 1.59806740e+00, 0.00000000e+00, 3.04061627e+00,
        0.00000000e+00, 4.56145334e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 2.56617641e+00, 0.00000000e+00, 3.76130891e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 2.25188398e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 2.29746628e+00, 0.00000000e+00, 0.00000000e+00,
        4.91090775e-01, 0.00000000e+00, 0.00000000e+00, 2.68757701e-01,
        0.00000000e+00, 2.83860111e+00, 3.06481957e+00, 0.00000000e+00,
        0.00000000e+00, 8.86397243e-01, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 2.67475843e+00, 4.45377469e-01,
        4.66923058e-01, 0.00000000e+00, 1.23887026e+00, 0.00000000e+00,
        3.24229908e+00, 3.95938778e+00, 7.80869126e-01, 2.35901022e+00,
        0.00000000e+00, 0.00000000e+00, 4.39496756e+00, 0.00000000e+00,
        5.58587492e-01, 0.00000000e+00, 8.33164930e-01, 0.00000000e+00,
        1.05002940e-01, 3.09266973e+00, 0.00000000e+00, 0.00000000e+00],
       [0.00000000e+00, 1.71956635e+00, 8.52033377e-01, 2.87431836e+00,
        0.00000000e+00, 1.44510102e+00, 0.00000000e+00, 2.26490664e+00,
        3.95393014e+00, 0.00000000e+00, 3.09220028e+00, 4.65225697e-01,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        8.06872129e-01, 0.00000000e+00, 0.00000000e+00, 2.72551465e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 3.16905856e-01, 0.00000000e+00,
        0.00000000e+00, 4.74110663e-01, 0.00000000e+00, 1.06873035e+00,
        7.46809125e-01, 1.56842375e+00, 0.00000000e+00, 2.77372360e-01,
        5.24407804e-01, 0.00000000e+00, 0.00000000e+00, 1.68116665e+00,
        1.11929429e+00, 2.97584724e+00, 1.55387759e+00, 0.00000000e+00,
        2.18963528e+00, 3.72682428e+00, 0.00000000e+00, 0.00000000e+00,
        2.31971771e-01, 0.00000000e+00, 3.32603455e+00, 0.00000000e+00,
        0.00000000e+00, 1.10847163e+00, 0.00000000e+00, 2.06082296e+00,
        1.85625434e+00, 0.00000000e+00, 0.00000000e+00, 2.25890326e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        3.19665432e-01, 3.07176018e+00, 1.99592948e-01, 0.00000000e+00,
        4.49227810e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 1.23876023e+00, 0.00000000e+00,
        3.98132086e-01, 0.00000000e+00, 1.18834615e-01, 0.00000000e+00,
        6.52224493e+00, 0.00000000e+00, 0.00000000e+00, 6.56467795e-01,
        0.00000000e+00, 0.00000000e+00, 3.61150384e+00, 1.50383353e-01,
        6.41233504e-01, 2.60811955e-01, 0.00000000e+00, 0.00000000e+00,
        1.19467771e+00, 0.00000000e+00, 3.53946090e-02, 4.60676384e+00,
        3.01892710e+00, 5.12624645e+00, 2.63171482e+00, 5.81020164e+00,
        0.00000000e+00, 2.47122884e-01, 3.25780749e+00, 0.00000000e+00,
        0.00000000e+00, 2.11186957e+00, 4.28522396e+00, 0.00000000e+00,
        1.05025351e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        5.55113196e-01, 0.00000000e+00, 3.63788271e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 5.00781822e+00,
        4.70510149e+00, 8.88607144e-01, 0.00000000e+00, 1.94936168e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 2.11702406e-01,
        0.00000000e+00, 0.00000000e+00, 7.71915376e-01, 1.00005352e+00,
        0.00000000e+00, 4.98903692e-01, 1.67754424e+00, 0.00000000e+00,
        1.33315754e+00, 0.00000000e+00, 4.45194721e+00, 0.00000000e+00,
        2.69036627e+00, 0.00000000e+00, 0.00000000e+00, 5.00457287e-01,
        0.00000000e+00, 0.00000000e+00, 3.96098042e+00, 5.24755621e+00,
        0.00000000e+00, 5.01480627e+00, 1.02065539e+00, 0.00000000e+00,
        0.00000000e+00, 3.81922662e-01, 2.42469931e+00, 3.28508353e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.32257307e-01,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 5.82755089e-01,
        0.00000000e+00, 0.00000000e+00, 1.46961081e+00, 2.96623898e+00,
        0.00000000e+00, 3.29777431e+00, 0.00000000e+00, 1.14533782e+00,
        2.60658717e+00, 5.20111752e+00, 3.79591346e+00, 1.49161506e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 5.52116537e+00,
        2.91816592e+00, 0.00000000e+00, 2.67469835e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 6.96935892e+00,
        0.00000000e+00, 1.40678453e+00, 2.23776054e+00, 4.61781788e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        1.61894703e+00, 3.61282587e+00, 3.94072843e+00, 0.00000000e+00,
        4.54843044e-04, 0.00000000e+00, 0.00000000e+00, 8.03797722e-01,
        0.00000000e+00, 5.72851419e+00, 1.02512193e+00, 0.00000000e+00,
        1.29191601e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        6.11544657e+00, 3.15644312e+00, 0.00000000e+00, 5.57662821e+00,
        4.99616480e+00, 7.72207022e-01, 5.39032459e+00, 0.00000000e+00,
        0.00000000e+00, 2.25529015e-01, 0.00000000e+00, 5.36596715e-01,
        0.00000000e+00, 1.81711638e+00, 0.00000000e+00, 0.00000000e+00,
        9.11764145e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        4.94917774e+00, 0.00000000e+00, 3.31027699e+00, 0.00000000e+00,
        1.70392036e-01, 3.61162424e-01, 2.54643679e+00, 0.00000000e+00,
        0.00000000e+00, 6.02995348e+00, 0.00000000e+00, 3.06049228e-01,
        1.20904040e+00, 1.63265777e+00, 0.00000000e+00, 0.00000000e+00,
        2.19519377e+00, 6.64466918e-01, 0.00000000e+00, 2.53132772e+00,
        2.92869377e+00, 4.00854206e+00, 0.00000000e+00, 6.26695919e+00]],
      dtype=float32)>

6.2.2 层方式实现

使用 tensorflow.keras.layers.Dense() 函数

x = tf.random.normal([4, 28 * 28])
fc = layers.Dense(521,activation='relu')
fc = layers.Dense(521,activation='relu')(x)

6.3 神经网络

通过层层堆叠 神经网络示例1 中的全连接层,能够堆叠成任意层数的网络(保证前一层的输出节点数与当前层的输入节点数相匹配),我们把这种由神经元构成的网络叫做神经网络。如下图所示


层神经网络

画图工具的链接如下:
http://alexlenail.me/NN-SVG/index.html

对于多层神经网络实现方式如下

6.3.1 张量方式实现

需要定义各层的权值矩阵\mathbf{W}和偏置向量\mathbf{b}


# 隐藏层1张量
w1 = tf.Variable(tf.random.normal([784,256],stddev=0.1))
b1 = tf.Variable(tf.zeros([256]))

# 隐藏层2张量
w2 = tf.Variable(tf.random.normal([256,128],stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))

# 隐藏层3张量
w3 = tf.Variable(tf.random.normal([128,64],stddev=0.1))
b3 = tf.Variable(tf.zeros([64]))

# 隐藏层4张量
w4 = tf.Variable(tf.random.normal([64,10],stddev=0.1))
b4 = tf.Variable(tf.zeros([10]))
with tf.GradientTape() as tape:
    y1 = tf.nn.relu(x @ w1 + b1)
    y2 = tf.nn.relu(y1 @ w2 + b2)
    y3 = tf.nn.relu(y2 @ w3 +b3)
    y4 = tf.nn.relu(y3 @ w4 + b4)

6.3.2 层方式实现

A

fc1 = layers.Dense(256,activation='relu')
fc2 = layers.Dense(128,activation='relu')
fc3 = layers.Dense(64,activation='relu')
fc4 = layers.Dense(10,activation=None)
x = tf.random.normal([4,28*28])
h1 = fc1(x)
h2 = fc2(h1)
h3 = fc3(h2)
h4 = fc4(h3)

B

model = tf.keras.Sequential([
    layers.Dense(256,activation='relu'),
    layers.Dense(128,activation='relu'),
    layers.Dense(64,activation='relu'),
    layers.Dense(10,activation=None)
])
y = model(x)

6.3.3 优化目标

把神经网络从输入到输出的计算过程叫做前向传播(数据张量从第一层流动至输出层的过程);
前向传播的最后一步就是完成误差的计算,然后利用梯度下降算法迭代更新。

6.4 激活函数

可以参考我之前写的激活函数总结:
https://www.jianshu.com/p/4f1a82fe723a

6.5 输出层设计

神经网络的最后一层,除了和所有的隐藏层一样,完成维度变换,特征提取的功能,还作为输出层的使用,需要根据具体的任务场景来决定是否使用激活函数,以及使用什么类型的激活函数。

常见的几种输出类型包括:

6.5.1 普通实数空间

正弦函数曲线预测,年龄的预测,股票走势的预测等都属于整个或者部分连续的实数空间,输出层可以不加激活函数

6.5.2 [0,1]区间

图片的生成,二分类问题等,都属于输出值属于 [0,1]。

在机器学习中,一般会将图片的像素值归一化到 [0,1] 区间,如果直接使用输出层的值,像素的值范围会分布整个实数空间。为了让像素的值范围映射到 [0,1]的有效实数空间,需要在输出层后添加某个合适的激活函数\sigma,其中 Sigmoid 函数刚好具有此功能。

6.5.3 [0,1]区间,和为1

输出值 out_{i}\in [0,1],所有输出值之和为1,这种设定以多分类问题最为常见。

使用 Softmax 函数。

6.5.4 [-1,1]

如果希望输出值的范围分布在 [-1,1],可以简单地使用 tanh 激活函数。

6.6 误差计算

常见的误差计算函数有:均方差,交叉熵,KL散度,Hinge Loss 函数等。
均方差主要用于回归问题,交叉熵主要用于分类问题。

6.6.1 均方差

均方差误差(Mean Squared Error, MSE)函数把输出向量和真实向量映射到笛卡尔坐标系的两个点上,通过计算这两个点的欧式距离的平方来衡量两个向量之间的差距:
MSE:=\frac{1}{d_{out}}\sum_{i=1}^{d_{out}}(y_{i}-o_{i})^{2}

o = tf.random.normal([2,10])
y_onehot = tf.constant([1,3])
y_onehot = tf.one_hot(y_onehot, depth=10)
loss = keras.losses.MSE(y_onehot, o)
criteon = keras.losses.MeanSquaredError()
criteon(o,y_onehot)

6.6.2 交叉熵

熵在信息学科中也叫做信息熵(香农熵),熵越大,代表的不确定性也就越大,信息量也就越大。
某个分布P(i)的熵定义为:
H(P):=-\sum_{i}P(i)log_{2}P(i)
e.g. 对于4分类问题,如果某个样本的真实标签是第4类,其 one-hot 编码为 [0,0,0,1],即这张图片的分类是唯一确定的,不确定性为0,其熵为0。
如果它预测的概率分布是 [0.1,0.1,0.1,0.7],它的熵约为1.356。

基于熵引出交叉熵的定义:
H(p,q):=-\sum_{i=0}p(i)log_{2}q(i)

通过变换,交叉熵可以分解为p的熵H(p)p,q的 KL 散度的和:
H(p,q)=H(p) +D_{KL}(p|q)
其中 KL 定义为:
D_{KL}(p|q) = \sum_{x\in X}p(x)log(\frac{p(x)}{q(x)})

需要注意的是,交叉熵和 KL 散度都不是对称的,即:
\begin{split} H(p,q) & \ne H(q,p)\\ D_{KL}(p|q) & \ne D_{KL}(q|p) \end{split}

交叉熵可以很好地衡量两个分布之间的差别,特别地,当分类问题中y的编码分布p采用 one-hot 编码时:H(y)=0,此时
H(y,o) = H(y) + D_{KL}(y|o) = D_{KL}(y|o)

6.7 神经网络类型

全连接层是神经网络中最基本的网络类型。
缺点:参数多(处理较大特征长度的数据时)

6.7.1 卷积神经网络

用于图片分类: AlexNet,VGG,GoogLeNet,ResNet,DenseNet 等

用于目标识别: RCNN,Fast RCNN,Faster RCNN,Mask RCNN 等

6.7.2 循环神经网络

卷积神经忘了由于缺乏 Menmory 机制和处理补丁长序列信号的能力,并不擅长处理自然语言人物。循环神经网络被证明。
RNN, LSTM,Seq2Seq,GNMT,GRU,双向RNN

6.7.3 注意力(机制)网络

Attention 的提出,克服了 RNN 训练不稳定,难以并行化等缺。
Transformer,GPT,BERT,GRT-2

6.7.4 图神经网络

类似于社交网络,通信网络,蛋白质分子结构等一系列不规则的空间拓扑结构的数据,CNN,RNN效果不好。
GCN, GAT,EdgeConv,DeepGCN 等。

参考资料:https://github.com/dragen1860/Deep-Learning-with-TensorFlow-book

上一篇下一篇

猜你喜欢

热点阅读