目标检测从入门到入土

SR_ESRGAN运行指南

2019-05-27  本文已影响0人  zestloveheart

介绍

ESRGAN是一个较新的的低分辨率转高分辨率的GAN模型,在SRGAN的基础上做了增强。
其论文在ESRGAN论文
其代码在ESRGAN仓库,该仓库只提供了简单的demo测试代码。完整的训练和测试代码在BasicSR仓库中。
如果要进一步学习,给出2篇论文综述作为参考:
综述1
综述2

初次运行ESRGAN

  1. 安装环境
    conda install numpy
    pip install opencv-python==3.4.5.20
    conda install python-lmdb
    pip install tensorboardX
    
    # 进入 https://pytorch.org/get-started/locally/ 找到安装pytorch合适的指令。我这里原来是Linux conda python3.6 CUDA10
    conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
    # 然而由于conda镜像没了,需要用pip了
    pip3 install https://download.pytorch.org/whl/cu100/torch-1.1.0-cp37-cp37m-linux_x86_64.whl
    pip3 install https://download.pytorch.org/whl/cu100/torchvision-0.3.0-cp37-cp37m-linux_x86_64.whl
    
    
  2. 拉代码git clone https://github.com/xinntao/ESRGAN.git
  1. 下载模型到models中
    https://pan.baidu.com/s/1-Lh6ma-wXzfH8NqeBtPaFQ

  2. 运行下面的代码,结果在result中。
    python test.py models/RRDB_ESRGAN_x4.pth
    python test.py models/RRDB_PSNR_x4.pth

初次使用BasicSR测试ESRGAN(SRGAN)模型

  1. 拉代码 git clone https://github.com/xinntao/BasicSR.git
  2. 进入codes文件夹cd codes
  3. 修改 options/test/test_ESRGAN.json
    1. datasets dataroot_HR 将后面路径改为自己的训练数据文件夹,文件夹内存放的是png文件;或者改为lmdb文件。
    2. path root 改为自己的BasicSR项目路径
    3. 将刚刚在ESRGAN中用到的model放到pretrain_model_G的目录下面。
    4. 其他暂时不用动,我本机配置如下所示。
    {
        "name": "RRDB_ESRGAN_x4"
        , "suffix": "_ESRGAN"
        , "model": "srragan"
        , "scale": 4
        , "gpu_ids": [0]
    
        , "datasets": {
            "test_1": { // the 1st test dataset
            "name": "set5"
            , "mode": "LRHR"
            , "dataroot_HR": "/root/addition_store/DIV2K_train_HR"
            }
        }
    
        , "path": {
            "root": "/home/student_docker/zlh/BasicSR"
            , "pretrain_model_G": "../experiments/pretrained_models/RRDB_ESRGAN_x4.pth"
        }
    
        , "network_G": {
            "which_model_G": "RRDB_net" // RRDB_net | sr_resnet
            , "norm_type": null
            , "mode": "CNA"
            , "nf": 64
            , "nb": 23
            , "in_nc": 3
            , "out_nc": 3
    
            , "gc": 32
            , "group": 1
        }
    }
    
  4. 运行测试代码 python test.py -opt options/test/test_ESRGAN.json
  5. 如果需要跑其他的测试代码,见其他测试
测试成功

训练ESRGAN(SRGAN)模型

准备数据(DIV2K)

  1. DIV2K official page或者百度云下载
  2. 有几个方法可以让IO速度变快
    1. 将HDD改成SSD
    2. 将图片数据集改成更小的子图切片(sub-images)。见3和4
    3. 将原始数据改成lmdb格式。见5和6
  3. 修改codes/scripts/extract_subimgs_single.py文件的路径
    input_folder = '/root/addition_store/DIV2K_train_HR' # 输入图片路径
    save_folder = '/root/addition_store/DIV2K_train_HR_sub' # 输出图片路径
    
  4. 运行 python scripts/extract_subimgs_single.py 执行切片操作
  5. 修改codes/scripts/create_lmdb.py
    img_folder = '/root/addition_store/DIV2K_train_HR_sub/*'  # glob matching pattern
    lmdb_save_path = '/root/addition_store/DIV2K_train_HR_sub.lmdb'  # must end with .lmdb
    mode = 2
    
  6. 运行 python scripts/create_lmdb.py 将数据改成lmdb格式

训练

  1. 修改options/train/train_ESRGAN.json

    "name": "002_RRDB_ESRGAN_x4_DIV2K"
    "train" "dataroot_HR": "/root/addition_store/DIV2K_train_HR_sub.lmdb"
    "val" "dataroot_HR": "/root/addition_store/DIV2K_valid_HR"
    "path" "root": "/home/student_docker/zlh/BasicSR"
    
  2. 运行 python train.py -opt options/train/train_ESRGAN.json

    运行成功
  3. tensorboard可视化 tensorboard --logdir=../tb_logger
    进入http://localhost.localdomain:6006可看到训练过程

    tensorboard可视化

参考

ESRGAN仓库
BasicSR仓库
代码结构介绍

上一篇下一篇

猜你喜欢

热点阅读