Broadcast 自动扩展
2020-01-12 本文已影响0人
残剑天下论
关键点
Broadcast操作可以理解为unsqueeze()和expand()的组合操作,它有两个关键点:
-
在最前面插入维度,维度值为1
-
扩展维度值为1的维度到指定维度值
若a、b的维度如下
a: (4, 32, 14, 14)
b: (32, 1, 1)
则b需要扩展为b -> (1, 32, 1, 1) # 在最前面增加一个维度,且维度值为1
然后扩展到将维度值为1的维度扩展到指定值,
则b由(1, 32, 1, 1) -> (4, 32, 14, 14)
图示过程
第一行的两个Tensor形状都为(4, 3),不需要broadcasting,直接相加即可;
第二行的两个Tensor形状分别为(4, 3)、(1, 3),二者都是两个维度,不需要增加维度;只需要将后者维度值为1的维度扩展到指定值,则有扩展成(4, 3),然后相加
第三行的两个Tensor形状分别为(4, 1)、(1, 3),二者都是两个维度,不需要增加维度;只需要将两个Tensor的维度值为1的维度扩展,都被扩展成(4, 3)
为什么需要Broadcast
- 具有实际应用
例如:a = torch.randint(60, 90, size=(4, 32, 8))表示学生的成绩,size=(4, 32, 8)表示[classes, students, scores],即a表示4个班级的学生成绩,每个班级有32个学生,每个学生都有8门课程的成绩,学生的成绩都在60分到90分之间。现在我需要将所有学生的成绩全部加5分。这个问题可以表示为:
a = torch.randint(60, 90, size=(4, 32, 8)) # shape: [4, 32, 8]
b = torch.tensor([5.0]) # shape: [1]
# 如果不使用Broadcast技术,则需要
b = b.unsqueeze(0).unsqueeze(0)
b = b.expand(4, 32, 8)
a + b
# 如果使用Broadcast技术,只需要
a + b
- 节省内存
只需要存储[5.0]一个数即可,不必存储[4, 32, 8]这么大一个矩阵。
Broadcast需要满足一定条件
- 两个张量维数不一致时,右对齐,然后在维数少的张量最前面插入维度
例如
a -> [4, 32, 14, 14]
b -> [14, 14]
-> [1, 1, 14, 14]
-> [4, 32, 14, 14]
- 两个张量维数相等时,只需将维度值为1的维度扩增
a -> [4, 32, 14, 14]
b -> [1, 32, 1, 1] -> [4, 32, 14, 14]
- 不能扩增的情况
# a b张量维数相等,都为4,也没有维度值为1的维度,无法扩增,因此a + b时会报错
a -> [4, 32, 14, 14]
b -> [2, 32, 14, 14]
总的来说,Broadcast只能进行: (1)在最前面插入新维度(2)扩增维度值为1的维度
在处理图片中的用途
一个批量图片或者特征图可表示为[4, 3, 32, 32],表示4张图片,每张图片3个通道,宽高为32。
- 为每个通道增加基准
[4, 3, 32, 32] + [32, 32] - 为每个图片的三个通道增加不同的值,例如三个通道各自的均值
[4, 3, 32, 32] + [3, 1, 1] - 为每张图片增加一个亮度值
[4, 3, 32, 32] + [1, 1, 1, 1]
上述这一段写的不甚清楚,可能有误,自行理解。