基于视觉的多分辨率地图构建与定位程序说明

2019-01-02  本文已影响0人  Omar_4321

一.安装

库:

安装numpy、matplotlib、sklearn、scipy、PIL、opencv、pickle、pytorch(高于等于0.4)
代码在CycleGAN and pix2pix in PyTorch基础上编写。
Python版本为3.6(使用3.5和3.7也能运行),在Windows和Ubuntu下都能运行,windows下可能会报Lambda表达式打包的错误。

文件目录结构:

C:\CODE_P
├─alexnet 中间层特征可视化
├─Checkpoints 保存训练模型
├─Cluster 聚类相关程序
│ │ cluster_img.py 对图像聚类
│ │ gen_cell.py 生成Cell,主要为Cell后处理程序
├─Data 数据准备模块
│ │ base_dataset.py 基类,在送入网络前进行处理
│ │ base_data_loader.py 基类
│ │ Dataset_gather.py 根据不同参数调用不同的数据的具体实现
│ │ data_loader.py 数据下载类,根据不同参数调用不同的数据
│ │ Data_manage.py 数据管理, 读取、生成路径等
├─figure 保存图表
├─label 保存所有label
├─Models 网络模型相关
│ │ base_model.py 基类
│ │ double_threads.py 双线程示例
│ │ layers_trans.py 替换字符串string中指定位置p的字符为c,用于批量转换模型各层的名字
│ │ models.py 根据参数进行模型选择
│ │ model_set.py 网络调用、优化函数定义、前向和反向传播及损失值计算
│ │ networks.py 网络模型定义
│ │ resnet_layer_trans.txt

├─Options
│ │ options_set.py 参数定义
├─pre_model_state_dict
│ resnet18-5c106cde.pth 预训练模型
├─Result 结果
├─runs
│ │ 无视这个文件夹 │
├─util 计算图像中值
│ compute_image_mean.py│
├─Visualization 可视化相关程序
│ ├─test
└─

二.总体流程

下图是总体流程,三个部分分别为 特征提取、构建Cell和训练定位网络并定位三个部分

程序流程图

主程序

opt = BaseOptions().parse()#导入配置参数
clu = clusterdata()#实例化clusterdata类
datareader = dataread(opt)#实例化数据读取类
[gps_x,gps_y] = datareader.get_gps()#读取数据的GPS信息
c = dataset.num_img#各个车道的图像数量,左中右对应c[0]、c[1]、c[2]
ll2 = clu.cluster_sequence(length,200)#根据图像序列平均划分200个cell
ll3 = clu.cluster_sequence(length,600)
ll4 = clu.cluster_sequence(length,900)

#标注三个车道的图像为0、1、2
three_cla = numpy.zeros(length,dtype=int)
three_cla[0:c[0]] = 0
three_cla[c[0]:c[0]+c[1]] = 1
three_cla[c[0]+c[1]:c[0]+c[1]+c[2]] = 2
three_l = numpy.array(three_cla)
#numpy.savetxt('3.txt',three_l,fmt='%d')

#平均划分的CELL标号写入txt
f =open('label/seq200.txt','w')  
for j in range(len(img_dir)):
    text = str(img_dir[j][37:]) + ' ' + str(int(ll2[j]))                                
    f.write(text)
    f.write('\n')
f.close()
f =open('label/seq600.txt','w')  
for j in range(len(img_dir)):
    text = str(img_dir[j][37:]) + ' ' + str(int(ll3[j]))                                
    f.write(text)
    f.write('\n')
f.close()
f =open('label/seq900.txt','w')  
for j in range(len(img_dir)):
    text = str(img_dir[j][37:]) + ' ' + str(int(ll4[j]))                                
    f.write(text)
    f.write('\n')
f.close()
#将数据随机划分为无序的训练集和测试集
split(opt,'seq200')
split(opt,'seq600')
split(opt,'seq900')

train_extract_features(opt.num_outputs)#训练特提取网络
extract_features(opt.num_outputs)#使用训练好的网络提取图像特征
label = clu.clu_features(900)#根据新特征聚类
numpy.save('clu_900.npy',label)#保存聚类结果#label = numpy.load('clu_900.npy')#下载聚类结果
#显示聚类结果的柱形图
plt.figure(2)
plt.bar(numpy.arange(len(label)),label,width = 1)
plt.show()
#实例化生成Cell的类
cell_gen = CELL(label,900,100,0.16,0.5,opt)  #the third para should < 0.25,else all cells will be 0
[cells_num,lane_cells_count] = cell_gen.gen_cell()#生成Cell,即对聚类结果进行后处理

