TFLearn的简单实例

2019-09-27  本文已影响0人  何小有

通过使用TFLearn实现简单的逻辑非(NOT)、逻辑或(OR)、逻辑与(AND)和异或(XOR)运算符模型,我们可以一览TFLearn和TensorFlow中的一些基础知识。

首先,要在我们的demo文件中导入TFLearn和TensorFlow的python库,并声明使用UTF-8编码。

# -*- coding: utf-8 -*-

import tensorflow as tf
import tflearn

逻辑非(NOT)运算符

然后我们要准备好训练逻辑非(NOT)运算符模型的训练数据,如下所示,我们想告诉AI两个例子:当输入[0.]时,应该输出[1.];当输入[1.]时,应该输出[0.]

# 逻辑非(`NOT`)运算符
X = [[0.], [1.]]
Y = [[1.], [0.]]

如果我们要使用input_data方法建立一个逻辑非(NOT)运算符模型,要分几步走。第一步,建立一个None * 1维度大小的输入数据层。

维度

第二步,使用fully_connected方法建立全连接层,本文实例会使用的到fully_connected方法的全连接层参数列表如下(按序排列)。

  1. inputs:至少有两层张量且最后一维是静态值(必须)
  2. num_outputs:int型整数或者是long型,是层中输出单元的个数(必须)
  3. activation:激活函数,默认值是ReLU函数,如果将它设为None则会跳过且保持线性激活

而根据activation参数传递的激活函数,全连接层的作用也不同,linear函数是线性回归函数,sigmoid函数是神经网络激励函数,使我们的网络输出输出0~1之间的单个标量。

全连接层

第三步,使用regression方法,该方法通过回归算法进行定量输出,本文实例会使用的到regression方法的回归参数列表如下(按序排列)。

  1. incoming:至少有两层张量且最后一维是静态值(必须)
  2. optimizer:优化算法器,不同的算法效果不一,例如我们现在用的sgd算法每次更新时,会对每个样本进行梯度更新
  3. learning_rate:机器学习率,也叫指数衰减法,对于不同大小的数据集,调节不同的学习率,是训练模型中经常要干事情
  4. loss:损失函数,用来估量模型的预测值f(x)与真实值Y的不一致程度,它是一个非负实值函数,通常使用L(Y, f(x))来表示,损失函数越小,模型的准确性就越好。本文中的实例用到了两种算法,一是mean_square即均方误差算法,二是binary_crossentropy即二值交叉熵算法
# 图定义
with tf.Graph().as_default():
  g = tflearn.input_data(shape=[None, 1])
  g = tflearn.fully_connected(g, 128, activation='linear')
  g = tflearn.fully_connected(g, 128, activation='linear')
  g = tflearn.fully_connected(g, 1, activation='sigmoid')
  g = tflearn.regression(g, optimizer='sgd', learning_rate=2., loss='mean_square')

接下来我们使用DNN即深度神经网络去训练逻辑非(NOT)运算符模型,在下面fit方法中,n_epoch=100表示整个训练数据集将会使用100遍,snapshot_epoch=False表示不需要在每个周期都保存并评估模型。

  # 模型训练
  m = tflearn.DNN(g)
  m.fit(X, Y, n_epoch=100, snapshot_epoch=False)

训练完成后的输出内容如下所示:

$ python3 logical.py                      
2019-09-27 11:07:57.069657: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
---------------------------------
Run id: MHQXTU
Log directory: /tmp/tflearn_logs/
---------------------------------
Training samples: 2
Validation samples: 0
--
Training Step: 100  | total loss: 0.00023 | time: 0.002s
| SGD | epoch: 100 | loss: 0.00023 -- iter: 2/2
--

我们可以忽略第一行输出的警告,毕竟我们的电脑不用专门用于机器学习的。第二、三行告诉我们训练模型的日志文件的位置,通过tensorboard --logdir='/tmp/tflearn_logs'命令即可通过这些日志文件可视化我们的训练效果和性能。第四、五行则是Training samples(训练样本)和Validation samples(验证样本)的数量。最后的两行则分别是Training Step(训练步骤)、total loss(总损失)、time(时间)、优化算法器、epoch(周期)、loss(损失)和iter(迭代)。

接下来,就可以通过predict方法测试我们的逻辑非(NOT)运算符模型。

  # 测试模型
  print("测试非(`NOT`)运算符")
  print("NOT 0:", m.predict([[0.]]))
  print("NOT 1:", m.predict([[1.]]))

输入内容如下所示,现在逻辑非(NOT)运算符模型可以正常运作了。

测试非(`NOT`)运算符
NOT 0: [[0.9850776]]
NOT 1: [[0.01178647]]

逻辑或(OR)运算符

逻辑或(OR)运算符模型同上非常类似,代码如下所示。

# 逻辑或(`OR`)运算符
X = [[0., 0.], [0., 1.], [1., 0.], [1., 1.]]
Y = [[0.], [1.], [1.], [1.]]

