ptorch top5实现

2020-05-06  本文已影响0人  vieo

参考1
参考2
参数:
def topk(self, k, key=None, split_every=None):
input (Tensor) – 输入张量
k (int) – “top-k”中的k
dim (int, optional) – 排序的维
largest (bool, optional) – 布尔值,控制返回最大或最小值
sorted (bool, optional) – 布尔值,控制返回值是否排序
out (tuple, optional) – 可选输出张量 (Tensor, LongTensor) output buffer

def evaluteTop1(model, loader):
    model.eval()

    correct = 0
    total = len(loader.dataset)

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
        #correct += torch.eq(pred, y).sum().item()
    return correct / total

def evaluteTop5(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)
    for x, y in loader:
        x,y = x.to(device),y.to(device)
        with torch.no_grad():
            logits = model(x)
            maxk = max((1,5))
            y_resize = y.view(-1,1)
             _  , pred = logits.topk(maxk, 1, True, True)
            correct += torch.eq(pred, y_resize).sum().float().item()
     return correct / total
上一篇下一篇

猜你喜欢

热点阅读