#cells_num = 631

train_clustered_cell(cells_num)#训练定位网络
single_localization(cells_num)#单张图片定位

train_cells(cells_num)#分层定位网络训练
localization(cells_num)#分层定位网络单张图片定位

Data模块

CreateDataLoader函数准备数据

    train_str = 'train'
    test_str = 'test'
    dataset_train = CreateDataLoader(opt,train_str,isTrain = True)
    dataset_test = CreateDataLoader(opt,test_str,isTrain = False)

CreateDataLoader,实例化CustomDatasetDataLoader类并进行初始化

def CreateDataLoader(opt,phase,isTrain):
    data_loader = CustomDatasetDataLoader()
    print(data_loader.name())
    data_loader.initialize(opt,phase,isTrain)
    return data_loader

CustomDatasetDataLoader类实现如下:重写BaseDataLoader类,并在初始化时通过调用CreateDataset函数选择数据集,再定义torch.utils.data.DataLoader中的参数,如batch大小,是否给顺序等

class CustomDatasetDataLoader(BaseDataLoader):
    def name(self):
        return 'CustomDatasetDataLoader'

    def initialize(self, opt,phase,isTrain):
        BaseDataLoader.initialize(self, opt)
        self.dataset = CreateDataset(opt,phase,isTrain)
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=opt.batchSize,
            #shuffle = opt.shuffle if isTrain else not opt.shuffle,
            #shuffle = False, 
            shuffle= isTrain,
            num_workers=int(opt.nThreads))
        print('-----------------dataloader------------------')
        #print(self.dataset)
    def load_data(self):
        return self

    def __len__(self):
        return min(len(self.dataset), self.opt.max_dataset_size)

    def __iter__(self):
        for i, data in enumerate(self.dataloader):
            if i >= self.opt.max_dataset_size:
                break
            yield data

CreateDataset函数实现如下:通过opt.dataset_mode参数选择数据集

def CreateDataset(opt,phase,isTrain):
    dataset = None
    if opt.dataset_mode == 'c3_Dataset':
        from Data.Dataset_gather import c3_Dataset
        dataset = c3_Dataset()
    elif opt.dataset_mode == 'seq_Dataset':
        from Data.Dataset_gather import seq_Dataset
        dataset = seq_Dataset()
    elif opt.dataset_mode == 'cells_Dataset':
        from Data.Dataset_gather import cells_Dataset
        dataset = cells_Dataset()
    elif opt.dataset_mode == 'clustered_cells_Dataset':
        from Data.Dataset_gather import clustered_cells_Dataset
        dataset = clustered_cells_Dataset()
    elif opt.dataset_mode == 'other_test_dataset':
        from Data.Dataset_gather import other_test_dataset
        dataset = other_test_dataset()
    else:
        raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)

    print("dataset [%s] was created" % (dataset.name()))
    dataset.initialize(opt,phase,isTrain)
    return dataset

以cells_Dataset数据为例,先下载对应的train.txt和test.txt以及单张图片定位的txt。调用get_transform_函数,定义数据处理过程,这个函数中的处理过程在自动调用__getitem__函数时会自动进行,如对图片进行剪裁、缩放等。__getitem__函数在

for i, data in enumerate(dataset_train):

循环中会在每一次迭代时自动调用,返回的data即为return的数据
for i, data in enumerate(dataset_train):

class cells_Dataset(BaseDataset):
    def initialize(self, opt ,phase ,isTrain):
        self.opt = opt
        self.root = opt.coderoot
        self.transform_flag = True
        str_train = '/label/cell_'+ str(opt.num_outputs) +'_train.txt'
        str_test = '/label/cell_'+ str(opt.num_outputs) +'_test.txt'
        str_localiza = '/label/cell_'+ str(opt.num_outputs) +'.txt'
        if(phase == 'train'):
            split_file = self.root + str_train