# 图定义
with tf.Graph().as_default():
  g = tflearn.input_data(shape=[None, 2])
  g = tflearn.fully_connected(g, 128, activation='linear')
  g = tflearn.fully_connected(g, 128, activation='linear')
  g = tflearn.fully_connected(g, 1, activation='sigmoid')
  g = tflearn.regression(g, optimizer='sgd', learning_rate=2., loss='mean_square')

  # 模型训练
  m = tflearn.DNN(g)
  m.fit(X, Y, n_epoch=100, snapshot_epoch=False)

  # 测试模型
  print("测试或(`OR`)运算符")
  print("0 or 0:", m.predict([[0., 0.]]))
  print("0 or 1:", m.predict([[0., 1.]]))
  print("1 or 0:", m.predict([[1., 0.]]))
  print("1 or 1:", m.predict([[1., 1.]]))

逻辑与(AND)运算符

逻辑与(AND)运算符模型同上非常类似,代码如下所示。

# 逻辑与(`AND`)运算符
X = [[0., 0.], [0., 1.], [1., 0.], [1., 1.]]
Y = [[0.], [0.], [0.], [1.]]

# 图定义
with tf.Graph().as_default():
  g = tflearn.input_data(shape=[None, 2])
  g = tflearn.fully_connected(g, 128, activation='linear')
  g = tflearn.fully_connected(g, 128, activation='linear')
  g = tflearn.fully_connected(g, 1, activation='sigmoid')
  g = tflearn.regression(g, optimizer='sgd', learning_rate=2., loss='mean_square')

  # 模型训练
  m = tflearn.DNN(g)
  m.fit(X, Y, n_epoch=100, snapshot_epoch=False)

  # 测试模型
  print("测试与(`AND`)运算符")
  print("0 and 0:", m.predict([[0., 0.]]))
  print("0 and 1:", m.predict([[0., 1.]]))
  print("1 and 0:", m.predict([[1., 0.]]))
  print("1 and 1:", m.predict([[1., 1.]]))

逻辑异或(XOR)运算符

现在我们可以更进一步:具有多个优化器的图层组合,使用逻辑与非(NAND)和或(OR)运算符的乘积创建异或(XOR)运算符。

# 数据
X = [[0., 0.], [0., 1.], [1., 0.], [1., 1.]]
Y_nand = [[1.], [1.], [1.], [0.]]
Y_or = [[0.], [1.], [1.], [1.]]

下面的图层定义代码中,与之前不一样的地方是:将binary_crossentropy即二值交叉熵算法作为损失函数传递给loss参数;同时使用merge即合并模式函数,该函数的mode参数用于设置合并模式,这里用的是elemwise_mul模式,即输出按元素求和。

# 图定义
with tf.Graph().as_default():
  # 使用2个优化器构建网络
  g = tflearn.input_data(shape=[None, 2])
  # 逻辑与非(`NAND`)运算符定义
  g_nand = tflearn.fully_connected(g, 32, activation='linear')
  g_nand = tflearn.fully_connected(g_nand, 32, activation='linear')
  g_nand = tflearn.fully_connected(g_nand, 1, activation='sigmoid')
  g_nand = tflearn.regression(g_nand, optimizer='sgd', learning_rate=2., loss='binary_crossentropy')
  # 或(`OR`)运算符定义
  g_or = tflearn.fully_connected(g, 32, activation='linear')
  g_or = tflearn.fully_connected(g_or, 32, activation='linear')
  g_or = tflearn.fully_connected(g_or, 1, activation='sigmoid')
  g_or = tflearn.regression(g_or, optimizer='sgd', learning_rate=2., loss='binary_crossentropy')
  # 异或(`XOR`)合并逻辑与非(`NAND`)和或(`OR`)运算符
  g_xor = tflearn.merge([g_nand, g_or], mode='elemwise_mul')

最后我们开始训练逻辑异或(XOR)运算符模型,并测试模型的实际效果如何。

  # 训练
  m = tflearn.DNN(g_xor)
  m.fit(X, [Y_nand, Y_or], n_epoch=400, snapshot_epoch=False)

  # 测试
  print("测试异或(`XOR`)运算符")
  print("0 xor 0:", m.predict([[0., 0.]]))
  print("0 xor 1:", m.predict([[0., 1.]]))
  print("1 xor 0:", m.predict([[1., 0.]]))
  print("1 xor 1:", m.predict([[1., 1.]]))

输出的内容大概就是下面的那样。

---------------------------------
Run id: A2GWIP
Log directory: /tmp/tflearn_logs/
---------------------------------
Training samples: 8
Validation samples: 0
--
Training Step: 400  | total loss: 0.81700 | time: 0.004s
| SGD_0 | epoch: 400 | loss: 0.40862 -- iter: 4/4
| SGD_1 | epoch: 400 | loss: 0.40838 -- iter: 4/4
--
测试异或(`XOR`)运算符
0 xor 0: [[0.00057937]]
0 xor 1: [[0.99796593]]
1 xor 0: [[0.9979677]]
1 xor 1: [[0.00111361]]
上一篇下一篇

猜你喜欢

热点阅读