TensorFlow2简单入门-图像加载及预处理
2021-01-16 本文已影响0人
K同学啊
作者:明天依旧可好
下载数据
import tensorflow as tf
import pathlib
data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
fname='flower_photos', untar=True)
data_root = pathlib.Path(data_root_orig)
print(data_root)
"""
输出:
C:\Users\Administrator\.keras\datasets\flower_photos
"""
可以通过C:\Users\Administrator.keras\datasets\flower_photos路径查找到下载的文件
#查看数据目录
for item in data_root.iterdir():
print(item)
"""
输出:
C:\Users\Administrator\.keras\datasets\flower_photos\daisy
C:\Users\Administrator\.keras\datasets\flower_photos\dandelion
C:\Users\Administrator\.keras\datasets\flower_photos\LICENSE.txt
C:\Users\Administrator\.keras\datasets\flower_photos\roses
C:\Users\Administrator\.keras\datasets\flower_photos\sunflowers
C:\Users\Administrator\.keras\datasets\flower_photos\tulips
"""
flower_photos文件夹下包括5个文件夹和一个说明文件,5个文件夹中分别放有5个类别的数据(即对应着5种不同的标签。)
import random
#获取所有图片的路径
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
#将所有路径打乱
random.shuffle(all_image_paths)
image_count = len(all_image_paths)
image_count
"""
输出:3670
"""
all_image_paths[:3]
"""
输出:
['C:\\Users\\Administrator\\.keras\\datasets\\flower_photos\\daisy\\11870378973_2ec1919f12.jpg',
'C:\\Users\\Administrator\\.keras\\datasets\\flower_photos\\roses\\8442304572_2fdc9c7547_n.jpg',
'C:\\Users\\Administrator\\.keras\\datasets\\flower_photos\\dandelion\\17574213074_f5416afd84.jpg']
"""
检查图片
from PIL import Image
import os
train_images = []
for image in all_image_paths[]:
train_images.append(Image.open(os.path.join(image)))
将图片与标签同步从本地文件中拿出来。
import matplotlib.pyplot as plt
train_labels = [pathlib.Path(path).parent.name for path in all_image_paths]
plt.figure(figsize=(20,10))
for i in range(20):
plt.subplot(5,10,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(train_images[i])
plt.xlabel(train_labels[i])
plt.show()
构建一个 tf.data.Dataset
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_root,
validation_split=0.2,
subset="training",
seed=123,
image_size=(192, 192),
batch_size=20)
class_names = train_ds.class_names
print("\n",class_names)
train_ds
"""
输出:
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
<BatchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory()
:将创建一个从本地目录读取图像数据的数据集。数据集对象可以直接传递到fit(),也可以在自定义低级训练循环中进行迭代。
import matplotlib.pyplot as plt
plt.figure(figsize=(20, 10))
for images, labels in train_ds.take(1):
for i in range(20):
ax = plt.subplot(5, 10, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
-
dataset.take(1)
:取第一个元素构建dataset(是第一个元素,不是随机的一个),从文件中读取数据形成train_ds时是以为20为一个步长的,故这里的dataset.take(1)即前20个数据。 -
dataset.skip(2)
:跳过前2个元素后构建的dataset
for image_batch, labels_batch in train_ds:
print(image_batch.shape)
print(labels_batch.shape)
break
"""
输出:
(30, 192, 192, 3)
(30,)
"""