#            split_file.replace(''\'',''/'')
            isTrain = True
        elif(phase == 'test'):
            split_file = self.root + str_test
            isTrain = False
        else:
            #isTrain = True
            self.transform_flag = False
            split_file = self.root + str_localiza
        self.path = numpy.loadtxt(split_file, dtype=str, delimiter=' ', skiprows=0, usecols=(0))
        #self.path = [os.path.join(self.opt.dataroot, path) for path in self.path]
        self.path = [(self.opt.dataroot + path) for path in self.path]
        self.lane= numpy.loadtxt(split_file, dtype=float, delimiter=' ', skiprows=0, usecols=(1))
        self.cell= numpy.loadtxt(split_file, dtype=float, delimiter=' ', skiprows=0, usecols=(2))
        self.mean_image = numpy.load(os.path.join(self.opt.dataroot , 'mean_image.npy'))#下载中值文件
        self.size = len(self.path)
        print('len(self.path):{:}'.format(self.size))
        self.transform = get_transform_(opt,self.mean_image,self.transform_flag)#定义数据处理过程
        #self.num_outputs = opt.num_outputs
    def __getitem__(self, index):
        path = self.path[index % self.size]
        A_img = Image.open(path).convert('RGB')
        #A_img.save('pic/'+path[-9:])
        #print('************')
        cell = self.cell[index % self.size]
        lane = self.lane[index % self.size]
        img = self.transform(A_img)
        return {'img': img, 'cell': cell,
                'path': path,'lane':lane}

    def __len__(self):
        return self.size

    def name(self):
        return 'cells_Dataset'

get_transform_函数的定义如下:使用lambda表达式将函数打包到transforms,在每次执行上面的__getitem__函数时,这些lambda表达式封装的函数都会对每张图片进行处理。

def get_transform_(opt, mean_image,isTrain = True):
    transform_list = []
    transform_list.append(transforms.Resize(opt.loadSize, Image.BICUBIC))
    transform_list.append(transforms.Lambda(lambda img: __subtract_mean(img, mean_image)))
    transform_list.append(transforms.Lambda(lambda img: __crop_image(img, opt.fineSize, isTrain)))
    transform_list.append(transforms.Lambda(lambda img: __to_tensor(img)))
    return transforms.Compose(transform_list)

def __scale_width(img, target_width):
    ow, oh = img.size
    if (ow == target_width):
        return img
    w = target_width
    h = int(target_width * oh / ow)
    return img.resize((w, h), Image.BICUBIC)

def __subtract_mean(img, mean_image):
    if mean_image is None:
        return numpy.array(img).astype('float')
    return numpy.array(img).astype('float') - mean_image.astype('float')

def __crop_image(img, size, isTrain):
    h, w = img.shape[0:2]
    # w, h = img.size
    if isTrain:
        if w == size and h == size:
            return img
        x = numpy.random.randint(0, w - size)
        y = numpy.random.randint(0, h - size)
    else:
        x = int(round((w - size) / 2.))
        y = int(round((h - size) / 2.))
    return img[y:y+size, x:x+size, :]
    # return img.crop((x, y, x + size, y + size))

def __to_tensor(img):
    return torch.from_numpy(img.transpose((2, 0, 1)))

Model模块

model = create_model(opt)

create_model创建model,根据 opt.model参数创建用于特征训练、定位训练和车道分类的网络

def create_model(opt,istest = False):
    model = None
    print(opt.model)
    if opt.model == 'RESNET18'
        from .model_set import RESNET18Model
        model = RESNET18Model():  #训练特征提取网络、定位网络
    elif opt.model == 'RESNET18_CELL':   
        from .model_set import RESNET18Model_CELL
        model = RESNET18Model_CELL() :#训练分层定位网络
    elif opt.model == 'RESNET18_3':   
        from .model_set import RESNET18Model_3
        model = RESNET18Model_3() #训练车道分类网络
    else:
        raise ValueError("Model [%s] not recognized." % opt.model)
    model.initialize(opt, istest)
    #print("model [%s] was created" % (model.name()))
    return model
    def save_network(self, network, network_label, epoch):
        save_filename = '%s_net_%s.pth' % (network_label, epoch)
        save_path = os.path.join(self.save_dir, '%s_%s'%(self.opt.dataset_mode,self.opt.num_outputs))
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        save_path = os.path.join(save_path,save_filename)
        torch.save(network.state_dict(), save_path)

以RESNET18Model类为例,讲解Model类的功能。RESNET18Model类重写了BaseModel类,BaseModel类中有个重要的函数实现,即save_network函数,在base_model文件中。

