TensorFlow2简单入门

TensorFlow2简单入门-图像加载及预处理

2021-01-16  本文已影响0人  K同学啊

作者:明天依旧可好


下载数据

import tensorflow as tf

import pathlib
data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
                                         fname='flower_photos', untar=True)
data_root = pathlib.Path(data_root_orig)
print(data_root)
"""
输出:
C:\Users\Administrator\.keras\datasets\flower_photos
"""

可以通过C:\Users\Administrator.keras\datasets\flower_photos路径查找到下载的文件

#查看数据目录
for item in data_root.iterdir():
    print(item)
"""
输出:
C:\Users\Administrator\.keras\datasets\flower_photos\daisy
C:\Users\Administrator\.keras\datasets\flower_photos\dandelion
C:\Users\Administrator\.keras\datasets\flower_photos\LICENSE.txt
C:\Users\Administrator\.keras\datasets\flower_photos\roses
C:\Users\Administrator\.keras\datasets\flower_photos\sunflowers
C:\Users\Administrator\.keras\datasets\flower_photos\tulips
"""

flower_photos文件夹下包括5个文件夹和一个说明文件,5个文件夹中分别放有5个类别的数据(即对应着5种不同的标签。)

import random
#获取所有图片的路径
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
#将所有路径打乱
random.shuffle(all_image_paths)

image_count = len(all_image_paths)
image_count
"""
输出:3670
"""
all_image_paths[:3]
"""
输出:
['C:\\Users\\Administrator\\.keras\\datasets\\flower_photos\\daisy\\11870378973_2ec1919f12.jpg',
 'C:\\Users\\Administrator\\.keras\\datasets\\flower_photos\\roses\\8442304572_2fdc9c7547_n.jpg',
 'C:\\Users\\Administrator\\.keras\\datasets\\flower_photos\\dandelion\\17574213074_f5416afd84.jpg']
"""

检查图片

from PIL import Image
import os

train_images = []
for image in all_image_paths[]:
    train_images.append(Image.open(os.path.join(image)))

将图片与标签同步从本地文件中拿出来。

import matplotlib.pyplot as plt

train_labels  = [pathlib.Path(path).parent.name for path in all_image_paths]

plt.figure(figsize=(20,10))
for i in range(20):
    plt.subplot(5,10,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i])
    plt.xlabel(train_labels[i])
plt.show()

构建一个 tf.data.Dataset

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_root,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(192, 192),
  batch_size=20)

class_names = train_ds.class_names
print("\n",class_names)

train_ds
"""
输出:
Found 3670 files belonging to 5 classes.
Using 2936 files for training.

 ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
<BatchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>
"""

train_ds = tf.keras.preprocessing.image_dataset_from_directory():将创建一个从本地目录读取图像数据的数据集。数据集对象可以直接传递到fit(),也可以在自定义低级训练循环中进行迭代。

import matplotlib.pyplot as plt

plt.figure(figsize=(20, 10))
for images, labels in train_ds.take(1):
    for i in range(20):
        ax = plt.subplot(5, 10, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break
"""
输出:
(30, 192, 192, 3)
(30,)
"""
上一篇下一篇

猜你喜欢

热点阅读