pytorch : CrossEntropyLoss 应用于语

2018-11-10  本文已影响496人  月牙眼的楼下小黑

作 者: 月牙眼的楼下小黑
联 系: zhanglf_tmac (Wechat)
声 明: 欢迎转载本文中的图片或文字,请说明出处


pytorch 官方文档中对 CrossEntropyLoss()的介绍,会产生一种错觉: pytorch中的CrossEntropyLoss似乎无法应用于多类别的图像语义分割任务。

其实: pytorch中的CrossEntropyLoss 是可以直接应用于语义分割任务的。

我们不妨假设一个分割网络的输出形状为: (channel = 3, width = 2, height = 2) ,即 2 x 2 分辨率的图像,其中每个像素可能属于 {0,1,2} 三类中的其中一类。

import torch
from torch import nn
from torch.autograd import Variable

input = Variable(torch.ones(1,3,2,2), requires_grad=True)
target = Variable(torch.LongTensor([[[0,1],[1,0]]]))
print('input:', input)
print('target:', target)

loss = nn.CrossEntropyLoss()
print('loss: ', loss(input, target))
input: Variable containing:
(0 ,0 ,.,.) = 
  1  1
  1  1

(0 ,1 ,.,.) = 
  1  1
  1  1

(0 ,2 ,.,.) = 
  1  1
  1  1
[torch.FloatTensor of size 1x3x2x2]

target: Variable containing:
(0 ,.,.) = 
  0  1
  1  0
[torch.LongTensor of size 1x2x2]

loss:  Variable containing:
 1.0986
[torch.FloatTensor of size 1]

我们讨论一下两个细节:

问题1: 输出的 loss 形状为什么是 1x 1

默认情况下,即 size_average = True, loss 会在每个 mini-batch(小批量) 上取平均值. 如果字段 size_average 被设置为 False, loss 将会在每个 mini-batch(小批量) 上累加, 而不会取平均值.

那么这个 mini_batch_size 等于几呢? 在程序中,网络输出形状为 4-d Tensor: ( batch_size, channel, width, height)。 注意: mini_batch_size != batch_size, 而是: mini_batch_size = batch_size * width * height.

这非常好理解,因为语义分割本质上是 pixel-level classification, 所以 mini_batch_size 就等于一个 batch 图像中的 像素总数

我们可以将上面代码中 loss 参数 size_average 设为 False , 做个简单的验证:

import torch
from torch import nn

input = Variable(torch.ones(1,3,2,2), requires_grad=True)
target = Variable(torch.LongTensor([[[0,1],[1,0]]]))
print('input:', input)
print('target:', target)

loss = nn.CrossEntropyLoss(size_average=False)
print('loss', loss(input, target))

此时输出的 loss 值为: 4.3944, 正好是 1.09861 x 2 x 2 倍。

问题2:如何得到每个 pixel 的 loss ?

只需将loss 参数 reduce 设为 False 即可。若网络输出形状为 4-d Tensor: ( batch_size, channel, width, height), 此时 loss 函数会返回一个 3-d Tensor:batch_size, width, height), 每个元素对应一个 pixelloss 值。

import torch
from torch import nn

input = Variable(torch.ones(1,3,2,2), requires_grad=True)
target = Variable(torch.LongTensor([[[0,1],[1,0]]]))

loss = nn.CrossEntropyLoss(reduce=False)
print('loss: ', loss(input, target))
loss: Variable containing:
(0 ,.,.) = 
  1.0986  1.0986
  1.0986  1.0986
[torch.FloatTensor of size 1x2x2]
上一篇下一篇

猜你喜欢

热点阅读