class RESNET18Model(BaseModel):
    def name(self):
        return 'RESNET18'
    def initialize(self, opt,isTest = False):#调用resnet网络结构;定义优化方法为SGD;定义训练策略lr_scheduler.StepLR
        self.opt = opt
        BaseModel.initialize(self, opt)
        self.isTrain = not isTest
        self.net = networks.RESNET18(opt.num_outputs,isTest)#调用net为networks模块下的RESNET18网络
        if self.isTrain:
            self.old_lr = opt.lr
            self.criterion = torch.nn.CrossEntropyLoss()    #定义损失函数为交叉熵函数       
            self.optimizers = []
            self.optimizer_A = torch.optim.SGD(self.net.parameters() , lr = opt.lr , momentum = 0.9)#定义优化方法为SGD
            self.optimizers.append(self.optimizer_A)
            self.schedulers = lr_scheduler.StepLR(self.optimizer_A, step_size=10, gamma=0.9)#定义训练策略,每10个epoch学习率×0.9
    def set_input(self, input):#设置输出图像
        self.input_img = input['img']
        self.cell = input['cell']
        self.image_paths = input['path']        
    def forward(self):#推理函数
        self.input_img = Variable(self.input_img.float().cuda())
        [self.features,self.pred] = self.net(self.input_img)
        Z = F.softmax(self.pred,dim=1)#获得softmax输出
        _ , self.preds_= torch.max(Z, 1)#获得softmax输出中概率最大的类
    def extract_features(self):
        f = deepcopy(self.features.data.cpu().numpy())#提取特征
        return f
    def testnet(self):#测试,只推理不backward
        self.forward()
        
    def trainnet(self):#训练
        self.optimize_parameters()
    def get_pred_result(self):
        return self.preds_
    def get_image_paths(self):
        return self.image_paths
    def backward(self):#反向传播
        self.loss = self.criterion(self.pred,self.cell.long().cuda())
        self.loss.backward()
    def optimize_parameters(self):#训练优化
        self.forward()
        self.optimizer_A.zero_grad()
        self.backward()
        self.optimizer_A.step()
    def get_current_acc(self,opt):#得到每个batch的正确率
        self.cell = self.cell.long().cuda()
        self.running_corrects = int(torch.sum(self.preds_ == self.cell.data))
        return self.running_corrects    
    def get_current_loss(self,opt):#得到损失值
        self.loss = self.criterion(self.pred,self.cell.long().cuda())
        return float(self.loss)
    def save(self, epoch):
        self.save_network(self.net, 'RESNET18', epoch)
    def forward_singlepic(self):#单张图片推理
        self.input_img = Variable(self.input_img.float().cuda())
        [self.features,self.pred] = self.net(self.input_img)
        Z = F.softmax(self.pred,dim=1)
        _ , self.preds_= torch.max(Z, 1)
        return self.preds_

networks模块定义了各个网络的结构,其中class RESNET18(torch.nn.Module):是继承了torch.nn.Module类,重写了初始化函数 init和前向传播函数forward,在网络喂入图片数据后自动调用forward函数。 torch.load 返回的是一个 OrderedDict。关于模型和权重下载以及权重保存格式等,可以阅读这个博客。self.model.eval()将网络调到测试模式,测试模式时对ropout和batch normalization层的操作在训练和测试的时候是不一样的,具体讲解看这个博客

class RESNET18(torch.nn.Module):
    """Constructs a ResNet-18 model.
    """
    def __init__(self, num_output, isTest=False,  gpu_ids=[]):
        super(RESNET18, self).__init__()        
        self.model_name = 'resnet18'
        self.gpu_ids = gpu_ids
        state_dict = (torch.load('C:/code_p/Checkpoints/RESNET18/clustered_cells_Dataset_631/RESNET18_net_068.pth'))#预训练权重,其数据结构是每个键对应一个层
        self.model = ResNet(BasicBlock, [2, 2, 2, 2], num_output)#定义ResNet的具体网络结构
        pretrained = True
        if pretrained:
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[6:] # remove `module.`
                new_state_dict[name] = v
                #print(v.size())
            self.model.load_state_dict(new_state_dict,strict = True) #将权重下载到模型中,以模型各层的名字为准,名字不对应则报错,如果strict = False,名字不对应则直接略过。
        if isTest:
            self.model.eval()#在测试模式下
        self.model.eval()
        self.model = self.model.cuda()
        print(self.model)
    def forward(self, x):#前向传播函数
        out = self.model(x)
        return out 

