神经网络入门(7)
2019-07-13 本文已影响12人
zidea
MachineLearninginMarketing
# 形状 [307]
x = tf.placeholder(tf.float32, [None, 3072])
# [None]
y = tf.placeholder(tf.int64, [None])
# get_variable 表示如果已经定义了 w 就使用定义好的 w,如果没有定义新建一个变量 w
# (3072,1)
w = tf.get_variable('w', [x.get_shape()[-1], 1], initializer=tf.random_normal)
# (1,)
b = tf.get_variable('b', [1], initializer=tf.constant_initializer(0.0))
# y_ 将 w * x + b 现在
# [None,3072],[3072,1] = [None,1] (None,1)
y_ = tf.matmul(x, w) + b
# y_ 还只是一个内积值,我们可以将其变成一个概率值,变成概率值的方法是使用函数 sigmoid 中对其进行压缩
# [None,1]
p_y_1 = tf.nn.sigmoid(y_)
# 得到概率为 1 的值就可以和真正 y 进行差别分析,以为 y 的形状(None) 和 p_y_1(None,1)不一样需要进行形状修改
y_reshaped = tf.reshape(y, (-1, 1))
# 以为在 tensorFlow 对数据类型比较敏感,我们需要将 y_resphapded 类型从 int64 修改Wie float32
y_reshaped_float = tf.cast(y_reshaped, tf.float32)
# reduce_mean 是就均值而 square 是求平方
loss = tf.reduce_mean(tf.square(y_reshaped_float - p_y_1))
# 预测值通过将 p_y_1 和 0.5 进行比较得到 true 或 false 来表预测值
predict = p_y_1 > 0.5
# [1,0,1,1,0,0,1]
correct_prediction = tf.equal(tf.cast(predict, tf.int64), y_reshaped)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float64))
class CifarData:
def __init__(self,filenames, need_shuffle):
all_data = []
all_labels = []
for filename in filenames:
data, labels = load_data(filename)
for item, label in zip(data, labels):
if label in [0,1]:
all_data.append(item)
all_labels.append(label)
self._data = np.vstack(all_data)
self._labels = np.hstack(all_labels)
self._num_examples = self._data.shape[0]
print(self._num_examples)
self._need_shuffle = need_shuffle
self._indicator = 0
if self._need_shuffle:
self._shuffle_data()
定义一个类 CifarData 来控制数据,need_shuffle 作为一个控制是否对数据进行重新排序(洗牌)的标识,当我们处理训练数据集时候可以通过开启 need_shuffle 来得到更多随机样本,而对于测试数据集则不会开启该开关。
all_data = []
all_labels = []
for filename in filenames:
data, labels = load_data(filename)
使用之间的 data_load 方法将文件中数据加载进来。
for item, label in zip(data, labels):
if label in [0,1]:
all_data.append(item)
all_labels.append(label)
因为我们处理二分类的数据集所有通过过滤得到标签为 0 或 1 的数据集
self._data = np.vstack(all_data)
self._labels = np.hstack(all_labels)
通过合并后转换为 numpy 中的矩阵,vstack 将按纵向进行合并形成一个矩阵,而 hstack 则是按横向进行合并形成矩阵。
self._num_examples = self._data.shape[0]
定义有多少个向量,然后就是定义 suffle 函数,
def _shuffle_data(self):
p = np.random.permutation(self._num_examples)
self._data = self._data[p]
self._labels = self._labels[p]
首先通过 random.permutation 得到一个排列,这个函数从 0 到 _num_examples 进行一个混排,然后用 p 对 data 和 p 集合进行洗牌。
def next_batch(self,batch_size):
end_indicator = self._indicator + batch_size
if end_indicator > self._num_examples:
if self._need_shuffle:
self._shuffle_data()
else:
raise Exception("have no more examples")
if end_indicator > self._num_examples:
raise Exception("batch size is larger than all examples")
batch_data = self._data[self._indicator:end_indicator]
batch_labels = self._labels[self._indicator:end_indicator]
self._indicator = end_indicator
return batch_data,batch_labels
定义 next_batch 样本,会返回 batch_size 个样本,
end_indicator = self._indicator + batch_size
if end_indicator > self._num_examples:
if self._need_shuffle:
self._shuffle_data()
self._indicator = 0
end_indicator = batch_size
else:
raise Exception("have no more examples")
如果 end_indicator 大于 self._num_examples 表示我们取值样本的截止位置超出了样本数,这时如果是训练数据集就需要重新洗牌然后继续获取数据,但如果不是训练数据集则抛出一个异常。
batch_data = self._data[self._indicator:end_indicator]
batch_labels = self._labels[self._indicator:end_indicator]
self._indicator = end_indicator
return batch_data, batch_labels
将这batch_size 间数据返回去。
self._data = np.vstack(all_data)
self._labels = np.hstack(all_labels)
# 测试
print(self._data.shape)
print(self._labels.shape)
self._num_examples = self._data.shape[0]
在这个位置输出一下,测试一下我们创建好的类是否正常工作
train_filenames = [os.path.join(CIFAR_DIR, 'data_batch_%d' % i)
for i in range(1, 6)]
test_filenames = [os.path.join(CIFAR_DIR, 'test_batch')]
train_data = CifarData(train_filenames, True)
我们知道训练数据集应该有 50000 样本以为每一个 data_batch 有一个 10000 样本,而又 0 - 9 十个类别(也就是图片的类别)而因为过滤为 0,1 所以只要 10000 个数据
(10000, 3072)
(10000,)
batch_data, batch_labels = train_data.next_batch(10)
print(batch_data, batch_labels)
[[208 186 128 ... 100 97 97]
[ 55 59 65 ... 55 52 52]
[223 223 226 ... 61 58 52]
...
[160 111 71 ... 48 48 51]
[105 105 105 ... 50 50 49]
[252 248 248 ... 93 98 97]] [1 0 1 1 0 0 1 1 0 1]