Learning Tensorflow part 2

2020-08-21  本文已影响0人  轻骑兵1390

说明一些tensorflow课程中的用法.

1. 加载数据

def normalize(images, labels):
  images = tf.cast(images, tf.float32)
  images /= 255
  return images, labels

train_dataset =  train_dataset.map(normalize)
train_dataset =  train_dataset.cache()

normalize将图片数据由[0, 255]归一化到[0, 1].

2. 构造神经网络

l0 = tf.keras.layers.Flatten(input_shape = (28, 28, 1))

3. 构造误差函数

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=['accuracy'])

4. 训练

BATCH_SIZE = 32
train_dataset = train_dataset.cache().repeat().shuffle(num_train_examples).batch(BATCH_SIZE)
test_dataset = test_dataset.cache().batch(BATCH_SIZE)

model.fit(train_dataset, epochs=5, steps_per_epoch=math.ceil(num_train_examples/BATCH_SIZE))

这里要打乱原有的数据

上一篇 下一篇

猜你喜欢

热点阅读