还有网络结构的具体实现,这部分为官方对resnet18的实现源码。这里不讲解,可以去网上搜一下资料。

训练函数

以训练特征提取网络为例,先创建数据,然后创建模型,在每次epoch中进行一次训练和一次测试。

def train_extract_features(num_outputs):
    opt = BaseOptions().parse()
    train_str = 'train'
    test_str = 'test'
    dataset_train = CreateDataLoader(opt,train_str,isTrain = True)#创建训练数据
    dataset_test = CreateDataLoader(opt,test_str,isTrain = False)#创建测试数据
    dataset_size_train = len(dataset_train)
    dataset_size_test = len(dataset_test)
    model = create_model(opt)
    Loss_list_train = []
    Loss_list_test = []
    Accuracy_list_train = []
    Accuracy_list_test = []
    for epoch in range(opt.num_epochs):#epoch 
        epoch_acc_train = 0
        epoch_acc_test = 0
        epoch_loss_train = 0
        epoch_loss_test = 0
        print('Training...')
        for i, data in enumerate(dataset_train): #iter
            #print('[%04d/%04d] ' % (i, len(dataset_train)/opt.batchSize), end='\r')
            model.set_input(data)#输入数据
            model.trainnet()#训练网络
            running_corrects = model.get_current_acc(opt)
            running_loss = model.get_current_loss(opt)
            epoch_acc_train = running_corrects + epoch_acc_train
            epoch_loss_train = running_loss + epoch_loss_train
            Loss_list_train.append(running_loss)
            data_batch_size = len(data['cell'])
            Accuracy_list_train.append(running_corrects/data_batch_size)
            #print(running_loss)
            #print(running_corrects)
            print('[%04d/%04d] ------------------  corrects: %04f-------------------' % (i, len(dataset_train)/opt.batchSize,epoch_acc_train/(i+1)/data_batch_size), end='\r')
        epoch_loss_train = epoch_loss_train/(i+1)
        epoch_acc_train = epoch_acc_train*100/dataset_size_train
        print(' Train epoch {:}:---- lr:{:} ----Acc: {:.4f}%  loss:{:.4f}' .format(epoch,opt.lr,epoch_acc_train,epoch_loss_train))
        print('Test...')
        for i, data in enumerate(dataset_test):
            model.set_input(data)
            istest = True
            model.testnet()
            running_corrects = model.get_current_acc(opt)
            running_loss = model.get_current_loss(opt)
            epoch_acc_test = running_corrects + epoch_acc_test
            epoch_loss_test = running_loss + epoch_loss_test
            Loss_list_test.append(running_loss)
            data_batch_size = len(data['cell'])
            Accuracy_list_test.append(running_corrects/data_batch_size)
            print('[%04d/%04d] ------------------  corrects: %04f-------------------' % (i, len(dataset_test)/opt.batchSize,epoch_acc_test/(i+1)/data_batch_size), end='\r')
        epoch_loss_test = epoch_loss_test/(i+1)
        epoch_acc_test = epoch_acc_test*100/dataset_size_test
        print(' Test epoch {:}:---- lr:{:} ----Acc: {:.4f}%  loss:{:.4f}' .format(epoch,opt.lr,epoch_acc_test,epoch_loss_test))
        model.save(epoch)
        if(epoch%1 == 0):
            numpy.save('D:/figure/Loss_list_train_600_.npy',Loss_list_train)#保存loss
            numpy.save('D:/figure/Loss_list_test_600_.npy',Loss_list_test)
            numpy.save('D:/figure/Accuracy_list_train_600_.npy',Accuracy_list_train)
            numpy.save('D:/figure/Accuracy_list_test_600_.npy',Accuracy_list_test)
        if((epoch_acc_train>99.9)&(epoch_acc_test>99.9)):
            break

GEN_CELL模块

GEN_CELL模块是对聚类后的cell进行后处理生成最终cell的模块

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Jul 28 10:19:06 2018

