tf.nn.in_top_k/tf.nn.top_k

2018-07-16  本文已影响0人  yalesaleng

tf.nn.in_top_k

correct = tf.nn.in_top_k(logits, labels, k)
其中:

logits: a tensor  of shape [batch_size, NUM_CLASSES]
labels: a tensor of shape [batch_size]

理解:

  • 1.对于logits的某行logits[i],
    找到其前k个最大的预测值的index_0, .., index_k-1,
    如果发现对应的labels[i]在{index_0, …, index_k-1},
    则返回True.(大致这个意思)
  • 2.当k=1时,等价于tf.equal(logits, labels)。
    但是,equal()函数中的logits和labels的shape必须一样。
    因此,通过read_data_sets(…, one_hot=False,…)读取数据时,
    必须使得one_hot=Ture(默认为False).

下面为源码中对此的解释:
This outputs a batch_size bool array, an entry out[i] is true if the prediction for the target class is among the top k predictions among all predictions for example i.


tf.nn.top_k

tf.nn.top_k(input, k=1, sorted=sorted, name)
在top_k()函数中,返回的是两个值:

  • values: input中最后一维部分的前k个最大的值。
  • indices:与values中各个值对应的索引。
    (sorted默认为1, 即根据输出values进行降序输出)。
上一篇下一篇

猜你喜欢

热点阅读