人工智能/模式识别/机器学习精华专题深度学习·神经网络·计算机视觉大数据,机器学习,人工智能

CVPR 2017 Feedback-Network 的 pyt

2018-05-21  本文已影响12人  Meng_Blog

项目地址

我的github地址

目的

根据 Feedback-Network (CVPR 2017, Zamir et al.) 论文提出的反馈网络结构,对CIFAR100或类似数据集进行分类。当前实现了CIFAR100数据集上的训练和测试,基本达到论文效果。

结果

Requirements

步骤

使用Pytorch为工具,实现CIFAR100数据集分类。

方法

训练代码流程

  1. Hyper-params: 设置数据加载路径、模型保存路径、初始学习率等参数。
  2. Training parameters: 用于定义模型训练中的相关参数,例如最大迭代次数、优化器、损失函数、是否使用GPU等、模型保存频率等
  3. load data: 定义了用于读取数据,在其中实现了数据、标签读取及预处理过程。预处理过程在__getitem__中。
  4. models: 定义的FeedbackNet类,并实例化
  5. optimizer、criterion、lr_scheduler: 定义优化器为SGD优化器,损失函数为CrossEntropyLoss,学习率调整策略采用ReduceLROnPlateau。
  6. trainer: 定义了用于模型训练和验证的类Trainer,trainer为Trainer的实例化。在Trainer的构造函数中根据步骤二中的参数设定,对训练过程中的参数进行设置,包括训练数据、测试数据、模型、是否使用GPU等。
    Trainer中定义了训练和测试函数,分别为train()_val_one_epoch()train()函数中,根据设定的最大循环次数进行训练,每次循环调用_train_one_epoch()函数进行单步训练。

测试代码流程

  1. Test parameters: 用于定义模型测试中的相关参数
  2. models: 定义的FeedbackNet类,并实例化
  3. tester: 对测试类Tester实例化,Tester中主要进行模型加载函数与预测函数。
    _load_ckpt()函数加载模型;
    test()函数进行预测,其中定义了对单张图片进行预处理的过程,并输出预测结果。

参考

上一篇下一篇

猜你喜欢

热点阅读