Tensorflow 1.0:可视化w和loss随训练变化

2017-04-05  本文已影响1674人  Double_E

tensorflow 1.0 对训练过程进行监视

# -*- coding: utf-8 -*-
"""
Created on Mon Apr  3 19:15:24 2017

@author: Jhy_BUPT
README:

REF:

"""
# -*- coding: utf-8 -*-
"""
Created on Mon Apr  3 19:15:24 2017

@author: Jhy_BUPT
README:

REF:

"""
import os
import io
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

sess = tf.Session()



batch_size = 50

x_data = np.arange(1000) / 10.0
true_w = 2
y_data = x_data * true_w + np.random.normal(loc=0.0, scale=25, size=1000)

train_idx = np.random.choice(len(x_data), size=int(len(x_data) * 0.9), replace=False)
test_idx = np.setdiff1d(np.arange(1000), train_idx)

train_x, train_y = x_data[train_idx], y_data[train_idx]
test_x, test_y = x_data[test_idx], y_data[test_idx]

x = tf.placeholder(tf.float32, [None])
y_ = tf.placeholder(tf.float32, [None])

w = tf.Variable(tf.random_normal([1], dtype=tf.float32), name='weigth')

y = tf.multiply(w, x)

loss = tf.reduce_mean(tf.abs(y - y_))

optimizer = tf.train.GradientDescentOptimizer(0.001)
train_op = optimizer.minimize(loss)

with tf.name_scope('weight_est'):
    tf.summary.scalar('w_est', tf.squeeze(w))

with tf.name_scope('loss'):
    tf.summary.histogram('Loss', loss)

summary_op = tf.summary.merge_all()

init = tf.global_variables_initializer()
sess.run(init)
summary_writer = tf.summary.FileWriter('C:\\tmp\\d44', tf.get_default_graph())

for i in range(1000):
    batch_idx = np.random.choice(len(train_x), size=batch_size)
    xs = train_x[batch_idx]
    ys = train_y[batch_idx]
    _, train_loss, summary = sess.run([train_op, loss, summary_op],
                                      feed_dict={x: xs, y_: ys})
    test_loss = sess.run([loss], feed_dict={x: test_x, y_: test_y})
    if i % 10 == 0:
        print('Epoch: {}, Train Loss: {}, Test Loss: {}'.format(i, train_loss, test_loss))

    log_writer = tf.summary.FileWriter('C:\\tmp\\d44')
    log_writer.add_summary(summary, i)

Weight 的估计值随着epoch(0-1000)逐渐逼近真实值:2

Paste_Image.png

Loss 随着epoch(0-1000),逐渐稳定在20左右

Paste_Image.png
上一篇下一篇

猜你喜欢

热点阅读