入门 tensorflow(二)

2017-08-03  本文已影响28人  Kean_L_C

线性回归

张量定义:
TensorFlow operations (also called ops for short) can take any number of inputs and produce any number of outputs. For example, the addition and multiplication ops each take two inputs and produce one output. Constants and variables take no input (they are called source ops). The inputs and outputs are multidimensional arrays, called tensors (hence the name “tensor flow”)

# encoding: utf-8
"""
@version: python3.5.2 
@author: kaenlee  @contact: lichaolfm@163.com
@software: PyCharm Community Edition
@time: 2017/8/3 11:33
purpose:
"""
import numpy as np
import pandas as pd
import tensorflow as tf
import sklearn.datasets as dt


housing = dt.fetch_california_housing()
x = housing.data
y = housing.target
m, n = x.shape
print(m, n)

# 给数据添加一个常数项
housing_bias = np.c_[np.ones(shape=(m, 1)), x]
y = np.reshape(y, (-1, 1))
# print(housing_bias[:3])
# print(y[:3])
# 输入张量
X = tf.constant(housing_bias, dtype=tf.float32, name="X")
Y = tf.constant(y, dtype=tf.float32, name="Y")
XT = tf.transpose(X)

theta = tf.matmul(tf.matmul(tf.matrix_inverse(tf.matmul(XT, X)), XT), Y)  # 查看线性回归如何求解delta

with tf.Session() as sess:
    theta_value = sess.run(theta)
    # theta_value = theta.eval()
    print(theta_value)
    sess.close()

# 通过梯度下降法计算
# 随机生成theta 的一组初始值
theta = tf.Variable(tf.random_uniform([n + 1, 1], -1.0, 1.0), name="theta")
y_pred = tf.matmul(X, theta, name="predictions")
error = y_pred - y
mse = tf.reduce_mean(tf.square(error), name="mse")
gradients = 2/m * tf.matmul(tf.transpose(X), error)  #此处参考梯度下降算法原理
training_op = tf.assign(theta, theta - 0.01 * gradients) # 更新theta

# op = tf.train.GradientDescentOptimizer(0.01)  # 学习步长
# training_op = op.minimize(mse)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(1000):
        if epoch % 100 == 0:
            print("Epoch", epoch, "MSE =", mse.eval())
        sess.run(training_op)
    best_theta = theta.eval()
    print(best_theta)
    sess.close()

ps:梯度下降法返回值会nan,有待解决!

上一篇 下一篇

猜你喜欢

热点阅读