fastai深度学习官方教程代码笔记Lesson1

2019-05-17  本文已影响0人  阿垃垃圾君

最近开始学习fastai,感觉这是一个对于初学者比较友好的库,官方提供了详细的视频教程和代码。
这个笔记基于我对于课程内容和代码的理解,对一些代码的使用进行一些说明,也是为了帮助我更好的理解课程的内容。
代码笔记会按照课程来完成,每课一篇。

课程代码:https://course.fast.ai/start_kaggle.html

首先使用resNet34网络进行训练和测试

#下面的代码用于修改代码后自动重载%aimport下的模块

%reload_ext autoreload

%autoreload 2

#这句代码一般用于jupyter,是的matplotlib直接在python的console下显示图像

%matplotlib inline

from fastai import *

from fastai.vision import *

#这个参数用于在之后将数据分批的时候设置每个批次的大小,当内存不足时可以将这个值减小

bs = 64

# bs = 16  # uncomment this line if you run out of memory even after clicking Kernel->Restart

#help可以查看具体方法的参数和作用

help(untar_data)

#从url下载数据并解压

path = untar_data(URLs.PETS); path

path.ls()

#分别获取图片路径和注解路径

path_anno = path/'annotations'

path_img = path/'images'

#从图片路径读取图片文件

fnames = get_image_files(path_img)

fnames[:5]

#这里是设置随机种子数,应该是为了保证每次训练的参数一致

np.random.seed(2)

#从文件名获取图片标签的正则

pat = re.compile(r'/([^/]+)_\d+.jpg$')

# 将图片根据文件名标记,并通过get_transforms()函数进行变换,使得图片归一化为224大小的图片

# 正规化图片,使得图片的个个通道下的数值基于0到255且不过亮或过暗

data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=bs, num_workers=0).normalize(imagenet_stats)

data.show_batch(rows=3, figsize=(7,6))

print(data.classes)

#data.c返回数据的类别数

len(data.classes),data.c

#创建resnet34模型的cnn网络

learn = create_cnn(data, models.resnet34, metrics=error_rate)

#fit_one_cycle是一种通过逐渐增大学习率进行学习的方式,这里传入的数字是学习的轮数epoch

learn.fit_one_cycle(4)

#保存模型

learn.save('stage-1')

#获取最多分类错误的数据的混淆矩阵

interp = ClassificationInterpretation.from_learner(learn)

#获取分错的最多的数据以及这些数据的

losses,idxs = interp.top_losses()

len(data.valid_ds)==len(losses)==len(idxs)

#展示预测错误最多的图片

interp.plot_top_losses(9, figsize=(15,11), heatmap=False)

#显示函数的文档

doc(interp.plot_top_losses)

#展示预测混淆矩阵

interp.plot_confusion_matrix(figsize=(12,12), dpi=60)

#展示分错最多的混淆矩阵

interp.most_confused(min_val=2)

#将训练好的模型的参数解冻,使其可以再次训练

learn.unfreeze()

learn.fit_one_cycle(1)

#读取之前训练的模型

learn.load('stage-1');

#寻找学习率,展示学习率与误差的关系,用以寻找较好的学习率区间

learn.lr_find()

learn.recorder.plot()

#尝试使用误差较小的学习率区间

learn.unfreeze()

learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-4))

下面采用resNet50网络进行训练

#重新构建数据集,尺寸大小为299,每个批次的尺寸为原来的一般,并且正规化图像
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(),
                                   size=299, bs=bs//2, num_workers=0).normalize(imagenet_stats)

#构建resNet50层的网络
learn = create_cnn(data, models.resnet50, metrics=error_rate)

#寻找学习率,并展示不同学习率下的结果
learn.lr_find()
learn.recorder.plot()

#进行8轮学习
learn.fit_one_cycle(8)

#保存当前模型
learn.save('stage-1-50')

#将参数解冻,并且使用调整过的学习率进行3轮训练
learn.unfreeze()
learn.fit_one_cycle(3, max_lr=slice(1e-6,1e-4))

#读取之前保存的模型
learn.load('stage-1-50');

#获取分类结果中错误程度最严重的结果的混淆矩阵,。展示混淆矩阵中错误超过2个的分错统计,比如将灰熊(tag=1)分成黑熊(tag=2)的有10张,就有[1,2,10]
interp = ClassificationInterpretation.from_learner(learn)
interp.most_confused(min_val=2)

最后还介绍了其他数据格式构建数据集的方法,课程中用了手写数字的数据集MNIST

#首先下载并解压数据集
path = untar_data(URLs.MNIST_SAMPLE); path

#从文件夹获取数据集,对数据集进行变换,尺寸为26
tfms = get_transforms(do_flip=False)
data = ImageDataBunch.from_folder(path, ds_tfms=tfms, size=26, num_workers=0)

#显示变换后的数据
data.show_batch(rows=3, figsize=(5,5))

#创建resnet18网络,并且使用默认的学习率训练2轮
learn = create_cnn(data, models.resnet18, metrics=accuracy)
learn.fit(2)

#读取标签csv文件,并显示前五行
df = pd.read_csv(path/'labels.csv')
df.head()

#从csv文件获取数据集,并且进行变换,尺寸为28
data = ImageDataBunch.from_csv(path, ds_tfms=tfms, size=28, num_workers=0)

#显示前三行数据,并显示所有类别
data.show_batch(rows=3, figsize=(5,5))
data.classes

#从df数据文件中获取数据集,变换尺寸为24,并展示类别
data = ImageDataBunch.from_df(path, df, ds_tfms=tfms, size=24, num_workers=0)
data.classes

#从df文件中查找文件路径,并且保存在fn_paths中
fn_paths = [path/name for name in df['name']]; fn_paths[:2]

#定义正则,并且通过文件名匹配正则获取数据集,并且变换为尺寸24,最后显示类别
pat = r"/(\d)/\d+\.png$"
data = ImageDataBunch.from_name_re(path, fn_paths, pat=pat, ds_tfms=tfms, size=24, num_workers=0)
data.classes

#通过lambda方程来匹配文件名构建数据集,尺寸变换为24,并显示类别
data = ImageDataBunch.from_name_func(path, fn_paths, ds_tfms=tfms, size=24,
        label_func = lambda x: '3' if '/3/' in str(x) else '7', num_workers=0)
data.classes

#通过文件名来标记数据的标签
labels = [('3' if '/3/' in str(x) else '7') for x in fn_paths]
labels[:5]

#通过给定的标签列表和文件路径列表构建对应标签的数据集,尺寸变换为24,最后显示类别
data = ImageDataBunch.from_lists(path, fn_paths, labels=labels, ds_tfms=tfms, size=24, num_workers=0)
data.classes
上一篇下一篇

猜你喜欢

热点阅读