代码实现(5)Relation Network for Few-
2019-08-06 本文已影响0人
续袁
1.环境要求
(1)Pytorch 1.0
(2)python 3.6
(3) numpy
(4) scipy
(5)matplotlib
(6)torchvision
(7)PIL (高版本的Python安装pillow)
1.2
conda install torchvision -c pytorch
2. 代码运行
2.1 问题:原程序是在GPU,改为CPU
# 第一步: 注释掉一下两行代码
# feature_encoder.cuda(GPU)
# relation_network.cuda(GPU)
# 第二步: 添加参数 ,map_location='cpu'
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.
# 解决方法:添加参数 ,map_location='cpu'
if os.path.exists(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"),map_location='cpu'))
print("load feature encoder success")
if os.path.exists(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
relation_network.load_state_dict(torch.load(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"),map_location='cpu'))
print("load relation network success")
2.2 问题:KeyError: '..\datas\omniglot_resized'
Linux和window路径的转换
解决方法:把'/'改成'\\'即可
def get_class(self, sample):
return os.path.join(*sample.split('\\')[:-1])
2.3 问题:报错信息:
File "/LearningToCompare_FSL-master/miniimagenet/miniimagenet_train_few_shot.py", line 193, in main
torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1, 1), 1))
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #3 'index'
解决方法: 在前面加一句 : batch_labels = batch_labels.long()
2.4 IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
报错信息:
File "LearningToCompare_FSL-master/miniimagenet/miniimagenet_train_few_shot.py", line 212, in main
print("episode:",episode+1,"loss",loss.data[0])
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
按要求改成
print("episode:", episode + 1, "loss", loss.item())
就可以了
2.5 问题
File "C:/Users/xpb/PycharmProjects/LearningToCompare_FSL-master/omniglot/omniglot_train_one_shot.py", line 268, in <listcomp>
rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(CLASS_NUM)]
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'other'
解决方法: 在前面加上
predict_labels = predict_labels.long()
test_labels = test_labels.long()
3.代码解读
代码
[1] floodsung/LearningToCompare_FSL
[2] prolearner/LearningToCompareTF
参考资料
[1] torchvision库简介(翻译)
[2] Pytorch——计算机视觉工具包:torchvision
[3] Python---python3.7.0---如何安装PIL
[4] Python图像处理PIL各模块详细介绍
问题解决
[0] Learning to Compare: Relation Network 源码调试
[1] 关于Python读取文件的路径中斜杠问题
[2] python把路径中反斜杠''变为'/'
[3] # python路径拼接os.path.join()函数的用法
[4] RuntimeError: Attempting to deserialize object on CUDA device 2 but torch.cuda.device_count() is 1
[5] Pytorch的GPU计算(cuda)
论文
[1] Learning to Compare: Relation Network for Few-Shot Learning