tensorboard使用指南

2019-04-17  本文已影响0人  zestloveheart

介绍

tensorboard可以可视化tensorflow运行过程。
本文有以下部分:

基础使用

这里假定已安装好anaconda,tensorflow,运行代码

import tensorflow as tf
a = tf.constant([1.0,2.0,3.0],name='input1')
b = tf.Variable(tf.random_uniform([3]),name='input2')
add = tf.add_n([a,b],name='addOP')
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    writer = tf.summary.FileWriter("E:/TensorBoard",sess.graph)
    print(sess.run(add))
writer.close()

会将日志存储到 E:/TensorBoard文件夹内。
在控制台使用

tensorboard --logdir=E:/TensorBoard

如果启动成功,使用浏览器进入 localhost:6006即可查看运行结果。
如果不成功则看下面的错误处理部分

查看训练过程

这里假设之前的网络已经使用keras搭建完成,可以训练、预测。
(如果没有实现,可以参考最后的代码部分。)

先在model.fit之前加上以下代码,创建调用tensorboard的回调函数。

log_filepath = 'E:/TensorBoard/keras_log'
tb_cb = keras.callbacks.TensorBoard(log_dir=log_filepath, write_images=1, histogram_freq=1)

再在model.fit函数中加上callbacks参数,引入tensorboard回调函数:

callbacks=[tb_cb]

运行代码,重启tensorboard指定路径,进入网址查看。

错误处理

我启动时运行时遇到一个问题,有可能是因为tensorboard版本太高,我是1.13版本

OSError.png
参考了这篇博客,修改源码,然后解决的问题。

修改Anaconda3\Lib\site-packages\tensorboard\manager.py


修改manager.py的代码

把原先的两行serialize注释,换成下面的一行

serialize=lambda dt: int(dt.strftime("%S")),

重新运行即成功


image.png

代码

最后附上mnist fashion使用tensorboard的代码,可直接运行。

import tensorflow as tf
import keras
import numpy as np
import matplotlib.pyplot as plt
import os
fashion_mnist = keras.datasets.fashion_mnist
(train_data, train_labels), (test_data, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

train_data = train_data / 255.0
test_data = test_data / 255.0
print(train_data.shape)

print(train_labels)

def create_model():
    inputs = keras.layers.Input(shape=(28,28))
    h1 = keras.layers.Flatten()(inputs)
    h1 = keras.layers.Dense(64,activation=tf.nn.relu)(h1)
    h1 = keras.layers.Dense(64, activation=tf.nn.relu)(h1)
    h1 = keras.layers.Dense(64, activation=tf.nn.relu)(h1)
    outputs = keras.layers.Dense(10,activation=tf.nn.softmax)(h1)

    model = keras.Model(inputs= inputs,outputs = outputs)
    model.compile(optimizer=keras.optimizers.Adam(),
                  loss="sparse_categorical_crossentropy",metrics=['accuracy',])
    model.summary()
    return model

model = create_model()


checkpoint_path = "training_1\\fashion_mnist.ckpt"

if os.path.exists(checkpoint_path):
    model.load_weights(checkpoint_path)
else:
    log_filepath = 'E:/TensorBoard/keras_log'
    tb_cb = keras.callbacks.TensorBoard(log_dir=log_filepath, write_images=1, histogram_freq=1)
    cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)
    history = model.fit(train_data, train_labels, epochs=10, callbacks=[cp_callback,tb_cb],validation_split=0.3)
    with open('training_1\\fashion_mnist_history.txt', 'w') as f:
        f.write(str(history.history))


test_loss, test_acc = model.evaluate(test_data, test_labels)
print('Test accuracy:', test_acc)

predictions = model.predict(test_data)
上一篇 下一篇

猜你喜欢

热点阅读