@author: zs
"""
import numpy 
import os
from matplotlib import pyplot as plt
class CELL():
    def __init__(self,label,k,outline_range,outline_range_threshold,kkk,opt):
        ''''
        outline_range、outline_range_threshold分别是统计范围和小Cell的阈值,小于这个阈值则合并,kkk*outline_range是需要处理的范围,kk是个系数。
        ''''
        self.root = opt.coderoot
        self.outline_range = int(outline_range)
        self.outline_range_threshold = outline_range_threshold
        self.num = k
        self.kkk = kkk
        self.new_index = numpy.zeros(k)
        self.index = label
        self.dataset_size = len(label)
        self.new_label = numpy.zeros(label.shape)
        str_path = 'label/seq'+ str(opt.num_outputs) +'.txt'
        split_file = os.path.join(self.root , str_path)
        self.path = numpy.loadtxt(split_file, dtype=str, delimiter=' ', skiprows=0, usecols=(0))
        str_path = 'label/3.txt'
        split_file = os.path.join(self.root , str_path)
        self.lines = numpy.loadtxt(split_file, dtype=int, delimiter=' ', skiprows=0, usecols=(1))
        print(self.new_label.shape)
    def idx_transformation(self):#将聚类形成无序label转换成以图像序列为准的有序label
        n = 0
        for i in range(0,self.dataset_size):
            if(self.new_index[self.index[i]] == 0):
                self.new_index[self.index[i]] = n
                n+=1
        for i in range(0,self.dataset_size):
            self.new_label[i] = int(self.new_index[self.index[i]])-1
            #print(self.new_label[i])
        #self.lll = self.new_label.copy()
        
        save_dir = 'C:/code_p/label/new_label%d.txt'%(self.num)
        numpy.savetxt(save_dir,self.new_label,fmt='%d')
        self.label_removed = self.new_label.copy()
    def check_same_cell_differentlane(self):#检查小的cell和包含不同车道图像的cell
        lane_cells_count = numpy.zeros(3,dtype=numpy.int)
        for j in range(len(self.lines)):
            lane_cells_count[self.lines[j]] += 1
        lane_cells_count[1] = lane_cells_count[1]+lane_cells_count[0]
        print(lane_cells_count)
        save_dir = 'C:/code_p/label/ttt.txt'
        f =open(save_dir,'w')  
        for j in range(len(self.new_label)):
            text =str(self.new_label[j])                                 
            f.write(text)
            f.write('\n')
        f.close
        for i in range(len(lane_cells_count)-1):
            if (self.new_label[lane_cells_count[i]] == self.new_label[lane_cells_count[i]+1]):
                for j in range(lane_cells_count[i]+1,len(self.new_label)):
                    self.new_label[j] += 1        
        
        
    def remove_smallcell(self):#移除小的cell
        print('remove small cell...')
        kkk = self.kkk
        for i in range(0,len(self.new_label)-self.outline_range):
            #print(i)
            global point_sample
            point_sample = numpy.zeros(int(2000),dtype=numpy.int)#保存一定范围内各类cell的数量,因为我们的cell数量这里不超过2000,因此长度设为2000,保证不会超出
            #ll_ = numpy.zeros(int(2000),dtype=numpy.int)
            #print(int(self.outline_range/2))
            #last = self.label_removed[i]
            #nn = 0
            #ll_[0] = self.label_removed[i]
            for j in range(0,self.outline_range):
#                if (self.label_removed[i+j] != last):
#                    nn += 1
#                    ll_[nn] = self.label_removed[i+j]
                    
                point_sample[int(self.label_removed[i+j])] += 1#在[i,i+outline_range]范围内统计每类标签的数量
                #last = self.label_removed[i+j]
            for ii in range(0,2000):
                flag_remove_once = 0
                if(point_sample[ii]<=self.outline_range*self.outline_range_threshold)and(point_sample[ii]>0):   #如果此类cell数量不是0并且小于阈值                 
                    for jj in range(int(self.outline_range/2-self.outline_range*kkk/2),int(self.outline_range/2+self.outline_range*kkk/2)):#对统计处理范围内的小cell进行合并
                        #print(ll.shape())
                        #print(ll_)
                        if (self.label_removed[i+jj] == ii):
                            print('remove%d'%(self.new_label[i+jj]))
                            self.label_removed[i+jj] = self.label_removed[i+jj-1]
                            flag_remove_once = 1
                            print(point_sample)
                            for iii in range(len(point_sample)):
                                if(point_sample[iii] >0):
                                    print(point_sample[iii])
                            #print(ll_)
#                        plt.bar(range(i+0,i+self.outline_range),self.label_removed[i:i+self.outline_range],width = 1)
#                        plt.show()
                if flag_remove_once:
                    i -= 1#保证可以处理交叉的CELL
                    break
        #for i in range(0,len(self.new_label)):
            #print(self.new_label[i],'---',self.lll[i]) 
            #if(abs(self.new_label[i]-self.lll[i])>0.1):
                #print('remove label: %d'%(self.lll[i]))
        save_dir = 'C:/code_p/label/label_removed%d.txt'%(self.num)
        numpy.savetxt(save_dir,self.label_removed,fmt='%d')                  
    def checkandsort_cell(self):#对重复出现的大cell赋予新的标号
#        for i in range(0,len(self.new_label)):
#            #print(self.new_label[i],'---',self.lll[i]) 
#            if(abs(self.new_label[i]-self.lll[i])>0.1):
#                print('remove label')#: %d'%(self.lll[i]))
        print('check and sort cell...')
        cell= numpy.zeros(len(self.label_removed),dtype=numpy.int)
        last = 0
        class_plus = 0
        reco = numpy.zeros(1000,dtype=numpy.int)
        self.count = 0
        for i in range(0,len(self.label_removed)):
            n = int(self.label_removed[i])
            if(reco[n]> 0)and(abs(n - last)>0.1):
                #print('check repeated cell: %d:'%(n))
                self.count += 1
            if(abs(n-last)>0.1):
                class_plus += 1
            #print(class_plus)
            cell[i] = class_plus
            last = n
            reco[n] += 1
        print('check repeated  %d cells'%(self.count))
        return cell
    def cou_cell(self):#统计处理前后的cell数量变化
        nn = int(self.new_label[-1]+1)
        print(nn)
        cm = numpy.zeros(nn,dtype=numpy.int)
        for i in range(0,len(self.new_label)):
            cm[int(self.new_label[i])-1] += 1
        print('before remove max: %d'%(max(cm)))
        print('before remove min: %d'%(min(cm)))
        #print(cm)
        cm = numpy.zeros(nn,dtype=numpy.int)
        for i in range(0,len(self.label_removed)):
            cm[int(self.label_removed[i])-1] += 1
        print('after remove max: %d'%(max(cm)))
        print('after remove min: %d'%(min(cm)))
        #print(cm)
    def gen_cell(self):
        self.idx_transformation()
        #self.check_same_cell_differentlane()
        self.remove_smallcell()
        
        self.cou_cell()
        cell = self.checkandsort_cell()
        save_dir = 'C:/code_p/label/cell_%d.txt'%(cell[-1]+1)
        self.path
        lane_cells_count = numpy.zeros(3,dtype=numpy.int)
        f =open(save_dir,'w')  
        for j in range(len(cell)):
            lane_cells_count[self.lines[j]] += 1
            text = self.path[j]+' '+str(self.lines[j])+' ' +str(cell[j])                                 
            f.write(text)
            f.write('\n')
        f.close
        print(lane_cells_count)
        lane_branch_start_cell = [0,cell[lane_cells_count[0]],cell[lane_cells_count[0]+lane_cells_count[1]]]
        print(lane_branch_start_cell)
        lane_cells_cla = [cell[lane_cells_count[0]-1]+1,cell[lane_cells_count[1]+lane_cells_count[0]-1]-cell[lane_cells_count[0]]+1,cell[lane_cells_count[2]+lane_cells_count[1]+lane_cells_count[0]-1]-cell[lane_cells_count[1]+lane_cells_count[0]]+1]
        #lane_cells_cla = [lane_cells_cla[0],lane_cells_cla[1]-]
        
        numpy.save('C:/code_p/label/lane_cells_cla_900.npy',lane_cells_cla)
        numpy.save('C:/code_p/label/lane_branch_start_cell_900.npy',lane_branch_start_cell)

        return [cell[-1]+1,lane_cells_cla]

论文第四章是分层定位网络相关的内容,主要根据浅层的特征对数据进行预分类,这应该不会是你的下一步重点,如果你想了解,相关程序主要在以下两个类:
model_set 文件内的

class RESNET18Model_CELL(BaseModel):

networks文件内的

class ResNet_redefined_cell(nn.Module):
上一篇下一篇

猜你喜欢

热点阅读