Broadcast 自动扩展

2020-01-12  本文已影响0人  残剑天下论

关键点

Broadcast操作可以理解为unsqueeze()和expand()的组合操作,它有两个关键点:

若a、b的维度如下
a: (4, 32, 14, 14)
b: (32, 1, 1)
则b需要扩展为b -> (1, 32, 1, 1)  # 在最前面增加一个维度,且维度值为1

然后扩展到将维度值为1的维度扩展到指定值,
则b由(1, 32, 1, 1) -> (4, 32, 14, 14)

图示过程

第一行的两个Tensor形状都为(4, 3),不需要broadcasting,直接相加即可;
第二行的两个Tensor形状分别为(4, 3)、(1, 3),二者都是两个维度,不需要增加维度;只需要将后者维度值为1的维度扩展到指定值,则有扩展成(4, 3),然后相加
第三行的两个Tensor形状分别为(4, 1)、(1, 3),二者都是两个维度,不需要增加维度;只需要将两个Tensor的维度值为1的维度扩展,都被扩展成(4, 3)

为什么需要Broadcast

a = torch.randint(60, 90, size=(4, 32, 8))  # shape: [4, 32, 8]
b = torch.tensor([5.0])  # shape: [1]

# 如果不使用Broadcast技术,则需要
b = b.unsqueeze(0).unsqueeze(0)
b = b.expand(4, 32, 8)
a + b

# 如果使用Broadcast技术,只需要
a + b

Broadcast需要满足一定条件

a -> [4, 32, 14, 14]
b ->        [14, 14] 
   -> [1, 1, 14, 14]   
   -> [4, 32, 14, 14]
a -> [4, 32, 14, 14]
b -> [1, 32, 1, 1] -> [4, 32, 14, 14]
# a b张量维数相等,都为4,也没有维度值为1的维度,无法扩增,因此a + b时会报错
a -> [4, 32, 14, 14]
b -> [2, 32, 14, 14]

总的来说,Broadcast只能进行: (1)在最前面插入新维度(2)扩增维度值为1的维度

在处理图片中的用途

一个批量图片或者特征图可表示为[4, 3, 32, 32],表示4张图片,每张图片3个通道,宽高为32。

上一篇 下一篇

猜你喜欢

热点阅读