UserWarning: Using a target size

2023-06-12  本文已影响0人  小黄不头秃

报错内容:
UserWarning: Using a target size (torch.Size([1, 224, 224])) that is different to the input size (torch.Size([1, 1, 224, 224])) is deprecated. Please ensure they have the same size.
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)

报错原因:是因为在相关的函数中两个矩阵的维度不一样所导致的。

解决办法:
使用torch.unsqueeze()或者torch.squeeze()进行升降维。
例如:

    net = UNet().to(device)
    net.train()
    loss_fn = nn.BCELoss()

    for i,(img,target) in enumerate(train_loader):
        img, target = img.to(device), target.to(device)
        y = net(img)
        loss = loss_fn(y, target.unsqueeze(dim=0))
上一篇 下一篇

猜你喜欢

热点阅读