js css html

torch.flatten vs torch.nn.Flatte

2023-03-05  本文已影响0人  LabVIEW_Python

torch.flatten 和 torch.nn.Flatten 都用于把多维Tensor展平(flatten), 区别是:

torch.flatten(input, start_dim=0, end_dim=- 1)

Class torch.nn.Flatten(start_dim=1, end_dim=- 1)

测试范例程序如下:

import torch

input_tensor = torch.randn(32, 4, 5, 5)
m = torch.nn.Flatten() #实例化Flatten
output1 = m(input_tensor)
print(output1.shape)
output2 = torch.flatten(input_tensor)
print(output2.shape)

运行结果如下:

torch.Size([32, 100])
torch.Size([3200])

另外,torch.nn.Flatten适合作为一个“神经网络层”,加入神经网络中,范例:

def _create_fcs(self, split_size, num_boxes, num_classes):
        S, B, C = split_size, num_boxes, num_classes
        return nn.Sequential(
            nn.Flatten(),
            nn.Linear(1024 * S * S, 4096), 
            # Usually, dropout is placed on the fully connected layers only
            # A rule of thumb is to set the keep probability (1 - drop probability) to 0.5 when dropout is applied to fully connected layers
            # https://stackoverflow.com/questions/46841362/where-dropout-should-be-inserted-fully-connected-layer-convolutional-layer
            nn.Dropout(0.5),
            nn.LeakyReLU(0.1),
            # The predictions are encoded as an S × S × (B ∗ 5 + C) tensor
            nn.Linear(4096, S * S * (B * 5 + C)), # 7*7*(2*5+20)=1470
        )
上一篇下一篇

猜你喜欢

热点阅读