BCEWithLogitsLoss参数weight

2021-10-07  本文已影响0人  三方斜阳
1. weight:
import torch
import torch.nn as nn
input = torch.tensor([[-0.4089,-1.2471,0.5907],
                      [-0.4897,-0.8267,-0.7349],
                      [0.5241,-0.1246,-0.4751]])
m=nn.Sigmoid()
S_input=m(input)

target=torch.FloatTensor([[0,1,1],[0,0,1],[1,0,1]])

w = [0.1, 0.9] # 标签0和标签1的权重
weight = torch.zeros(target.shape)  # 权重矩阵
for i in range(target.shape[0]):
    for j in range(target.shape[1]):
        weight[i][j] = w[int(target[i][j])]
print(weight)

BCEWithLogitsLoss=nn.BCEWithLogitsLoss(weight=weight)
loss = BCEWithLogitsLoss(input,target)
print(loss)
loss = 0.0
for i in range(S_input.shape[0]):
    for j in range(S_input.shape[1]):
        loss += -weight[i][j] * (target[i][j] * torch.log(S_input[i][j]) + (1 - target[i][j]) * torch.log(1 - S_input[i][j]))
print(loss/(S_input.shape[0]*S_input.shape[1])) # 默认取均值

tensor([[0.1000, 0.9000, 0.9000],
        [0.1000, 0.1000, 0.9000],
        [0.9000, 0.1000, 0.9000]])
tensor(0.4711)
tensor(0.4711)
上一篇 下一篇

猜你喜欢

热点阅读