English BootCamp英语训练营

代码实现(5)Relation Network for Few-

2019-08-06  本文已影响0人  续袁

1.环境要求

(1)Pytorch 1.0
(2)python 3.6
(3) numpy
(4) scipy
(5)matplotlib
(6)torchvision
(7)PIL (高版本的Python安装pillow)

1.2

conda install torchvision -c pytorch

2. 代码运行

2.1 问题:原程序是在GPU,改为CPU

# 第一步: 注释掉一下两行代码

 # feature_encoder.cuda(GPU)
   # relation_network.cuda(GPU)
# 第二步: 添加参数 ,map_location='cpu'
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.
# 解决方法:添加参数 ,map_location='cpu'
    if os.path.exists(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
        feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"),map_location='cpu'))
        print("load feature encoder success")
    if os.path.exists(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
        relation_network.load_state_dict(torch.load(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"),map_location='cpu'))
        print("load relation network success")

2.2 问题:KeyError: '..\datas\omniglot_resized'

 Linux和window路径的转换
解决方法:把'/'改成'\\'即可 
def get_class(self, sample):
        return os.path.join(*sample.split('\\')[:-1]) 

2.3 问题:报错信息:

File "/LearningToCompare_FSL-master/miniimagenet/miniimagenet_train_few_shot.py", line 193, in main
    torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1, 1), 1))
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #3 'index'

解决方法: 在前面加一句 : batch_labels = batch_labels.long()

2.4 IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

报错信息:
File "LearningToCompare_FSL-master/miniimagenet/miniimagenet_train_few_shot.py", line 212, in main
    print("episode:",episode+1,"loss",loss.data[0])
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

按要求改成
print("episode:", episode + 1, "loss", loss.item())
就可以了

2.5 问题

  File "C:/Users/xpb/PycharmProjects/LearningToCompare_FSL-master/omniglot/omniglot_train_one_shot.py", line 268, in <listcomp>
    rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(CLASS_NUM)]
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'other'

解决方法: 在前面加上
predict_labels = predict_labels.long()
test_labels = test_labels.long()

3.代码解读

代码

[1] floodsung/LearningToCompare_FSL
[2] prolearner/LearningToCompareTF

参考资料

[1] torchvision库简介(翻译)
[2] Pytorch——计算机视觉工具包:torchvision
[3] Python---python3.7.0---如何安装PIL
[4] Python图像处理PIL各模块详细介绍

问题解决

[0] Learning to Compare: Relation Network 源码调试
[1] 关于Python读取文件的路径中斜杠问题
[2] python把路径中反斜杠''变为'/'
[3] # python路径拼接os.path.join()函数的用法
[4] RuntimeError: Attempting to deserialize object on CUDA device 2 but torch.cuda.device_count() is 1
[5] Pytorch的GPU计算(cuda)

论文

[1] Learning to Compare: Relation Network for Few-Shot Learning

上一篇下一篇

猜你喜欢

热点阅读