26 | 使用PyTorch完成医疗图像识别大项目:分割模型实训
开始训练模型之前,我们需要先把之前的标注文件清理好。如下是原作给出的代码示例。
#在这个引入部分,有一个新的pylidc包需要安装,使用pip安装即可。这个包是专门用来处理LIDC数据集的,现在用的LUNA数据集就是在这个基础上加工的,关于这个包的说明很简单:A library for working with the LIDC dataset.
import torch
import SimpleITK as sitk
import pandas
import glob, os
import numpy
import tqdm
import pylidc
安装完之后,首先读取原来的标注文件。这个文件里记录了1000多个结节的坐标和直径信息。
annotations = pandas.read_csv('D:/lunadata/annotations.csv')
然后对我们的数据进行扫描,记录恶性数据,是否有缺失数据等等
malignancy_data = []
missing = []
spacing_dict = {}
scans = {s.series_instance_uid:s for s in pylidc.query(pylidc.Scan).all()}
suids = annotations.seriesuid.unique()
for suid in tqdm.tqdm(suids):
fn = glob.glob('D:/lunadata/subset*/{}.mhd'.format(suid))
if len(fn) == 0 or '*' in fn[0]:
missing.append(suid)
continue
fn = fn[0]
x = sitk.ReadImage(fn)
spacing_dict[suid] = x.GetSpacing()
s = scans[suid]
for ann_cluster in s.cluster_annotations():
is_malignant = len([a.malignancy for a in ann_cluster if a.malignancy >= 4])>=2
centroid = numpy.mean([a.centroid for a in ann_cluster], 0)
bbox = numpy.mean([a.bbox_matrix() for a in ann_cluster], 0).T
coord = x.TransformIndexToPhysicalPoint([int(numpy.round(i)) for i in centroid[[1, 0, 2]]])
bbox_low = x.TransformIndexToPhysicalPoint([int(numpy.round(i)) for i in bbox[0, [1, 0, 2]]])
bbox_high = x.TransformIndexToPhysicalPoint([int(numpy.round(i)) for i in bbox[1, [1, 0, 2]]])
malignancy_data.append((suid, coord[0], coord[1], coord[2], bbox_low[0], bbox_low[1], bbox_low[2], bbox_high[0], bbox_high[1], bbox_high[2], is_malignant, [a.malignancy for a in ann_cluster]))
这里能看到处理的进度条。
image.png
其中的miss用来记录是否有源文件(mhd文件损坏或者缺失),好在我这里是0缺失的。用原始数据信息去匹配我们读取的数据,
df_mal = pandas.DataFrame(malignancy_data, columns=['seriesuid', 'coordX', 'coordY', 'coordZ', 'bboxLowX', 'bboxLowY', 'bboxLowZ', 'bboxHighX', 'bboxHighY', 'bboxHighZ', 'mal_bool', 'mal_details'])
processed_annot = []
annotations['mal_bool'] = float('nan')
annotations['mal_details'] = [[] for _ in annotations.iterrows()]
bbox_keys = ['bboxLowX', 'bboxLowY', 'bboxLowZ', 'bboxHighX', 'bboxHighY', 'bboxHighZ']
for k in bbox_keys:
annotations[k] = float('nan')
for series_id in tqdm.tqdm(annotations.seriesuid.unique()):
# series_id = '1.3.6.1.4.1.14519.5.2.1.6279.6001.100225287222365663678666836860'
c = candidates[candidates.seriesuid == series_id]
a = annotations[annotations.seriesuid == series_id]
m = df_mal[df_mal.seriesuid == series_id]
if len(m) > 0:
m_ctrs = m[['coordX', 'coordY', 'coordZ']].values
a_ctrs = a[['coordX', 'coordY', 'coordZ']].values
#print(m_ctrs.shape, a_ctrs.shape)
matches = (numpy.linalg.norm(a_ctrs[:, None] - m_ctrs[None], ord=2, axis=-1) / a.diameter_mm.values[:, None] < 0.5)
has_match = matches.max(-1)
match_idx = matches.argmax(-1)[has_match]
a_matched = a[has_match].copy()
# c_matched['diameter_mm'] = a.diameter_mm.values[match_idx]
a_matched['mal_bool'] = m.mal_bool.values[match_idx]
a_matched['mal_details'] = m.mal_details.values[match_idx]
for k in bbox_keys:
a_matched[k] = m[k].values[match_idx]
processed_annot.append(a_matched)
processed_annot.append(a[~has_match])
else:
processed_annot.append(c)
processed_annot = pandas.concat(processed_annot)
processed_annot.sort_values('mal_bool', ascending=False, inplace=True)
processed_annot['len_mal_details'] = processed_annot.mal_details.apply(len)
我这块的输出显示没有需要丢掉的数据,那看起来LUNA数据集里提供的数据已经更新了。
image.png
最后把这个文件保存下来。
df_nona = processed_annot.dropna()
df_nona.to_csv('./data/part2/luna/annotations_with_malignancy.csv', index=False)
已经生成新的标注数据。
下面开始执行训练代码。首先还是创建缓存,结果这里遇到一个问题,代码接收的参数有问题,
在13章dset.py的49行,isMal_bool = {'False': False, 'True': True}[row[5]]
但实际上我们的文件里这一列存的是0.0和1.0,导致读取异常,把这里改成如下就能正常运行了。
isMal_bool = {'0.0': False, '1.0': True}[row[5]],接着启动缓存建设。
run('test13ch.prepcache.LunaPrepCacheApp')
在shell里面运行训练环节。
> python -m test13ch.training --epoch 20 --augmented final_seg
结果训练了一个epoch就内存溢出了。无奈,把batch size调小一点,从16改成了8个,这次就没问题了,我的设备还是不太行,真想买一台双3090Ti卡的机器。
image.png
这次看看效果。这里列出了第1,5,10,15,20个epoch的结果,可看到第1个epoch不管在训练集还是验证集的精确度很低,召回率还可以,在验证集上的fp(假阳性)达到了2442.7%,这主要是因为训练集使用的是裁剪后的小图片,而验证集使用的是完整的CT切片数据,所以假阳性很高也正常,多给出一些结果再让医生去看总比漏掉要好的多。
image.png
到了第5个epoch,精确度有所提升,训练集的f1达到0.71了
image.png
到了第10个epoch,又提升了一点点,但是验证集上给出的tp有些下降。
image.png
到15个epoch,在训练集上的效果持续提升,但是在验证集上的效果下降明显,tp值以及到了79%,说明这个时候已经出现了过拟合现象。
image.png image.png
下面去TensorBoard上去看看效果。蓝色是训练集,红色是验证集,首先是损失情况,在训练集上前期损失下降比较快,后面就比较平缓,在验证集上的损失变化不大。
image.png
然后是fn,fp,tp指
image.png
最后是f1 score,精确度,召回率,可以看到在经过了几个epoch之后验证集的召回率开始下降,出现了过拟合现象。
image.png
最后看一看导入TensorBoard的图像效果。带有label_x的表示这是一个标注图像,上面没有颜色的表名这个图像上都是无标注的,在对应的预测结果上,有一些橙色结果是假阳性预测,对于下面带绿色就是阳性标注及阳性预测结果。
image.png
image.png
image.png
image.png
看起来效果还不错,我们的这个模型就先训到这里,基本上可以满足我们的需求了。