使用torch.nn.functional.grid_sampl
2020-03-17 本文已影响0人
Daniel开峰
flownet2 计算视频中前后两帧的光流信息
def resample(self, image, flow):
'''
image: 上一帧的图片,torch.Size([1, 3, 256, 256])
flow: 光流, torch.Size([1, 2, 256, 256])
final_grid: torch.Size([1, 2, 256, 256])
'''
b, c, h, w = image.size()
grid = get_grid(b, h, w, gpu_id=flow.get_device())
flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1)
final_grid = (grid + flow).permute(0, 2, 3, 1).cuda(image.get_device())
output = torch.nn.functional.grid_sample(image, final_grid, mode='bilinear', padding_mode='border')
return output
Reference:
1.crop pooling
2.What is the equivalent of torch.nn.functional.grid_sample in Tensorflow / Numpy?