(8)将 int list 转成 one-hot形式

2018-12-04  本文已影响0人  顽皮的石头7788121

代码如下:

>>>class_num = 10

>>>batch_size = 4

>>>label = torch.LongTensor(batch_size, 1).random_() % class_num

 3

 0

 0

 8

>>>one_hot = torch.zoros(batch_size, class_num).scatter_(1, label, 1)

 0     0     0     1     0     0     0     0     0     0

 1     0     0     0     0     0     0     0     0     0

 1     0     0     0     0     0     0     0     0     0

 0     0     0     0     0     0     0     0     1     0


    以10类别分类为例,lable=[3] 和label=[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]是一致的.

    Tensor.scatter_(dim, index, value)

    从 value 中拿值,然后根据 dim 和 index 给自己的相应位置填上值

上一篇 下一篇

猜你喜欢

热点阅读